Skip to yearly menu bar Skip to main content


Poster

Lowering the Pre-training Tax for Gradient-based Subset Training: A Lightweight Distributed Pre-Training Toolkit

Yeonju Ro · Zhangyang “Atlas” Wang · Vijay Chidambaram · Aditya Akella

Exhibit Hall 1 #439
[ ]
[ Slides [ PDF [ Poster

Abstract: Training data and model sizes are increasing exponentially. One way to reduce training time and resources is to train with a carefully selected subset of the full dataset. Prior work uses the gradient signals obtained during a warm-up or ``pre-training" phase over the full dataset, for determining the core subset; if the pre-training phase is too small, the gradients obtained are chaotic and unreliable. As a result, the pre-training phase itself incurs significant time/resource overhead, and prior work has not gone beyond hyperparameter search to reduce pre-training time. Our work explicitly aims to reduce this $\textbf{pre-training tax}$ in gradient-based subset training. We develop a principled, scalable approach for pre-training in a distributed setup. Our approach is $\textit{lightweight}$ and $\textit{minimizes communication}$ between distributed worker nodes. It is the first to utilize the concept of model-soup based distributed training $\textit{at initialization}$. The key idea is to minimally train an ensemble of models on small, disjointed subsets of the data; we further employ data-driven sparsity and data augmentation for local worker training to boost ensemble diversity. The centralized model, obtained at the end of pre-training by merging the per-worker models, is found to offer stabilized gradient signals to select subsets, on which the main model is further trained. We have validated the effectiveness of our method through extensive experiments on CIFAR-10/100, and ImageNet, using ResNet and WideResNet models. For example, our approach is shown to achieve $\textbf{15.4$\times$}$ pre-training speedup and $\textbf{2.8$\times$}$ end-to-end speedup on CIFAR10 and ResNet18 without loss of accuracy. The code is at https://github.com/moonbucks/LiPT.git.

Chat is not available.