This repo provides a Python implementation of the context module and experimental results from the paper Intervening to learn and compose causally disentangled representations (arXiv:2507.04754 [stat.ML]).
The context module implementation is provided in src/conceptualizer/conceptualizer.py.
The function dec_conceptualizer() constructs it from PyTorch nn.Modules, and an example of how to integrate it into a simple VAE is shown in the VAE class of src/conceptualizer/conceptualizer.py; this is the "lightweight-VAE" described in the paper.
Experiments are organized into a Snakemake workflow, with the src/expt/workflow/Snakefile entry point.
Dependencies, versioning, and installation can all be handled by uv, with the included pyproject.toml and uv.lock containing all necessary information.
The pinned requirements.txt with hashes is provided for users of other package managers.
After installing uv:
uv run snakemake all --forceall --cores 8can in principle be run fromsrc/expt/to run experiments and reproduce all tables in the paper; however, it takes a few hundred GPU hours, so it's practical to use the Snakemake executor plugin to run it on a Slurm-managed GPU cluster.- adding dry run flag
-ninstructs Snakemake to list all the jobs without actually running them, showing jobs for the 350 (train_model+train_disent+fine_tune) models to be trained, along with jobs for evaluating them, preparing data, and formatting results) producing the following:Job stats: job count ------------------------------------ ------- all 1 collect_3dident_disent_results 1 collect_beta_results 1 collect_expressivity_2_results 1 collect_expressivity_4_results 1 collect_groupnorm_results 1 collect_l2norm_results 1 collect_mnist_beta_results 1 collect_mnist_expressivity_2_results 1 collect_mnist_expressivity_4_results 1 collect_mnist_groupnorm_results 1 collect_mnist_l2norm_results 1 collect_mnist_mnist_ablation_results 1 collect_quad_ablation_results 1 collect_quad_causal_ablation_results 1 eval_disent 60 eval_model 290 fine_tune 30 format_result 290 format_summary_table 1 format_table 14 prepare_data 4 train_disent 60 train_model 260 total 1024
- adding dry run flag
uv run snakemake test --cores 8can be run fromsrc/expt/to test that everything works.- It will download the
quaddataset and train one model. - This requires a GPU and takes around 90 mins on an H100.
- The trained model and various evaluation metrics and plots can then be found at
src/expt/results/dataset=quad/arch=simple64bn/method=2/expressivity=2.32.8/beta=1.0/sparsity=0.0.0.0/seed=0/; in particular, a plot of original vs reconstructed images can then be found there (reconstructions.png) and generated samples can be found there in therandom_samplessubdirectory, for examplerandom_samples/quad2-quad3.png.
- It will download the
If you find the code helpful, please cite it using the following bibtex:
@InProceedings{markham2026,
title = {Intervening to learn and compose causally disentangled representations},
author = {Markham, Alex and Hirsch, Isaac and Chang, Jeri A. and Solus, Liam and Aragam, Bryon},
booktitle = {Proceedings of the Fifth Conference on Causal Learning and Reasoning},
year = {2026},
editor = {Bijan Mazaheri and Niels Richard Hansen},
series = {Proceedings of Machine Learning Research},
month = {Apr},
publisher = {PMLR},
}
Feel free to raise an issue or email me with questions about reproducing the experimental results, modifying the context module, or integrating it into your deep generative model!