Skip to content

CausalAILab/ProjectedCausalAbstractions

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Causal Abstraction Inference under Lossy Representations

This repository contains the code for the paper "Causal Abstraction Inference under Lossy Representations" by Kevin Xia and Elias Bareinboim, published to ICML 2025.

Please cite our work if you found this code useful.

Setup

Run the following code to install python requirements.

python -m pip install -r requirements.txt

To run the ColoredMNIST experiments, place the MNIST data files in dat/mnist.

Running the code

All experiment procedures can be run using the main.py or sampler.py file with the desired arguments entered. For any of the commands below, feel free to modify the hyperparameters from main.py or sampler.py.

Experiment 1 (Estimating Causal Effects)

The following commands are used to run the G-constrained NCM, the CDAG-constrained NCM, and the Projected CDAG-constrained NCM respectively for the estimation experiment.

python -m src.main MNIST_Est_Baseline estimate mnist gan --max-epochs 400 --h-size 32 --u-size 4 --scale-u-size --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --img-size 32 -t 10 -n -1 --gpu 0
python -m src.main MNIST_Est_CDAG estimate mnist gan --max-epochs 400 --h-size 32 --u-size 4 --scale-u-size --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --img-size 32 -t 10 -n -1 --use-tau --gpu 0
python -m src.main MNIST_Est_ProjCDAG estimate mnist gan --max-epochs 400 --h-size 32 --u-size 4 --scale-u-size --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --img-size 32 -t 10 -n -1 --use-tau --use-projected-cdag --gpu 0

Experiment 2 (Generating Colored MNIST Digits)

Use the following command to run the representation NCM (RNCM).

python -m src.main MNIST_BD_RNCM_Sampler sampling mnistbd gan --max-epochs 800 --h-size 32 --u-size 4 --scale-u-size --batch-norm --repr auto_enc_conditional --rep-size 16 --rep-image-only --gan-mode wgan --gan-arch biggan --disc-type biggan --img-size 64 -n 20000 --gpu 0

For the other models, use the following command to first train a base model for sampling. This model is an NCM trained on the representation level.

python -m src.main MNIST_BD_Sampler_Base sampling mnistbd gan --max-epochs 400 --h-size 32 --u-size 4 --scale-u-size --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --img-size 64 -n 20000 --use-tau --use-projected-cdag --gpu 0

The following commands are used to train the basic and projected samplers respectively that generate image samples given the representation. The basic sampler only trains a decoder, while the projected sampler trains a GAN that outputs the corresponding conditional distribution given the representation.

python -m src.sampler MNIST_BD_Basic_Sampler basic mnistbd gan <PATH> --max-epochs 800 --h-layers 4 --h-size 128 --u-size 64 --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --disc-h-size 512 --img-size 64 -n 20000 --gpu 0
python -m src.sampler MNIST_BD_Projected_Sampler projected mnistbd gan <PATH> --max-epochs 800 --h-layers 4 --h-size 128 --u-size 64 --batch-norm --gan-mode wgan --gan-arch biggan --disc-type biggan --disc-h-size 512 --img-size 64 -n 20000 --gpu 0

<PATH> refers to the path of the base sampling model described above, which must be paired with a sampler to sample original data.

Scaling Representation Experiment

For the scaling representation experiment, a new NCM base and sampler must be trained for each dimension representation.

Use the following command to generate the sampling base NCM.

python -m src.main MNIST_Sampler_Base_Scale_<d> sampling mnistbd gan --max-epochs 2000 --h-size 4 --u-size 4 --scale-u-size --scale-h-size --batch-norm --gan-mode wgan --gan-arch dcgan --disc-type standard --img-size 32 --data-repr-size <d> -n 20000 --use-tau --use-projected-cdag --gpu 0

Use the following commands to generate the basic and projected samplers respectively.

python -m src.sampler MNIST_Basic_Sampler_Scale_<d> basic mnistbd gan <PATH> --max-epochs 800 --h-size 256 --u-size 128 --batch-norm --gan-mode wgan --gan-arch dcgan --disc-type standard --disc-h-size 512 --img-size 32 --data-repr-size <d> -n 20000 --gpu 0
python -m src.sampler MNIST_Projected_Sampler_Scale_<d> projected mnistbd gan <PATH> --max-epochs 800 --h-size 256 --u-size 128 --batch-norm --gan-mode wgan --gan-arch dcgan --disc-type standard --disc-h-size 512 --img-size 32 --data-repr-size <d> -n 20000 --gpu 0

Replace all instances of <d> with the dimensionality of the representation and replace <PATH> with the path to the sampling base NCM.

Visualizing Experiments

Experiments can be visualized by running the files in src/experiment.

About

Implementation of the "Causal Abstraction Inference under Lossy Representations" paper by the authors.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages