Per-example Gradients: a New Frontier for Understanding and Improving Optimizers
Abstract
When computing gradients, deep learning training algorithms typically treat the mini-batch as a fundamental unit --- only returning batch-averaged gradients. Computing non-linear statistics of the mini-batch gradient distribution has traditionally been viewed as prohibitively expensive or requiring complex, custom implementations. We challenge this view by demonstrating that sequence-level architectures offer a natural testbed for prototyping algorithms based on per-example gradients. We show that staged programming languages like JAX enable generic manipulations of mini-batch gradient computations. We then build on Dangel et. al. (2019) to derive implementations of specific per-example or per-token operations with negligible computational or memory overhead. Finally, we leverage our findings to re-examine two nonlinear optimization operations. First, we analyze signSGD, showing that the optimal placement of the sign operation is critical to success and can be predicted via a simple signal-to-noise ratio argument. Second, we investigate per-example variations of the Adam preconditioner and find that, contrary to conventional wisdom, optimization is best served when the preconditioner is dominated by the mean squared of the gradient distribution rather than its variance. Overall our work shows that accessible per-example gradient information unlocks new avenues for algorithm analysis and design.