Causal Fine-Tuning under Latent Confounded Shift
Abstract
Adapting to latent confounded shift remains a core challenge in modern AI. This setting is driven by hidden variables that induce spurious correlations between inputs and outputs during training, leading models to rely on non-causal shortcuts. For example, a model may learn to treat metadata (e.g., data source like "Amazon") as a proxy for positive sentiment, causing failure when the source becomes predominantly negative during deployment. To address this latent confounded shift, we introduce Causal Fine-Tuning (CFT). Using a structural causal model as an inductive bias, we derive sufficient conditions under which the causal effect of inputs is identifiable (despite latent confounding), and translate these insights into a fine-tuning objective that decomposes representations into high-level causal and low-level spurious components. Instantiating this framework in BERT, we show that learning such causal/spurious representations and adjusting them accordingly yield a more robust predictor. Experiments on spurious correlation injection attacks in text demonstrate that our method outperforms black-box domain generalization baselines, highlighting the benefits of explicitly modeling causal structure.