Skip to content

Latest commit

 

History

History
124 lines (100 loc) · 3.9 KB

README.md

File metadata and controls

124 lines (100 loc) · 3.9 KB

Latent Diffusion Counterfactual Explanations

This is the official code of the paper Latent Diffusion Counterfactual Explanations.

If this work is useful to you, please consider citing our paper:

@misc{farid2023latent,
    title={Latent Diffusion Counterfactual Explanations}, 
    author={Karim Farid and Simon Schrodi and Max Argus and Thomas Brox},
    year={2023},
    eprint={2310.06668},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Requirements

A suitable conda environment named ldm can be created and activated with:

conda env create -f environment.yaml
conda activate ldm

Download the following model weights:

Counterfactual generation with LDCE

Before generating counterfactuals, you need to configure the config file in configs/ldce/*.yaml, e.g., set the paths to the dataset etc.

Below we provide the commands to reproduce the results from our paper.

ImageNet

All classes (Table 1)

For class-conditional diffusion model:

python -m scripts.ldce --config-name=v1_wider \
    data.batch_size=5 \
    strength=0.382 \
    data.start_sample=$id data.end_sample=$((id+1)) > logs/imagenet_sd_${id}.log 

For text-conditional diffusion model:

python -m scripts.ldce --config-name=v1_stable_diffusion \
    data.batch_size=4 \
    sampler.classifier_lambda=3.95 \
    sampler.dist_lambda=1.2 \
    sampler.deg_cone_projection=50. \
    data.start_sample=$id data.end_sample=$((id+1)) > logs/imagenet_sd_${id}.log 

Only pairs (Table 2; here exemplary for zebra-sorrel)

For class-conditional diffusion model:

python -m scripts.ldce --config-name=v1_zs \
    data.batch_size=4 > logs/zs_cls.log 

For text-conditional diffusion model:

python -m scripts.ldce --config-name=v1_zs \
    data.batch_size=4 \
    strength=0.382 \
    sampler.classifier_lambda=3.95 \
    sampler.dist_lambda=1.2 \
    sampler.deg_cone_projection=50. \
    diffusion_model.cfg_path="configs/stable-diffusion/v1-inference.yaml" \
    diffusion_model.ckpt_path="/path/to/miniSD.ckpt" > logs/zs_sd.log 

CelebA HQ (Table 6)

python -m scripts.ldce --config-name=v1_celebAHQ \
    data.batch_size=4 \
    sampler.classifier_lambda=4.0 \
    sampler.dist_lambda=3.3 \
    data.num_shards=7 \
    sampler.deg_cone_projection=55. \
    data.shard=$id \
    strength=$strength > logs/celeb_smile.log 

Flowers 102

python -m scripts.ldce --config-name=v1_flowers\
    data.batch_size=4 \
    strength=0.5 \
    sampler.classifier_lambda=3.4 \
    sampler.dist_lambda=1.2 \
    output_dir=results/flowers \
    data.num_shards=7 \
    data.shard=${id} \
     > logs/flowers_${id}.log 

Oxford-IIIT Pets

python -m scripts.ldce --config-name=v1_pets\
    data.batch_size=4 \
    sampler.classifier_lambda=4.2 \
    sampler.dist_lambda=2.4 \
    data.num_shards=7 \
    data.shard=$id \
     > logs/pets_${id}.log 

Acknowledgements

We thank the following GitHub users/researchers/groups: