Training–Inference Consistent Segmented Execution for Long-Context LLMs
Xianpeng Shang ⋅ Jiang Li ⋅ Zehua Duo ⋅ Qianyi Cai ⋅ Xiangdong Su
Abstract
Transformer-based large language models face severe scalability challenges in long-context generation due to the computational and memory costs of full-context attention. Under practical computation and memory constraints, many inference-efficient long-context methods improve efficiency by adopting bounded-context or segment-level execution only during inference, while continuing to train models under full-context attention, resulting in a mismatch between training and inference execution and state-transition semantics. Based on this insight, we propose a training-consistent segment-level generation framework, in which training and inference follow the same segment-level forward execution semantics. During training, consistency with inference is enforced by restricting gradient propagation to KV states carried over from the immediately preceding segment, while permitting head-specific access to past KV states during the forward pass without involving them in gradient propagation. Across long-context benchmarks, our approach achieves performance comparable to full-context attention, while achieving competitive latency--memory trade-offs against strong inference-efficient baselines, and substantially improving scalability at very long context lengths (e.g., approximately $6\times$ lower peak prefill memory at 128K compared to full-context attention with FlashAttention).
Successful Page Load