Spiked-CFR: Causal Representation Learning from LLMs via Wasserstein Projection Pursuit
Fan Wang ⋅ Hengyu Yue ⋅ Yu Bowen ⋅ Weiming Liu ⋅ Zongxin Yang ⋅ Xuyun Zhang ⋅ Xiaolin Zheng ⋅ Chaochao Chen ⋅ Shuiguang Deng
Abstract
Estimating treatment effects from observational text is increasingly practical with Large Language Models (LLMs). However, applying causal representation learning directly to high-dimensional LLM embeddings faces a fundamental barrier: empirical Wasserstein matching suffers from the curse of dimensionality, rendering standard generalization guarantees effectively vacuous. We propose SPIKED-CFR, a framework bridging this gap by assuming a Spiked Confounding Structure, where treatment selection bias concentrates in a low-dimensional subspace of the semantic representation. We develop Wasserstein Projection Pursuit, a minimax objective that adversarially learns an orthogonal projection on the Stiefel manifold to identify and balance only this confounding subspace while preserving prognostic information. Under a spiked confounding structure, we show the projected discrepancy can be estimated at a rate governed by the intrinsic dimension $k \ll D$, and we derive a tighter PEHE generalization bound that depends on $k$ rather than the ambient embedding dimension. Experiments on four semi-synthetic benchmarks and four real-world clinical benchmarks demonstrate improved accuracy and robustness over strong baselines. Code is available at \url{https://anonymous.4open.science/r/SpikedCFR-7E13}.
Successful Page Load