Poster
in
Workshop: ES-FoMo II: 2nd Workshop on Efficient Systems for Foundation Models
Just read twice: closing the recall gap for recurrent language models
Simran Arora · Aman Timalsina · Aaryan Singhal · Sabri Eyuboglu · Xinyi Zhao · Ashish Rao · Atri Rudra · Christopher Re
Abstract:
Recurrent large language models that compete with Transformers in language modeling perplexity are emerging at a rapid rate (e.g., Mamba, RWKV,). Excitingly, these architectures use a constant amount of memory during inference. However, due to the limited memory, recurrent LMs cannot recall all the information in long contexts leading to brittle in-context learning (ICL) quality. It is not clear how to improve recall quality while preserving efficiency. We observe the order in which information is shown to the LM impacts the selection difficulty. To formalize this, we show that the hardness of information recall reduces to the hardness of set disjointness (SD), a decades-old and canonical problem in communication complexity theory that requires a streaming algorithm (e.g., recurrent model) to decide whether inputted sets are disjoint. We use this connection to empirically and theoretically show that the recurrent memory requirement to solve SD changes with set order. Our analysis suggests, to mitigate the reliance on data order, we can put information in the right order in-context or process prompts non-causally. Towards that end, we first propose: (1) JRT-Prompt, where information is repeated multiple times in the prompt, showing the model all data orders. This gives $11.1 \pm 1.2$ points of improvement, averaged across the ICL tasks, with $11.9\times$ higher throughput than FlashAttention-2 (length $32\mathrm{k}$, batch size $16$). We then propose (2) JRT-RNN, which uses non-causal cross-linear-attention to process prompts and provides $13.7$ (at $360\mathrm{M}$ params., $30\mathrm{B}$ tokens) and $6.9$ (at $1.3\mathrm{B}$ params., $50\mathrm{B}$ tokens) point average improvements over the decoder baseline with $19.2\times$ higher throughput than FA2. Code is available at: https://github.com/HazyResearch/prefix-linear-attention
Chat is not available.