Timezone: »
SpecTr: Fast Speculative Decoding via Optimal Transport
Ziteng Sun · Ananda Suresh · Jae Ro · Ahmad Beirami · Himanshu Jain · Felix Xinnan Yu · Michael Riley · Sanjiv Kumar
Event URL: https://openreview.net/forum?id=d0mGsaheuT »
Autoregressive sampling from large language models has shown to achieve state-of-the-art results in several natural language tasks.However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up decoding is *speculative decoding*: use a smaller model to sample a *draft* (block or sequence of tokens), and then score all tokens in the draft by the desired large language model in parallel. The tokens in the draft are either accepted or rejected based on a statistical method to guarantee that the final output is a valid sample from the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with *membership cost*. This framework can be viewed as an extension of the well-known *maximal-coupling* problem. This new formulation enables us to generalize the sampling method to allow for a set of $k$ candidates at the token-level, leading to an improved optimal membership cost. The optimal solution can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose an approximate solution whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of token vocabulary.Using this new OT algorithm, we develop a new autoregressive sampling algorithm called *SpecTr*, which creates multiple drafts of the next few tokens from the small language model, and score all of them in parallel by the large language model. We accept one or reject all of them based on their respective scores. We experimentally demonstrate that the proposed approach achieves a speedup of 3X, a further 1.36X speedup over speculative decoding on standard benchmarks.
Autoregressive sampling from large language models has shown to achieve state-of-the-art results in several natural language tasks.However, autoregressive sampling generates tokens one at a time making it slow, and even prohibitive in certain tasks. One way to speed up decoding is *speculative decoding*: use a smaller model to sample a *draft* (block or sequence of tokens), and then score all tokens in the draft by the desired large language model in parallel. The tokens in the draft are either accepted or rejected based on a statistical method to guarantee that the final output is a valid sample from the large model. In this work, we provide a principled understanding of speculative decoding through the lens of optimal transport (OT) with *membership cost*. This framework can be viewed as an extension of the well-known *maximal-coupling* problem. This new formulation enables us to generalize the sampling method to allow for a set of $k$ candidates at the token-level, leading to an improved optimal membership cost. The optimal solution can be computed via linear programming, whose best-known runtime is exponential in $k$. We then propose an approximate solution whose acceptance probability is $(1-1/e)$-optimal multiplicatively. Moreover, it can be computed in time almost linear with size of token vocabulary.Using this new OT algorithm, we develop a new autoregressive sampling algorithm called *SpecTr*, which creates multiple drafts of the next few tokens from the small language model, and score all of them in parallel by the large language model. We accept one or reject all of them based on their respective scores. We experimentally demonstrate that the proposed approach achieves a speedup of 3X, a further 1.36X speedup over speculative decoding on standard benchmarks.
Author Information
Ziteng Sun (Google Research)
Ananda Suresh (Google Research)
Jae Ro (Google)
Ahmad Beirami (Google Research)
Himanshu Jain (Google)
Felix Xinnan Yu (Google)
Michael Riley (Google)
Sanjiv Kumar (Google Research, NY)
More from the Same Authors
-
2021 : Learning with User-Level Privacy »
Daniel A Levy · Ziteng Sun · Kareem Amin · Satyen Kale · Alex Kulesza · Mehryar Mohri · Ananda Theertha Suresh -
2021 : FERMI: Fair Empirical Risk Minimization Via Exponential Rényi Mutual Information »
Andrew Lowy · Rakesh Pavan · Sina Baharlouei · Meisam Razaviyayn · Ahmad Beirami -
2023 : Robustness through Data Augmentation Loss Consistency »
Tianjian Huang · Shaunak Halbe · Chinnadhurai Sankar · Pooyan Amini · Satwik Kottur · Alborz Geramifard · Meisam Razaviyayn · Ahmad Beirami -
2023 : Robust Deep Learning via Layerwise Tilted Exponentials »
Bhagyashree Puranik · Ahmad Beirami · Yao Qin · Upamanyu Madhow -
2023 : Robustness through Loss Consistency Regularization »
Tianjian Huang · Shaunak Halbe · Chinnadhurai Sankar · Pooyan Amini · Satwik Kottur · Alborz Geramifard · Meisam Razaviyayn · Ahmad Beirami -
2023 : Towards A Scalable Solution for Compositional Multi-Group Fair Classification »
James Atwood · Tina Tian · Ben Packer · Meghana Deodhar · Jilin Chen · Alex Beutel · Flavien Prost · Ahmad Beirami -
2023 : Federated Heavy Hitter Recovery under Linear Sketching »
Adria Gascon · Peter Kairouz · Ziteng Sun · Ananda Suresh -
2023 : Let’s Do a Thought Experiment: Using Counterfactuals to Improve Moral Reasoning »
Xiao Ma · Swaroop Mishra · Ahmad Beirami · Alex Beutel · Jilin Chen -
2023 : Robustness through Loss Consistency Regularization »
Tianjian Huang · Shaunak A Halbe · Chinnadhurai Sankar · Pooyan Amini · Satwik Kottur · Alborz Geramifard · Meisam Razaviyayn · Ahmad Beirami -
2023 Workshop: Neural Conversational AI Workshop - What’s left to TEACH (Trustworthy, Enhanced, Adaptable, Capable and Human-centric) chatbots? »
Hyundong Cho · Nayeon Lee · Ninareh Mehrabi · Hsuan Su · Jonathan May · Hung-yi Lee · Ahmad Beirami -
2023 Poster: Subset-Based Instance Optimality in Private Estimation »
Travis Dick · Alex Kulesza · Ziteng Sun · Ananda Suresh -
2023 Poster: Federated Heavy Hitter Recovery under Linear Sketching »
Adria Gascon · Peter Kairouz · Ziteng Sun · Ananda Suresh -
2023 Poster: Algorithms for bounding contribution for histogram estimation under user-level privacy »
Yuhan Liu · Ananda Suresh · Wennan Zhu · Peter Kairouz · Marco Gruteser -
2023 Poster: User-level Private Stochastic Convex Optimization with Optimal Rates »
Raef Bassily · Ziteng Sun -
2023 Poster: Efficient Training of Language Models using Few-Shot Learning »
Sashank Jakkam Reddi · Sobhan Miryoosefi · Stefani Karp · Shankar Krishnan · Satyen Kale · Seungyeon Kim · Sanjiv Kumar -
2022 Poster: In defense of dual-encoders for neural ranking »
Aditya Menon · Sadeep Jayasumana · Ankit Singh Rawat · Seungyeon Kim · Sashank Jakkam Reddi · Sanjiv Kumar -
2022 Spotlight: In defense of dual-encoders for neural ranking »
Aditya Menon · Sadeep Jayasumana · Ankit Singh Rawat · Seungyeon Kim · Sashank Jakkam Reddi · Sanjiv Kumar -
2022 Poster: The Fundamental Price of Secure Aggregation in Differentially Private Federated Learning »
Wei-Ning Chen · Christopher Choquette Choo · Peter Kairouz · Ananda Suresh -
2022 Poster: Robust Training of Neural Networks Using Scale Invariant Architectures »
Zhiyuan Li · Srinadh Bhojanapalli · Manzil Zaheer · Sashank Jakkam Reddi · Sanjiv Kumar -
2022 Spotlight: The Fundamental Price of Secure Aggregation in Differentially Private Federated Learning »
Wei-Ning Chen · Christopher Choquette Choo · Peter Kairouz · Ananda Suresh -
2022 Oral: Robust Training of Neural Networks Using Scale Invariant Architectures »
Zhiyuan Li · Srinadh Bhojanapalli · Manzil Zaheer · Sashank Jakkam Reddi · Sanjiv Kumar -
2022 Poster: Correlated Quantization for Distributed Mean Estimation and Optimization »
Ananda Suresh · Ziteng Sun · Jae Ro · Felix Xinnan Yu -
2022 Spotlight: Correlated Quantization for Distributed Mean Estimation and Optimization »
Ananda Suresh · Ziteng Sun · Jae Ro · Felix Xinnan Yu -
2021 Workshop: Information-Theoretic Methods for Rigorous, Responsible, and Reliable Machine Learning (ITR3) »
Ahmad Beirami · Flavio Calmon · Berivan Isik · Haewon Jeong · Matthew Nokleby · Cynthia Rush -
2021 : Opening Remarks »
Ahmad Beirami -
2021 Poster: Robust Testing and Estimation under Manipulation Attacks »
Jayadev Acharya · Ziteng Sun · Huanyu Zhang -
2021 Spotlight: Robust Testing and Estimation under Manipulation Attacks »
Jayadev Acharya · Ziteng Sun · Huanyu Zhang -
2021 Poster: A statistical perspective on distillation »
Aditya Menon · Ankit Singh Rawat · Sashank Jakkam Reddi · Seungyeon Kim · Sanjiv Kumar -
2021 Poster: Disentangling Sampling and Labeling Bias for Learning in Large-output Spaces »
Ankit Singh Rawat · Aditya Menon · Wittawat Jitkrittum · Sadeep Jayasumana · Felix Xinnan Yu · Sashank Jakkam Reddi · Sanjiv Kumar -
2021 Spotlight: A statistical perspective on distillation »
Aditya Menon · Ankit Singh Rawat · Sashank Jakkam Reddi · Seungyeon Kim · Sanjiv Kumar -
2021 Spotlight: Disentangling Sampling and Labeling Bias for Learning in Large-output Spaces »
Ankit Singh Rawat · Aditya Menon · Wittawat Jitkrittum · Sadeep Jayasumana · Felix Xinnan Yu · Sashank Jakkam Reddi · Sanjiv Kumar -
2021 Poster: Ditto: Fair and Robust Federated Learning Through Personalization »
Tian Li · Shengyuan Hu · Ahmad Beirami · Virginia Smith -
2021 Spotlight: Ditto: Fair and Robust Federated Learning Through Personalization »
Tian Li · Shengyuan Hu · Ahmad Beirami · Virginia Smith -
2020 Poster: Does label smoothing mitigate label noise? »
Michal Lukasik · Srinadh Bhojanapalli · Aditya Menon · Sanjiv Kumar -
2020 Poster: Low-Rank Bottleneck in Multi-head Attention Models »
Srinadh Bhojanapalli · Chulhee Yun · Ankit Singh Rawat · Sashank Jakkam Reddi · Sanjiv Kumar -
2020 Poster: Accelerating Large-Scale Inference with Anisotropic Vector Quantization »
Ruiqi Guo · Philip Sun · Erik Lindgren · Quan Geng · David Simcha · Felix Chern · Sanjiv Kumar -
2020 Poster: Federated Learning with Only Positive Labels »
Felix Xinnan Yu · Ankit Singh Rawat · Aditya Menon · Sanjiv Kumar -
2020 Poster: Context Aware Local Differential Privacy »
Jayadev Acharya · Kallista Bonawitz · Peter Kairouz · Daniel Ramage · Ziteng Sun -
2019 : Structured matrices for efficient deep learning »
Sanjiv Kumar -
2019 Poster: Escaping Saddle Points with Adaptive Gradient Methods »
Matthew Staib · Sashank Jakkam Reddi · Satyen Kale · Sanjiv Kumar · Suvrit Sra -
2019 Poster: Agnostic Federated Learning »
Mehryar Mohri · Gary Sivek · Ananda Suresh -
2019 Oral: Agnostic Federated Learning »
Mehryar Mohri · Gary Sivek · Ananda Suresh -
2019 Oral: Escaping Saddle Points with Adaptive Gradient Methods »
Matthew Staib · Sashank Jakkam Reddi · Satyen Kale · Sanjiv Kumar · Suvrit Sra -
2019 Poster: Learning a Compressed Sensing Measurement Matrix via Gradient Unrolling »
Shanshan Wu · Alexandros Dimakis · Sujay Sanghavi · Felix Xinnan Yu · Daniel Holtmann-Rice · Dmitry Storcheus · Afshin Rostamizadeh · Sanjiv Kumar -
2019 Oral: Learning a Compressed Sensing Measurement Matrix via Gradient Unrolling »
Shanshan Wu · Alexandros Dimakis · Sujay Sanghavi · Felix Xinnan Yu · Daniel Holtmann-Rice · Dmitry Storcheus · Afshin Rostamizadeh · Sanjiv Kumar -
2019 Poster: Communication Complexity in Locally Private Distribution Estimation and Heavy Hitters »
Jayadev Acharya · Ziteng Sun -
2019 Oral: Communication Complexity in Locally Private Distribution Estimation and Heavy Hitters »
Jayadev Acharya · Ziteng Sun -
2018 Poster: Loss Decomposition for Fast Learning in Large Output Spaces »
En-Hsu Yen · Satyen Kale · Felix Xinnan Yu · Daniel Holtmann-Rice · Sanjiv Kumar · Pradeep Ravikumar -
2018 Oral: Loss Decomposition for Fast Learning in Large Output Spaces »
En-Hsu Yen · Satyen Kale · Felix Xinnan Yu · Daniel Holtmann-Rice · Sanjiv Kumar · Pradeep Ravikumar -
2018 Poster: INSPECTRE: Privately Estimating the Unseen »
Jayadev Acharya · Gautam Kamath · Ziteng Sun · Huanyu Zhang -
2018 Oral: INSPECTRE: Privately Estimating the Unseen »
Jayadev Acharya · Gautam Kamath · Ziteng Sun · Huanyu Zhang -
2017 Poster: Stochastic Generative Hashing »
Bo Dai · Ruiqi Guo · Sanjiv Kumar · Niao He · Le Song -
2017 Talk: Stochastic Generative Hashing »
Bo Dai · Ruiqi Guo · Sanjiv Kumar · Niao He · Le Song -
2017 Poster: Distributed Mean Estimation with Limited Communication »
Ananda Theertha Suresh · Felix Xinnan Yu · Sanjiv Kumar · Brendan McMahan -
2017 Poster: A Unified Maximum Likelihood Approach for Estimating Symmetric Properties of Discrete Distributions »
Jayadev Acharya · Hirakendu Das · Alon Orlitsky · Ananda Suresh -
2017 Talk: A Unified Maximum Likelihood Approach for Estimating Symmetric Properties of Discrete Distributions »
Jayadev Acharya · Hirakendu Das · Alon Orlitsky · Ananda Suresh -
2017 Talk: Distributed Mean Estimation with Limited Communication »
Ananda Theertha Suresh · Felix Xinnan Yu · Sanjiv Kumar · Brendan McMahan