Skip to content

Code library for training causal inference deep learning models with automatic hyperparameter optimization written in Tensorflow 2.

License

Notifications You must be signed in to change notification settings

causal-lab-miism/deep_causal_inference_ite_library

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Review of Causal Inference Deep Learning Methods for ITE Estimation with Automatic Hyperparameter Optimization

Tensorflow implementation of methods presented in:

Andrei Sirazitdinov, Marcus Buchwald, Jürgen Hesser, and Vincent Heuveline "Review of Deep Learning Methods for Individual Treatment Effect Estimation with Automatic Hyperparameter Optimization", 2022. Submitted to IEEE Transactions on Neural Networks and Learning.

Contact: andrei.sirazitdinov@medma.uni-heidelberg.de, marcus.buchwald@medma.uni-heidelberg.de

Currently the following causal inference methods present in the library: TLearner, SLearner, RLearner, XLearner, TARNet, CFR-Wasserstein, CFR-Weighted, CFR-MMD, DragonNet, TEDVAE, CEVAE, GANITE, DKLITE.

Requirements:

  1. Python 3.9
  2. Tensorflow 2.11
  3. Tensorflow Probability 0.19.0
  4. keras-tuner 1.1.3
  5. protobuf<=3.20.x
  6. pandas
  7. scikit-learn
  8. Numpy

To run the code use:

python main.py [-h] [--model-name MODEL_NAME] [--ipm-type IPM_TYPE] [--dataset-name DATASET_NAME] [--num NUM]

Example:
python main.py --model-name TARnet --dataset-name ihdp_b --num 100 --ipm-type None

Alternatively use Jupyter notebook.

See the full list of available models and datasets in main.py.

The file hyperparameters.py contains hyperparameters such as batch size or learning rate for the presented models.

Note, that our code performs hyperparameter search at first execution for each method to find the other hyperparameters.

We output PEHE or policy risk for each sub dataset and after training on all datasets an average PEHE or policy risk with 95% confidence interval.