Hardware-Aware Dynamic Sparse Training for Large Output Spaces
Nasib Ullah ⋅ Jinbin Zhang ⋅ Jean Lucien Randrianantenaina ⋅ Erik Schultheis ⋅ Rohit Babbar
Abstract
Extreme multi-label classification (XMC) involves learning deep learning models over large output spaces with millions of labels, making the output layer of the network a major bottleneck in memory and compute. While sparsity-based methods reduce arithmetic complexity, they often fail to yield proportional wall-clock gains due to irregular memory access, poor hardware utilization, or reliance on auxiliary architectural components in extreme long-tailed regimes. We introduce group-shared fixed fan-in sparsity, a semi-structured output-layer design in which groups of semantically related labels share a common sparse input pattern while retaining independent weights. This grouping introduces a task-aligned inductive bias---encouraging related labels to attend to similar feature subsets---while simultaneously reducing index memory overhead, increasing feature reuse across labels, and enabling efficient GPU execution via custom CUDA kernels that leverage modern accelerator primitives. As an alternative to auxiliary objectives, we exploit the long-tailed structure of XMC datasets by decomposing the output layer into a small dense head over frequent labels and a group-shared sparse tail over the remainder, providing an informative gradient pathway while preserving the memory benefits of sparsity. Through kernel-level microbenchmarking, we show that group-shared fixed fan-in converts reductions in arithmetic complexity into proportional wall-clock gains, achieving up to $4.4\times$ speedup in the forward pass and up to $25\times$ speedup in backward passes compared to standard fixed fan-in sparsity, while operating within a few percent of a FLOPs-matched dense bottleneck. Across large-scale XMC benchmarks, our approach matches or improves precision@k compared to prior sparse baselines, while substantially narrowing the performance gap to dense.
Successful Page Load