Learning When to Attend: Conditional Memory Access for Long-Context LLMs
Sakshi Choudhary ⋅ Aditya Chattopadhyay ⋅ Luca Zancato ⋅ Elvis Nunez ⋅ Matthew Trager ⋅ Wei Xia ⋅ Stefano Soatto
Abstract
Language models struggle to generalize beyond the context lengths seen during pretraining, limiting performance on long-horizon reasoning and retrieval. Continued pretraining on long-context data can mitigate this limitation, but it is prohibitively expensive due to the quadratic scaling of Attention with sequence length. In practice, most tokens do not require Global Attention over the entire sequence and can rely on local context. Based on this insight, we propose L2A, a sequence modeling layer that enables token-wise long-term conditional memory access by deciding \textit{when} to invoke Global Attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens, where it matches standard long-context training within 1.5–3\% while skipping Global Attention for $\sim$80\% of tokens and outperforming prior baselines. We also design custom Triton kernels to efficiently realize this token-wise conditional attention on GPUs, achieving up to $\sim$2× improvements in training throughput and time-to-first-token over FlashAttention-2. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50\% with negligible performance loss.
Successful Page Load