Retrieval-Aware Distillation for Transformer-SSM Hybrids
Abstract
State-space models (SSMs) offer efficient sequence modeling but show a large performance gap compared to Transformers on benchmarks that require in-context retrieval. This gap has been linked to a small set of attention heads, called Gather-and-Aggregate (G&A), which SSMs struggle to implement and are believed to drive the disparity. Leveraging this insight, we propose retrieval-aware distillation, a strategy that converts a pretrained Transformer into a hybrid student by preserving only these retrieval-critical components. We identify essential attention heads via ablation on a synthetic retrieval task and distill the rest into recurrent heads, resulting in a model with non-uniform attention placement tailored to retrieval demands. We empirically show that preserving just 2% of attention heads enables the hybrid model to recover teacher-level performance (10 retrieval-critical heads in a 1B model), reducing memory overhead by up to 6x compared to prior distillation methods that retain 30–50% of heads. Furthermore, we show that large recurrent states in SSMs often compensate for missing retrieval. Once retrieval is handled by these specific heads, the SSM backbone can be significantly simplified while maintaining performance, even with an 8x reduction in state dimension. Overall, the results show that strategically concentrating attention can close the Transformer–SSM gap with a fraction of the memory cost.