Skip to content

Estimating the ATE #4

@12kleingordon34

Description

@12kleingordon34

Hello,

I am writing to ask a couple of questions regarding your code base. I would like to use it to estimate the average treatment effect of a confounded dataset, with treatment $X$, outcome $Y$, and four pretreatment covariates $Z_1, Z_2, Z_3, Z_4$. We have the following factorisation of the pretreatment_covariate_joint $P(Z_1)~P(Z_2\mid Z_1)~P(Z_3\mid Z_1,~Z_2)~P(Z_4\mid Z_1,~Z_2,~Z_3)$. Our aim is to infer the ATE of $X$ on $Y$. We define the following class, pasted below

import torch
import torch.nn.functional as F

from causal_nf.sem_equations.sem_base import SEM


class TestModel(SEM):
    def __init__(self):
        functions = None
        inverses = None

        super().__init__(functions, inverses, None)

    def adjacency(self, add_diag=False):
        adj = torch.zeros((6, 6))
        adj[0, :] = torch.tensor([0, 0, 0, 0, 0, 0]) # Z1
        adj[1, :] = torch.tensor([1, 0, 0, 0, 0, 0]) # Z2
        adj[2, :] = torch.tensor([1, 1, 0, 0, 0, 0]) # Z3
        adj[3, :] = torch.tensor([1, 1, 1, 0, 0, 0]) # Z4
        adj[4, :] = torch.tensor([1, 1, 1, 1, 0, 0]) # X
        adj[5, :] = torch.tensor([1, 1, 1, 1, 1, 0]) # Y

        if add_diag:
            adj += torch.eye(6)

        return adj

    def intervention_index_list(self):
        return [0, 4]

I have made custom Preparator, DataLoader classes, and a config file. The model has already been fit and in the code I am loading it from the last checkpoint. I run the following code, which looks to estimate the ATE from 5 different samples from the fitted model. However, the ATE estimates do not appear to be consistently close with the true value in my benchmark dataset, and I wonder if you could point out any possible issues in the code pasted below:





import causal_nf.config as causal_nf_config
from causal_nf.config import cfg
import causal_nf.utils.training as causal_nf_train
from yacs.config import CfgNode
import torch
import causal_nf.utils.io as causal_nf_io
import numpy as np


from causal_nf.preparators.MY_preparator import MYPreparator
from causal_nf.config import cfg

seed = 10
args_list = []
args =  CfgNode({‘config_file’: f’{folder}/{ckpt_code}/wandb_local/config_local.yaml’,
                 ‘config_default_file’: f’{folder}/{ckpt_code}/wandb_local/default_config.yaml’,
                 ‘project’: None, ‘wandb_mode’: ‘disabled’, ‘wandb_group’: None,
                 ‘load_model’: f’{folder}/{ckpt_code}’, ‘delete_ckpt’: False})
config = causal_nf_config.build_config(
config_file=args.config_file,
    args_list=args_list,
    config_default_file=args.config_default_file,
)
causal_nf_config.assert_cfg_and_config(cfg, config)
preparator = MYPreparator.loader(cfg.dataset)
preparator.prepare_data()
model_lightning = causal_nf_train.load_model(cfg=cfg, preparator=preparator, ckpt_file=check_file)
model = model_lightning.model
model.eval()
loaders = preparator.get_dataloaders(
    batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers
)

n_rounds = 5
ates = []
seeds = np.arange(n_rounds)
batch = next(iter(loaders[-1]))
for i, seed in enumerate(seeds):

    int_dict = {‘name’: ‘1_0’, ‘a’: 1., ‘b’: 0., ‘index’: 4}


    name = int_dict[“name”]
    a = int_dict[“a”]#1.
    b = int_dict[“b”]#0.
    index = int_dict[“index”]

    torch.random.manual_seed(seed)
    ate = model_lightning.model.compute_ate(
        index,
        a=a,
        b=b,
        num_samples=10000,
        scaler=preparator.scaler_transform,
    )

    ates.append(ate.detach().numpy())
print(ates[-1])

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions