Causal Representation Learning

Jan 15 · 10min

Introduction & Motivation

What is causal representation learning? It involves methods that aim not to extract low-dimensional features, but to encode the causal factors of variation underlying observed data. To move beyond distributional robustness, we seek models that can reason under interventions and support counterfactual queries. This requires mathematical tools from causal inference and practical algorithms to discover disentangled, causally-meaningful representations.

Generative Model & SCM Formalization

Let be an observed datapoint. Assume that is generated by underlying causal variables . Under a Structural Causal Model (SCM):

Goal: Find an encoder s.t. is (approximately) causally correct and disentangled. We would like

Common Approaches & Algorithms

  • FactorVAE, BetaVAE: Encourage disentanglement via penalizing total correlation of .
  • CausalVAE: Explicit structure priors enforce relations between reflecting causal graph.
  • Invariant Risk Minimization (IRM): Learn features such that yields an invariant predictor across environments.
# Example: FactorVAE loss (PyTorch-like pseudocode)
def factor_vae_loss(x, encoder, decoder, discriminator):
    z_mu, z_logvar = encoder(x)
    z = reparameterize(z_mu, z_logvar)
    recon_x = decoder(z)
    recon_loss = F.mse_loss(recon_x, x)
    tc = total_correlation(z, discriminator)
    return recon_loss + beta * tc

Key Concepts & Equations

  • Total Correlation: (measures independence of latent dimensions)
  • Identifiability: Under certain conditions (e.g., multiple environments, known interventions), causal features are learnable.
  • Structural Hamming Distance: For causal graph evaluation, counts mismatched edges between graphs.

Mathematical Example:

# SCM causal simulation
def scm_sim():
    x1 = np.random.normal()
    x2 = 2*x1 + np.random.normal()
    x3 = x2 - x1 + np.random.normal()
    return np.stack([x1, x2, x3])

This code simulates a simple, linear SCM. Changing ("intervening on") propagates through the system.

Evaluation Metrics

  • Modularity: Measure sensitivity of each dimension to changes in data or interventions.
  • Disentanglement: Metrics like DCI and Mutual Information Gap (MIG).

Open Challenges

  • Proving identifiability with finite data.
  • Efficient algorithms for high-dimensional, non-linear SCMs.
  • Bridging deep learning and graphical model causal theory.

References

  1. Locatello et al. (2019). Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations
  2. Schölkopf et al. (2021). Towards Causal Representation Learning
  3. Subramanian et al. (2022) Learning Latent Structural Causal Models
> comment on twitter
>