Paper ID: 1348 Title: Learning Representations for Counterfactual Inference Review #1 ===== Summary of the paper (Summarize the main claims/contributions of the paper.): This paper is concerned with learning representations for counterfactual reasoning, a classically important problem in causal inference. The authors drew a connection between this problem and domain adaptation, and made use of some techniques in domain adaptation to solve their problem. Clarity - Justification: The framework is technically interesting. However, it is not conceptually clear to me. The authors did not explain why counterfactual inference can be understood precisely as a domain adaptation problem and what domain adaptation setting they actually referred to. In addition, some statements are very confusing. For example, at the beginning of Section 3, the authors stated “We propose to perform counterfactual inference by amending the direct modeling approach, taking into account…. making this a type of domain adaptation problem (Schölkopf et al., 2012).” Actually, as far as I know, Schölkopf et al. (2012) did not talk about domain adaptation at all. Significance - Justification: The considered problem is important, and the authors seem to see a nice connection between their problem and domain adaptation, although this connection is not rigorous or clear to me. Detailed comments. (Explain the basis for your ratings while providing constructive feedback.): This paper is concerned with learning representations for counterfactual reasoning, a classically important problem in causal inference. The framework is technically interesting. However, it is not conceptually clear to me. The authors did not explain why counterfactual inference can be understood as a domain adaptation problem and what domain adaptation problem they precisely referred to. Can the authors explain the setting of domain adaptation and its specific goal here? Is it covariate shift? If so, can the authors justify why the covariate shift assumptions hold true here? Did the authors just aim at maximizing the performance in the "test" domain? Regarding the third objective, which is to make the distribution of the treatment and control similar, I failed to see the precise reason why this should be preferred. In some domain adaptation methods such as transfer component analysis (TCA), this was preferred just because it was ASSUMED that for representations with identical marginal distributions, the conditional of Y given such representations is the same. This assumption was not justified at all. Or was it inspired by the learning bounds in Mansour et al. (2009)? Please make this conceptually clear. Furthermore, as seen from the experimental results, in the linear case the proposed method performs almost equally well with Lasso+ridge regression, indicating that the objective function the authors proposed to use might be overly complicated or redundant. An explanation on this would be appreciated. Several statements are very confusing. For example, at the beginning of Section 3, the authors stated “We propose to perform counterfactual inference by amending the direct modeling approach, taking into account…. making this a type of domain adaptation problem (Schölkopf et al., 2012).” Actually, as far as I know, Schölkopf et al. (2012) did not talk about domain adaptation at all. The structure is this paper is very loose. For instance, in Introduction, the authors first summarized their method and then explained its connection to previous work on representation learning and the problem of counterfactual reasoning in machine learning. I am wondering whether it would be more natural if the authors presented the problem and previous work first and then showed how they motivated the proposed method. Minor point: - Page 5: It measures how we COULD WE in… ===== Review #2 ===== Summary of the paper (Summarize the main claims/contributions of the paper.): The paper frames counterfactual prediction as a transfer learning problem. This is sensible because in observational studies 1) the factual distribution P and the counterfactual distribution Q are related, and 2) one obtains labeled data from P and unlabeled data from Q. Clarity - Justification: The paper is well written, easy to follow, and mathematically sound. Significance - Justification: The paper tackles an important problem in a principled way. However, the theory and practice from the manuscript remain disconnected, and the experimental result provides no insight on situations where the algorithm may fail. Detailed comments. (Explain the basis for your ratings while providing constructive feedback.): My major concern with the submission is that Theorem 1 does not hold when $\Phi$, $h$, or $h \circ \Phi$ are deep neural networks. This is because the objective function is then non-convex, and the argument of (Cortes & Mohri, 2014) falls apart in this case. The theory would only hold for fixed (non-learnable) representations $\Phi$ and linear predictors $h$. Since the representation $\Phi$ is always learned in the numerical simulations (and the classifier $h$ is learned as well in the best-performing architecture) the theory and practice of this paper remain disconnected. Furthermore, the bound does not take into account the hyper-parameters $(\alpha, \gamma)$. Another concern is that the numerical simulations do not analyze situations (perhaps synthetic) where the algorithm breaks: in particular, it may be the case that our dataset does not contain near counterfactual neighbours to the factual datapoints. This would in turn increase the minimal distances (i,j(i)) and, according to the bounds, provide a poor counterfactual estimation. Is such degrading graceful? Can the algorithm detect it and report to the user? These questions are of paramount importance in counterfactual prediction (e.g. treatment assignment to patients). ===== Review #3 ===== Summary of the paper (Summarize the main claims/contributions of the paper.): This paper presents transfer learning approach to causal inference from purely observational data. The core of causal inference is reasoning about counterfactuals and the interactions between context (X), treatment (T), and outcomes (Y). The strongest causal reasoning frameworks employ randomization (to ensure that T is chosen independent of X) or an explicit and known model of P(T|X), as in reinforcement learning and bandit problems. Otherwise, they fall back to approaches that require leveraging domain knowledge or making assumptions (propensity scoring, parametric models, etc.). This paper proposes an alternative approach to implicitly (and automatically) performing propensity-style matching within a latent space that is learned jointly with prediction function. It leverages the transfer learning theoretical framework and results of Cortes and Mohri on "discrepancy distance" (a hypothesis class-dependent distance between marginal data distributions, built upon the older dA distance from Ben-David, Gehrke, and Kifer, VLDB 2004). In short, they optimize an objective function over [a] hypotheses (i.e., a class of outcome prediction functions, e.g., linear regressions or neural nets) and [b] representations of the input (i.e., feature selection and re-weighting OR a non-linear mapping via neural net) with three terms: (1) prediction error for actual observed outcomes (2) discrepancy distance between empirical distributions over "factual" (observed) data and "counterfactual" data (i.e., factual examples with opposite outcomes implicitly matched within the learned representation space) (3) "counterfactual" prediction error: i.e., error between the prediction for observed X_i with treatment opposite of what they got (1-T_i) and the outcome for the nearest X_j within the learned representation space In addition to proposing the idea and providing two practical approaches to learning the algorithm (one theoretically grounded, based on linear regression; the other based on neural networks), the authors provide a theorem (Theorem 1, lines 425-490) that places an upper bound on the squared difference between the losses for two estimated outcome predictors (from empirical factual and counterfactual samples, respectively), over either the factual or counterfactual true distributions. This theoreom appears to be a direct consequence of Cortes and Mohri, 2014. Finally, they provide modest empirical results on two different simulated data sets (IHDP and News). In their experiments, they test three different flavors of the proposed model: fully linear, neural net representation + linear outcome prediction, and fully neural net. . On both data sets, the neural net is the clear winner on most metrics, while the linear outcome models (BLR and BNN-4-0) are -- at best -- competitive with strong baselines, including LASSO and a standard neural net. Clarity - Justification: The paper is surprisingly readable given its density, but it is not perfect. On the more minor side, the description of the baselines are sometimes hard to understand. For example, the language used to describe the BNN-4-0 and BNN-2-2 layers makes it unclear which layers do not include nonlinearities. I assume that BNN-4-0 applies four ReLU representation-only layers, followed by a single linear output layer and that BNN-2-2 has 2 ReLU representation-only layers, then two ReLU layers after the treatment is added. If that is the accurate, then some variant of the above wording might provide a more clear explanation. The paper seems to be lacking in detail related to practical implementation, in particular with respect to optimizing the objection in Equation (2), lines 287-291. How is the third term (the counterfactual prediction error) minimized? It is necessary to explicitly choose a nearest neighbor j(i)? If so, how is this search integrated into the optimization in an efficient manner? The authors might consider putting clear pseudo-code for the full algorithm in the appendix. Theorem 1 is very difficult to understand and contextualize. What are the implications of an upper bound the squared difference between the losses for separate estimated predictors of factual and counterfactual outcomes over samples drawn from EITHER factual or counterfactual distributions? The empirical results merit further discussion to help the reader fully appreciate them. In particular, some further discussion of the ranges of the ITE, ATE, and HETE metrics and what sort of performance would indicate that a model could be used to make reliable causal inferences about the real world. Beyond that, the major concern is that given the breadth of advanced machine learning topics this paper covers (transfer learning and discrepancy distance, which is theoretically challenging line of research; bandit problems and double robust estimators; potential outcomes causal inference; representation learning; etc.) makes it quite inaccessible to folks unfamiliar with one or more of those topics. This need not necessarily be addressed in this manuscript (as it is targeted to a specific machine learning community), but the authors should find other ways to communicate the core ideas and intuitions of this research to a broader audience. Significance - Justification: This is seems like a very innovative approach to causal inference from observational data -- a topic of growing importance and interest in a variety of domains, as the authors rightly point out. I think it will be of great interest to the ICML community and many others, especially since it touches upon so many areas of active inquiry in machine learning research. I think other researchers will be inspired to work on this and related problems and to apply and build upon this work. However, my enthusiasm is tempered by the limited empirical investigation and the very modest results. Detailed comments. (Explain the basis for your ratings while providing constructive feedback.): A very interesting paper. I would consider raising my scores and possibly giving the paper a strong accept and endorsement, but I would first like to hear author responses to several critiques: * In future work, the authors mention "deriving better optimization" algorithms. Just how difficult, in practice, is the optimization problem here? How is the nearest neighbor j(i) chosen? How big of a danger are local optima, given the use of an "alternating stochastic sub-gradient descent" -- especially when using a neural net? I would recommend the authors include a detailed description of the algorithm in an extended appendix (and of course, make code available, if accepted). * I will venture a guess at a take-home message for Theorem 1. Please correct me if I'm wrong and otherwise expand upon it (and I would suggest adding this sort of discussion to any future revisions). I believe that Theorem 1 loosely guarantees that we will learn reasonably consistent predictors for factual and counterfactual outcomes. Is that correct, and if so, does that seem like enough? It seems to say little about the "correctness" of the estimators... * Suppose a clinical researcher wanted to apply this to real observational medical data, draw some conclusions, and maybe list potential limitations. Is there a way to apply Theorem 1 to a real problem and produce some kind of quantity that measures how trustworthy the results are? * How might this algorithm be expected to perform in pathological cases, i.e., the respective contexts of the treatment and control groups are completely different? * The results for the linear outcome models suggest two conclusions, one not surprising, one a little disappointing. Would the authors agree with them? (1) Not surprising: the feature selection + re-weighting scheme is simply too weak of representation learning scheme to be useful for many, if not most, problems. (2) Disappointing: much of the benefit conveyed by introducing nonlinearities appears to come from capturing nonlinear interactions between treatment and context (and not from learning new representations of context)? * Could the poor performance of BNN-4-0 be due to difficulty in optimization? * Can these results be extended to other loss functions and hypothesis classes (e.g., logistic regression)? =====