On the Role of Intermediate Representations in Knowledge Distillation for Robust Generalization
Abstract
Domain generalization (DG) aims to learn a model that can generalize to unseen i.e. out-of-distribution (OOD) test domain. While large-capacity networks trained with sophisticated DG algorithms tend to achieve high robustness, they tend to be impractical in deployment. Typically, Knowledge distillation (KD) can alleviate this via an efficient transfer of knowledge from a robust teacher to a smaller student network. Throughout our experiments, we find that vanilla KD already provides strong OOD performance, often outperforming standalone DG algorithms. Motivated by this observation, we propose an adaptive distillation strategy that utilizes early layer predictions and uncertainty measures to learn a meta network that effectively rebalances supervised and distillation losses as per sample difficulty. Our method adds no inference overhead and consistently outperforms canonical ERM, vanilla KD, and competing DG algorithms across vision, and text OOD generalization benchmarks.