Statistically Optimal Scaling for Token Merging in Transformers
Qing Zhou ⋅ Hongyuan Zhang ⋅ Tao Yang ⋅ Junyu Gao ⋅ Qi Wang
Abstract
Token merging accelerates Transformer inference by clustering similar tokens to reduce sequence length (retention ratio $r$), but distorts attention outputs, inducing covariate shift in residual streams and performance collapse under high compression. Existing heuristics, such as proportional attention, mitigate mild compression effectively but degrade sharply at aggressive ratios due to unaddressed energy drift and biased attention distributions. We reframe token merging as a statistical reconstruction problem in high dimensions and introduce an asymptotic radial-angular decomposition of the reconstruction error, an analytical framework decoupling magnitude and distributional distortions. Minimizing this decomposed risk under minimal assumptions of finite second moments and variance stationarity yields closed-form optimal corrections governed by a single scaling factor $\sqrt{r}$: scaling merged values and shrinking merged logits toward the cluster-size prior. This calibrates both energy balance and distributional fidelity. Extensive experiments on vision Transformers demonstrate superior accuracy and robustness across compression levels.
Successful Page Load