Minibatch selection for Language Models via Partition Matroid Constrained Gradient Matching
Abstract
Training Large Language Models (LLMs) on heterogeneous datasets requires optimizing domain representations to balance convergence speed and domain coverage. While recent methods reduce computational overhead by selecting high-quality data subsets, they typically perform selection independently per domain or rely on computationally expensive proxy models to determine continuous domain weights. In this paper, we propose a joint sample selection framework that learns better domain representations in the batch directly and formulate the selection objective as joint domain utility maximization subject to partition matroid constraints. This approach enforces domain-specific budgets while maximizing a validation-guided gradient matching utility across all domains simultaneously. Theoretically, we establish that this objective is weakly submodular, allowing us to employ a computationally efficient orthogonal matching pursuit algorithm with provable guarantees. Empirically, we demonstrate that our method significantly outperforms baselines that favor independent domain selection on mathematical reasoning and molecular generation benchmarks. Furthermore, our analysis show that our approach reduces the number of conflicting training gradient pairs significantly as compared to independent selection across domains or domain agnostic selection. Applied to Qwen2.5 and Llama-3 and trained on MethaMathQA and Mol-Instructions, our approach yields robust gains under multiple subset fractions, with improvements on nine math-reasoning and four molecule-generation benchmarks, highlighting cross-domain benefits of joint subset selection.