-
Notifications
You must be signed in to change notification settings - Fork 8
Description
Hello,
I am writing to ask a couple of questions regarding your code base. I would like to use it to estimate the average treatment effect of a confounded dataset, with treatment
import torch
import torch.nn.functional as F
from causal_nf.sem_equations.sem_base import SEM
class TestModel(SEM):
def __init__(self):
functions = None
inverses = None
super().__init__(functions, inverses, None)
def adjacency(self, add_diag=False):
adj = torch.zeros((6, 6))
adj[0, :] = torch.tensor([0, 0, 0, 0, 0, 0]) # Z1
adj[1, :] = torch.tensor([1, 0, 0, 0, 0, 0]) # Z2
adj[2, :] = torch.tensor([1, 1, 0, 0, 0, 0]) # Z3
adj[3, :] = torch.tensor([1, 1, 1, 0, 0, 0]) # Z4
adj[4, :] = torch.tensor([1, 1, 1, 1, 0, 0]) # X
adj[5, :] = torch.tensor([1, 1, 1, 1, 1, 0]) # Y
if add_diag:
adj += torch.eye(6)
return adj
def intervention_index_list(self):
return [0, 4]
I have made custom Preparator, DataLoader classes, and a config file. The model has already been fit and in the code I am loading it from the last checkpoint. I run the following code, which looks to estimate the ATE from 5 different samples from the fitted model. However, the ATE estimates do not appear to be consistently close with the true value in my benchmark dataset, and I wonder if you could point out any possible issues in the code pasted below:
import causal_nf.config as causal_nf_config
from causal_nf.config import cfg
import causal_nf.utils.training as causal_nf_train
from yacs.config import CfgNode
import torch
import causal_nf.utils.io as causal_nf_io
import numpy as np
from causal_nf.preparators.MY_preparator import MYPreparator
from causal_nf.config import cfg
seed = 10
args_list = []
args = CfgNode({‘config_file’: f’{folder}/{ckpt_code}/wandb_local/config_local.yaml’,
‘config_default_file’: f’{folder}/{ckpt_code}/wandb_local/default_config.yaml’,
‘project’: None, ‘wandb_mode’: ‘disabled’, ‘wandb_group’: None,
‘load_model’: f’{folder}/{ckpt_code}’, ‘delete_ckpt’: False})
config = causal_nf_config.build_config(
config_file=args.config_file,
args_list=args_list,
config_default_file=args.config_default_file,
)
causal_nf_config.assert_cfg_and_config(cfg, config)
preparator = MYPreparator.loader(cfg.dataset)
preparator.prepare_data()
model_lightning = causal_nf_train.load_model(cfg=cfg, preparator=preparator, ckpt_file=check_file)
model = model_lightning.model
model.eval()
loaders = preparator.get_dataloaders(
batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers
)
n_rounds = 5
ates = []
seeds = np.arange(n_rounds)
batch = next(iter(loaders[-1]))
for i, seed in enumerate(seeds):
int_dict = {‘name’: ‘1_0’, ‘a’: 1., ‘b’: 0., ‘index’: 4}
name = int_dict[“name”]
a = int_dict[“a”]#1.
b = int_dict[“b”]#0.
index = int_dict[“index”]
torch.random.manual_seed(seed)
ate = model_lightning.model.compute_ate(
index,
a=a,
b=b,
num_samples=10000,
scaler=preparator.scaler_transform,
)
ates.append(ate.detach().numpy())
print(ates[-1])
Thanks!