Invited Talk
in
Workshop: Workshop on Theoretical Foundations of Foundation Models (TF2M)
Ananda Theertha Suresh (Google Research): Accelerating language model inference using optimal transport: Theory and algorithms
Ananda Suresh
Autoregressive sampling from large language models has led to state-of-the-art results in several natural language tasks. However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up sampling is speculative decoding (Leviathan et al., 2022): use a small model to sample a draft (block or sequence of tokens), and then score all tokens in the draft by the large language model in parallel. A subset of the tokens in the draft are accepted (and the rest rejected) based on a statistical method to guarantee that the final output follows the distribution of the large model.
In this talk, we provide a principled understanding of speculative decoding through the lens of distribution coupling and optimal transport theory. This new formulation enables us to improve upon speculative decoding in three ways: first we propose an optimal draft acceptance algorithm that provides additional wall-clock speedup without incurring additional computation cost. Next, we ask if the latency can be improved further with extra parallel computations? We answer this question affirmatively by showing that if we have multiple drafts from the small model, we can use them to improve the speedup further albeit using extra parallel computations. Finally, we demonstrate that the speedup can be further improved if we slightly relax the condition that the final output needs to exactly follow the distribution of the large model. We provide theoretical guarantees on the proposed algorithms and demonstrate the practicality of the algorithms on standard datasets.