diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 9acac13..1d09061 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -134,6 +134,7 @@ class SEQopts: subgroup_colname: str = None treatment_level: List[int] = field(default_factory=lambda: [0, 1]) trial_include: bool = True + visit_colname: str = None weight_eligible_colnames: List[str] = field(default_factory=lambda: []) weight_min: float = 0.0 weight_max: float = None diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index b068dc9..e7588a6 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -17,9 +17,16 @@ from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, - _weight_bind, _weight_predict, _weight_setup, - _weight_stats) +from .weighting import ( + _fit_denominator, + _fit_LTFU, + _fit_numerator, + _fit_visit, + _weight_bind, + _weight_predict, + _weight_setup, + _weight_stats, +) class SEQuential: @@ -93,7 +100,7 @@ def __init__( if self.denominator is None: self.denominator = _denominator(self) - if self.cense_colname is not None: + if self.cense_colname is not None or self.visit_colname is not None: if self.cense_numerator is None: self.cense_numerator = _cense_numerator(self) @@ -112,6 +119,7 @@ def expand(self) -> None: self.cense_colname, self.cense_eligible_colname, self.compevent_colname, + self.visit_colname, *self.weight_eligible_colnames, *self.excused_colnames, ] @@ -212,6 +220,7 @@ def fit(self) -> None: WDT[col] = WDT[col].astype("category") _fit_LTFU(self, WDT) + _fit_visit(self, WDT) _fit_numerator(self, WDT) _fit_denominator(self, WDT) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 617fc2c..5e557a8 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -163,9 +163,9 @@ def _calculate_risk(self, data, idx=None, val=None): .group_by("followup") .agg( [ - pl.col("risk").std().alias("SE"), - pl.col("risk").quantile(lci).alias("LCI"), - pl.col("risk").quantile(uci).alias("UCI"), + pl.col("risk").std().cast(pl.Float64).alias("SE"), + pl.col("risk").quantile(lci).cast(pl.Float64).alias("LCI"), + pl.col("risk").quantile(uci).cast(pl.Float64).alias("UCI"), ] ) .join(TxDT.select(["followup", main_col]), on="followup") @@ -198,9 +198,9 @@ def _calculate_risk(self, data, idx=None, val=None): .group_by("followup") .agg( [ - pl.col("inc_val").std().alias("inc_SE"), - pl.col("inc_val").quantile(lci).alias("inc_LCI"), - pl.col("inc_val").quantile(uci).alias("inc_UCI"), + pl.col("inc_val").std().cast(pl.Float64).alias("inc_SE"), + pl.col("inc_val").quantile(lci).cast(pl.Float64).alias("inc_LCI"), + pl.col("inc_val").quantile(uci).cast(pl.Float64).alias("inc_UCI"), ] ) .join(TxDT.select(["followup", "inc"]), on="followup") diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 3a96448..c048583 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -31,8 +31,15 @@ def _param_checker(self): "Only one of followup_class or followup_include can be set to True." ) - if self.weighted and self.method == "ITT" and self.cense_colname is None: - raise ValueError("For weighted ITT analyses, cense_colname must be provided.") + if ( + self.weighted + and self.method == "ITT" + and self.cense_colname is None + and self.visit_colname is None + ): + raise ValueError( + "For weighted ITT analyses, cense_colname or visit_colname must be provided." + ) if self.excused: _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames) diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py index 4874865..65e5ca7 100644 --- a/pySEQTarget/weighting/__init__.py +++ b/pySEQTarget/weighting/__init__.py @@ -3,5 +3,6 @@ from ._weight_fit import _fit_denominator as _fit_denominator from ._weight_fit import _fit_LTFU as _fit_LTFU from ._weight_fit import _fit_numerator as _fit_numerator +from ._weight_fit import _fit_visit as _fit_visit from ._weight_pred import _weight_predict as _weight_predict from ._weight_stats import _weight_stats as _weight_stats diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index 8f8f4a8..91e50c6 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -12,6 +12,11 @@ def _weight_bind(self, WDT): WDT = self.DT.join(WDT, on=on, how=join) + if self.visit_colname is not None: + visit = pl.col(self.visit_colname) == 0 + else: + visit = pl.lit(False) + if self.weight_preexpansion and self.excused: trial = (pl.col("trial") == 0) & (pl.col("period") == 0) excused = ( @@ -21,6 +26,7 @@ def _weight_bind(self, WDT): override = ( trial | excused + | visit | pl.col(self.outcome_col).is_null() | (pl.col("denominator") < 1e-7) ) @@ -33,6 +39,7 @@ def _weight_bind(self, WDT): override = ( trial | excused + | visit | pl.col(self.outcome_col).is_null() | (pl.col("denominator") < 1e-7) | (pl.col("numerator") < 1e-7) @@ -45,24 +52,35 @@ def _weight_bind(self, WDT): override = ( trial | excused + | visit | pl.col(self.outcome_col).is_null() | (pl.col("denominator") < 1e-15) | pl.col("numerator").is_null() ) self.DT = ( - WDT.with_columns( - pl.when(override) - .then(pl.lit(1.0)) - .otherwise(pl.col("numerator") / pl.col("denominator")) - .alias("wt") + ( + WDT.with_columns( + pl.when(override) + .then(pl.lit(1.0)) + .otherwise(pl.col("numerator") / pl.col("denominator")) + .alias("wt") + ) + .sort([self.id_col, "trial", "followup"]) + .with_columns( + pl.col("wt") + .fill_null(1.0) + .cum_prod() + .over([self.id_col, "trial"]) + .alias("weight") + ) ) - .sort([self.id_col, "trial", "followup"]) .with_columns( - pl.col("wt") - .fill_null(1.0) - .cum_prod() - .over([self.id_col, "trial"]) - .alias("weight") + ( + pl.col("weight") + * pl.col("_cense").fill_null(1.0) + * pl.col("_visit").fill_null(1.0) + ).alias("weight") ) + .drop(["_cense", "_visit"]) ) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index c6495ba..b2e5f38 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -2,22 +2,45 @@ import statsmodels.formula.api as smf +def _fit_pair( + self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None +): + outcome = getattr(self, outcome_attr) + + if eligible_colname_attr is not None: + _eligible_col = getattr(self, eligible_colname_attr) + if _eligible_col is not None: + WDT = WDT[WDT[_eligible_col] == 1] + + for rhs, out in zip(formula_attr, output_attrs): + formula = f"{outcome}~{rhs}" + model = smf.glm(formula, WDT, family=sm.families.Binomial()) + setattr(self, out, model.fit(disp=0)) + + def _fit_LTFU(self, WDT): if self.cense_colname is None: return - else: - fits = [] - if self.cense_eligible_colname is not None: - WDT = WDT[WDT[self.cense_eligible_colname] == 1] + _fit_pair( + self, + WDT, + "cense_colname", + [self.cense_numerator, self.cense_denominator], + ["cense_numerator", "cense_denominator"], + "cense_eligible_colname", + ) - for i in [self.cense_numerator, self.cense_denominator]: - formula = f"{self.cense_colname}~{i}" - model = smf.glm(formula, WDT, family=sm.families.Binomial()) - model_fit = model.fit(disp=0) - fits.append(model_fit) - self.cense_numerator = fits[0] - self.cense_denominator = fits[1] +def _fit_visit(self, WDT): + if self.visit_colname is None: + return + _fit_pair( + self, + WDT, + "visit_colname", + [self.cense_numerator, self.cense_denominator], + ["visit_numerator", "visit_denominator"], + ) def _fit_numerator(self, WDT): diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 7e77c4d..f28ba72 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -149,20 +149,44 @@ def _weight_predict(self, WDT): .otherwise(pl.col("numerator")) .alias("numerator") ) - if self.cense_colname is not None: - p_num = _predict_model(self, self.cense_numerator, WDT).flatten() - p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() - WDT = WDT.with_columns( - [ - pl.Series("cense_numerator", p_num), - pl.Series("cense_denominator", p_denom), - ] - ).with_columns( - (pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense") - ) - else: - WDT = WDT.with_columns(pl.lit(1.0).alias("cense")) + if self.cense_colname is not None: + p_num = _predict_model(self, self.cense_numerator, WDT).flatten() + p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() + WDT = WDT.with_columns( + [ + pl.Series("cense_numerator", p_num), + pl.Series("cense_denominator", p_denom), + ] + ).with_columns( + (pl.col("cense_numerator") / pl.col("cense_denominator")).alias("_cense") + ) + else: + WDT = WDT.with_columns(pl.lit(1.0).alias("_cense")) - kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"] + if self.visit_colname is not None: + p_num = _predict_model(self, self.visit_numerator, WDT).flatten() + p_denom = _predict_model(self, self.visit_denominator, WDT).flatten() + + WDT = WDT.with_columns( + [ + pl.Series("visit_numerator", p_num), + pl.Series("visit_denominator", p_denom), + ] + ).with_columns( + (pl.col("visit_numerator") / pl.col("visit_denominator")).alias("_visit") + ) + else: + WDT = WDT.with_columns(pl.lit(1.0).alias("_visit")) + + kept = [ + "numerator", + "denominator", + "_cense", + "_visit", + self.id_col, + "trial", + time, + "tx_lag", + ] exists = [col for col in kept if col in WDT.columns] return WDT.select(exists).sort(grouping + [time]) diff --git a/pyproject.toml b/pyproject.toml index b35eb1d..dea94bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.9.2" +version = "0.10.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index aa80120..9debc39 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,5 +1,4 @@ -import sys -from pathlib import Path +from pytest import approx from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data @@ -35,7 +34,7 @@ def test_ITT_coefs(): -0.01339242049205771, 0.20072409918428052, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PreE_dose_response_coefs(): @@ -66,7 +65,7 @@ def test_PreE_dose_response_coefs(): 0.010537967151467553, 0.0007772316818101141, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PostE_dose_response_coefs(): @@ -101,7 +100,7 @@ def test_PostE_dose_response_coefs(): -0.02106338184559446, 0.14867250693140854, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PreE_censoring_coefs(): @@ -132,7 +131,7 @@ def test_PreE_censoring_coefs(): 0.0011281734101133744, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PostE_censoring_coefs(): @@ -165,7 +164,7 @@ def test_PostE_censoring_coefs(): 0.013503198983327514, 0.4466573801510379, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PreE_censoring_excused_coefs(): @@ -199,7 +198,7 @@ def test_PreE_censoring_excused_coefs(): 0.08365564531281737, -0.0006220464783614585, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PostE_censoring_excused_coefs(): @@ -237,7 +236,7 @@ def test_PostE_censoring_excused_coefs(): 0.0014679503081325516, 0.3008769969502361, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PreE_LTFU_ITT(): @@ -261,18 +260,18 @@ def test_PreE_LTFU_ITT(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -21.640523091572796, - 0.0685235184372898, - -0.19006360662228572, - 0.028750950193838918, - -0.0005762057433736666, - 0.28554312978583757, - -0.001373044229623057, - 0.006589141394458155, - -0.44898959259422394, - 1.3875089788036237, + -21.636346991788276, + 0.06813705852786496, + -0.1939555961858531, + 0.02874152772603635, + -0.0005734047013500563, + 0.2854740212699898, + -0.0013729662310668182, + 0.006501915963316852, + -0.4467079969655381, + 1.3870473474960576, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_PostE_LTFU_ITT(): @@ -294,18 +293,18 @@ def test_PostE_LTFU_ITT(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -21.640523091572796, - 0.0685235184372898, - -0.19006360662228572, - 0.028750950193838918, - -0.0005762057433736666, - 0.28554312978583757, - -0.001373044229623057, - 0.006589141394458155, - -0.44898959259422394, - 1.3875089788036237, + -21.847198431385877, + 0.07786703138967718, + -0.15461370944416225, + 0.030140057462437704, + -0.0006287338029348562, + 0.287393206037481, + -0.0013719595115633126, + 0.007295485861066434, + -0.42797049565882755, + 1.4082102322835948, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_ITT_multinomial(): @@ -338,7 +337,7 @@ def test_ITT_multinomial(): 0.7847862691929901, 1.4703411759229423, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) def test_weighted_multinomial(): @@ -370,4 +369,39 @@ def test_weighted_multinomial(): 5.743984176710672, -0.08478678955657822, ] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + assert matrix == approx(expected, rel=1e-3) + + +def test_ITT_visit(): + data = load_data("SEQdata_LTFU") + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + weighted=True, weight_preexpansion=True, visit_colname="LTFU" + ), + ) + s.expand() + s.fit() + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -21.636346991788276, + 0.06813705852786496, + -0.1939555961858531, + 0.02874152772603635, + -0.0005734047013500563, + 0.2854740212699898, + -0.0013729662310668182, + 0.006501915963316852, + -0.4467079969655381, + 1.3870473474960576, + ] + assert matrix == approx(expected, rel=1e-3)