Skip to yearly menu bar Skip to main content


Poster

Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?

Khashayar Gatmiry · Nikunj Saunshi · Sashank J. Reddi · Stefanie Jegelka · Sanjiv Kumar


Abstract:

Transformers to do reasoning and few-shot learning, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate a multi-step algorithms -- such as gradient descent -- with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, {\em can training Transformers enable convergence to algorithmic solutions}? In this work we resolve this for in context linear regression with linear looped Transformers -- a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for multi-layer Transformer in this setting. We further validate our theoretical findings through synthetic experiments.

Live content is unavailable. Log in and register to view live content