Skip to content

Alex-Markham/context-module

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Context module

This repo provides a Python implementation of the context module and experimental results from the paper Intervening to learn and compose causally disentangled representations (arXiv:2507.04754 [stat.ML]).

The context module implementation is provided in src/conceptualizer/conceptualizer.py. The function dec_conceptualizer() constructs it from PyTorch nn.Modules, and an example of how to integrate it into a simple VAE is shown in the VAE class of src/conceptualizer/conceptualizer.py; this is the "lightweight-VAE" described in the paper.

Experiments are organized into a Snakemake workflow, with the src/expt/workflow/Snakefile entry point.

Dependencies, versioning, and installation can all be handled by uv, with the included pyproject.toml and uv.lock containing all necessary information. The pinned requirements.txt with hashes is provided for users of other package managers.

After installing uv:

  • uv run snakemake all --forceall --cores 8 can in principle be run from src/expt/ to run experiments and reproduce all tables in the paper; however, it takes a few hundred GPU hours, so it's practical to use the Snakemake executor plugin to run it on a Slurm-managed GPU cluster.
    • adding dry run flag -n instructs Snakemake to list all the jobs without actually running them, showing jobs for the 350 (train_model + train_disent + fine_tune) models to be trained, along with jobs for evaluating them, preparing data, and formatting results) producing the following:
      Job stats:
      job                                     count
      ------------------------------------  -------
      all                                         1
      collect_3dident_disent_results              1
      collect_beta_results                        1
      collect_expressivity_2_results              1
      collect_expressivity_4_results              1
      collect_groupnorm_results                   1
      collect_l2norm_results                      1
      collect_mnist_beta_results                  1
      collect_mnist_expressivity_2_results        1
      collect_mnist_expressivity_4_results        1
      collect_mnist_groupnorm_results             1
      collect_mnist_l2norm_results                1
      collect_mnist_mnist_ablation_results        1
      collect_quad_ablation_results               1
      collect_quad_causal_ablation_results        1
      eval_disent                                60
      eval_model                                290
      fine_tune                                  30
      format_result                             290
      format_summary_table                        1
      format_table                               14
      prepare_data                                4
      train_disent                               60
      train_model                               260
      total                                    1024
      
  • uv run snakemake test --cores 8 can be run from src/expt/ to test that everything works.
    • It will download the quad dataset and train one model.
    • This requires a GPU and takes around 90 mins on an H100.
    • The trained model and various evaluation metrics and plots can then be found at src/expt/results/dataset=quad/arch=simple64bn/method=2/expressivity=2.32.8/beta=1.0/sparsity=0.0.0.0/seed=0/; in particular, a plot of original vs reconstructed images can then be found there (reconstructions.png) and generated samples can be found there in the random_samples subdirectory, for example random_samples/quad2-quad3.png.

Citing

If you find the code helpful, please cite it using the following bibtex:

@InProceedings{markham2026,
  title = 	     {Intervening to learn and compose causally disentangled representations},
  author =       {Markham, Alex and Hirsch, Isaac and Chang, Jeri A. and Solus, Liam and Aragam, Bryon},
  booktitle = 	 {Proceedings of the Fifth Conference on Causal Learning and Reasoning},
  year = 	     {2026},
  editor = 	     {Bijan Mazaheri and Niels Richard Hansen},
  series = 	     {Proceedings of Machine Learning Research},
  month = 	     {Apr},
  publisher =    {PMLR},
}

Contact

Feel free to raise an issue or email me with questions about reproducing the experimental results, modifying the context module, or integrating it into your deep generative model!

About

Code Supplement to the paper "Intervening to Learn and Compose Causally Disentangled Representations" at CLeaR'26

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors