Poster
in
Workshop: HiLD: High-dimensional Learning Dynamics Workshop
Adapting to Gradual Distribution Shifts with Continual Weight Averaging
Jared Fernandez · Saujas Vaduguru · Sanket Vaibhav Mehta · Yonatan Bisk · Emma Strubell
Machine learning models are frequently deployed in settings where the data distributions are shifting gradually over time. However, existing methods for adapting to distribution shifts fail to properly leverage this implicit structure. We propose a simple method to continually adapt a model to distribution shift by using the exponential moving average of model weights over discrete timesteps. We refer to our method as Continually and Stochastically Averaging Weights (CSAW). We show that CSAW achieves state-of-the-art performance on the Wild-Time benchmark of in-the-wild gradual temporal distribution shifts on a variety of datasets across vision, language and medical domains with improvements in both average and worst case OOD performance: +2.23% accuracy on Yearbook, +2.96% on FMoW, +0.87% on HuffPost, +1.43% on ArXiv, and +0.75% ROC-AUC on MIMIC-Mortality. We analyze the loss landscapes of sequentially fine-tuned models and show that they exhibit favorable mode connectivity properties which allows for weight averaging.