Mitigating Per-Sample Harm in Stochastic Optimization
Abstract
Modern optimizers combine gradients from the current mini-batch with historical optimization state, such as momentum or adaptive moments. While effective for stability, this interaction can produce update directions that increase the loss of individual samples in the current batch. We formalize this effect as harm and cast the computation of an update as an optimization problem that explicitly minimizes the harmful impact of past optimization state on current data. To make this optimization problem tractable, we first reduce its dimensionality from the number of parameters to the batch size, and further show that restricting the optimization to the last layer provides an effective and efficient proxy. The resulting subproblem can be solved with a small number of GPU-friendly iterations and integrated seamlessly into SGD with momentum and AdamW. Experiments on image classification benchmarks show reduced per-sample interference and improved generalization with moderate overhead.