Prototype Transformer: Towards Language Model Architectures Interpretable by Design
Abstract
While state-of-the-art language models (LMs) surpass the vast majority of humans in certain domains, their reasoning remains largely opaque, reducing trust and risking deception and hallucination. In this work, we introduce the Prototype Transformer (ProtoT)—an autoregressive LM architecture that replaces the quadratic-cost self-attention in the transformer with a linear-cost module based on prototypes (parameter vectors). In ProtoT, the prototypes create communication channels aggregating contextual information at different time scales. We show that this leads to the prototypes automatically capturing nameable concepts (e.g. “woman”) during training, and it provides the potential to interpret the model’s reasoning and do targeted edits of its behavior. Compared to baselines, ProtoT scales well with model and data size, shows robustness to input perturbations, and performs well on text generation and downstream tasks (GLUE). Reaching close to the performance of state-of-the-art architectures, ProtoT paves the way to creating well-performing autoregressive LMs interpretable by design.