From 90fc87f185613b0d7e552974f66ee7773e955b73 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:51:07 +0100 Subject: [PATCH 1/2] matched excused conditions --- docs/conf.py | 1 + pySEQTarget/SEQuential.py | 31 +++-- pySEQTarget/analysis/__init__.py | 3 +- pySEQTarget/expansion/_dynamic.py | 48 ++++--- pySEQTarget/weighting/_weight_fit.py | 2 +- pySEQTarget/weighting/_weight_pred.py | 172 +++++++++++++++++++------- tests/test_coefficients.py | 51 +++++--- 7 files changed, 219 insertions(+), 89 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0fe672d..a7f536b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,6 +4,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import importlib.metadata + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index b068dc9..d13685f 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -7,19 +7,36 @@ import numpy as np import polars as pl -from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit, - _pred_risk, _risk_estimates, _subgroup_fit) +from .analysis import ( + _calculate_hazard, + _calculate_survival, + _outcome_fit, + _pred_risk, + _risk_estimates, + _subgroup_fit, +) from .error import _datachecker, _param_checker from .expansion import _binder, _diagnostics, _dynamic, _random_selection from .helpers import _col_string, _format_time, bootstrap_loop -from .initialization import (_cense_denominator, _cense_numerator, - _denominator, _numerator, _outcome) +from .initialization import ( + _cense_denominator, + _cense_numerator, + _denominator, + _numerator, + _outcome, +) from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, - _weight_bind, _weight_predict, _weight_setup, - _weight_stats) +from .weighting import ( + _fit_denominator, + _fit_LTFU, + _fit_numerator, + _weight_bind, + _weight_predict, + _weight_setup, + _weight_stats, +) class SEQuential: diff --git a/pySEQTarget/analysis/__init__.py b/pySEQTarget/analysis/__init__.py index 6799dfd..e35ceb7 100644 --- a/pySEQTarget/analysis/__init__.py +++ b/pySEQTarget/analysis/__init__.py @@ -3,6 +3,5 @@ from ._risk_estimates import _risk_estimates as _risk_estimates from ._subgroup_fit import _subgroup_fit as _subgroup_fit from ._survival_pred import _calculate_survival as _calculate_survival -from ._survival_pred import \ - _get_outcome_predictions as _get_outcome_predictions +from ._survival_pred import _get_outcome_predictions as _get_outcome_predictions from ._survival_pred import _pred_risk as _pred_risk diff --git a/pySEQTarget/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py index d2a99c5..932ba75 100644 --- a/pySEQTarget/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -25,37 +25,49 @@ def _dynamic(self): switch = ( pl.when(pl.col("followup") == 0) .then(pl.lit(False)) - .otherwise( - (pl.col("tx_lag").is_not_null()) - & (pl.col("tx_lag") != pl.col(self.treatment_col)) - ) + .otherwise(pl.col("tx_lag") != pl.col(self.treatment_col)) ) is_excused = pl.lit(False) if self.excused: conditions = [] - for i in range(len(self.treatment_level)): + for i, val in enumerate(self.treatment_level): colname = self.excused_colnames[i] if colname is not None: conditions.append( - (pl.col(colname) == 1) - & (pl.col(self.treatment_col) == self.treatment_level[i]) + (pl.col(colname) == 1) & (pl.col(self.treatment_col) == val) ) if conditions: excused = pl.any_horizontal(conditions) is_excused = switch & excused - switch = pl.when(excused).then(pl.lit(False)).otherwise(switch) - - DT = ( - DT.with_columns([switch.alias("switch"), is_excused.alias("isExcused")]) - .sort([self.id_col, "trial", "followup"]) - .filter( - (pl.col("switch").cum_max().shift(1, fill_value=False)).over( - [self.id_col, "trial"] + + DT = DT.with_columns( + [switch.alias("switch"), is_excused.alias("isExcused")] + ).sort([self.id_col, "trial", "followup"]) + + if self.excused: + DT = ( + DT.with_columns( + pl.col("isExcused") + .cast(pl.Int8) + .cum_sum() + .over([self.id_col, "trial"]) + .alias("_excused_tmp") ) - == 0 + .with_columns( + pl.when(pl.col("_excused_tmp") > 0) + .then(pl.lit(False)) + .otherwise(pl.col("switch")) + .alias("switch") + ) + .drop("_excused_tmp") ) - .with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) - ) + + DT = DT.filter( + (pl.col("switch").cum_max().shift(1, fill_value=False)).over( + [self.id_col, "trial"] + ) + == 0 + ).with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) self.DT = DT.drop(["tx_lag"]) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index 2996764..f9b08f9 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -40,7 +40,7 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - + DT_subset.to_csv("fml.csv") model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 5a858a8..7e77c4d 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -18,53 +18,137 @@ def _weight_predict(self, WDT): [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] ) - for i, level in enumerate(self.treatment_level): - mask = pl.col("tx_lag") == level - lag_mask = (WDT["tx_lag"] == level).to_numpy() - - if self.denominator_model[i] is not None: - pred_denom = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.denominator_model[i], subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) - else: - pred_denom = np.ones(WDT.height) - - if hasattr(self, "numerator_model") and self.numerator_model[i] is not None: - pred_num = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.numerator_model[i], subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + if not self.excused: + for i, level in enumerate(self.treatment_level): + mask = pl.col("tx_lag") == level + lag_mask = (WDT["tx_lag"] == level).to_numpy() + + if self.denominator_model[i] is not None: + pred_denom = np.ones(WDT.height) + if lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, self.denominator_model[i], subset) + if p.ndim == 1: + p = p.reshape(-1, 1) + p = p[:, i] + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + else: + pred_denom = np.ones(WDT.height) + + if self.numerator_model[i] is not None: + pred_num = np.ones(WDT.height) + if lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, self.numerator_model[i], subset) + if p.ndim == 1: + p = p.reshape(-1, 1) + p = p[:, i] + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + else: + pred_num = np.ones(WDT.height) + + WDT = WDT.with_columns( + [ + pl.when(mask) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator"), + pl.when(mask) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator"), + ] + ) + + else: + for i, level in enumerate(self.treatment_level): + col = self.excused_colnames[i] + + if col is not None: + denom_mask = ((WDT["tx_lag"] == level) & (WDT[col] != 1)).to_numpy() + + if self.denominator_model[i] is not None and denom_mask.sum() > 0: + pred_denom = np.ones(WDT.height) + subset = WDT.filter(pl.Series(denom_mask)) + p = _predict_model(self, self.denominator_model[i], subset) + + if p.ndim == 1: + prob_switch = p + else: + prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten() + + pred_denom[denom_mask] = prob_switch + + WDT = WDT.with_columns( + pl.when(pl.Series(denom_mask)) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator") + ) + + if i == 0: + flip_mask = ( + (WDT["tx_lag"] == level) + & (WDT[col] == 0) + & (WDT[self.treatment_col] == level) + ).to_numpy() + else: + flip_mask = ( + (WDT["tx_lag"] == level) + & (WDT[col] == 0) + & (WDT[self.treatment_col] != level) + ).to_numpy() + + WDT = WDT.with_columns( + pl.when(pl.Series(flip_mask)) + .then(1.0 - pl.col("denominator")) + .otherwise(pl.col("denominator")) + .alias("denominator") + ) + + if self.weight_preexpansion: + WDT = WDT.with_columns(pl.lit(1.0).alias("numerator")) else: - pred_num = np.ones(WDT.height) + for i, level in enumerate(self.treatment_level): + col = self.excused_colnames[i] - WDT = WDT.with_columns( - [ - pl.when(mask) - .then(pl.Series(pred_num)) - .otherwise(pl.col("numerator")) - .alias("numerator"), - pl.when(mask) - .then(pl.Series(pred_denom)) - .otherwise(pl.col("denominator")) - .alias("denominator"), - ] - ) + if col is not None: + num_mask = ( + (WDT[self.treatment_col] == level) & (WDT[col] == 0) + ).to_numpy() + + if self.numerator_model[i] is not None and num_mask.sum() > 0: + pred_num = np.ones(WDT.height) + subset = WDT.filter(pl.Series(num_mask)) + p = _predict_model(self, self.numerator_model[i], subset) + + if p.ndim == 1: + prob_switch = p + else: + prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten() + pred_num[num_mask] = prob_switch + + WDT = WDT.with_columns( + pl.when(pl.Series(num_mask)) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator") + ) + + first_level = self.treatment_level[0] + WDT = WDT.with_columns( + pl.when(pl.col(self.treatment_col) == first_level) + .then(1.0 - pl.col("numerator")) + .otherwise(pl.col("numerator")) + .alias("numerator") + ) if self.cense_colname is not None: p_num = _predict_model(self, self.cense_numerator, WDT).flatten() p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 948ba11..5d9da28 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,3 +1,10 @@ +import sys +from pathlib import Path + +# Add package root to path +package_root = Path(__file__).parent.parent +sys.path.insert(0, str(package_root)) + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data @@ -189,14 +196,13 @@ def test_PreE_censoring_excused_coefs(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.460912082691973, - 1.309708035546933, - 0.10853511682679658, - -0.0038913520688693823, - 0.08849129909709463, - -0.000647578869153453, + -6.383862385859305, + 1.335378710346093, + 0.029611514864625224, + -0.0017854432467035867, + 0.12888028673936663, + -0.0013855918917791584, ] - print(matrix) assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] @@ -222,22 +228,33 @@ def test_PostE_censoring_excused_coefs(): ) s.expand() s.fit() + s.DT.write_csv("weightdata.csv") + print(s.weight_stats) + print(s.outcome_model[0]["outcome"].summary()) + print(s.numerator_model[0].summary()) + print(s.numerator_model[1].summary()) + print(s.denominator_model[0].summary()) + print(s.denominator_model[1].summary()) + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.782732929102242, - 0.26371172100905477, - 0.13625528598217598, - 0.040580427030886, - -0.000343018323531494, - 0.031185150775465315, - 0.000784356550754563, - 0.004338417236024277, - -0.013052187516528172, - 0.20402680950820007, + -6.354881630599161, + 0.26437059880109814, + 0.11052945840253169, + 0.0359033269938446, + -0.00016819836476915145, + 0.029330229450150295, + 0.0006840606030679354, + 0.007740880717871542, + -0.010288544399887233, + 0.1520398124246858, ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] +test_PostE_censoring_excused_coefs() + + def test_PreE_LTFU_ITT(): data = load_data("SEQdata_LTFU") From b52a1aa294438ebc22d84786610fc8ee03cb9fc1 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:36:14 +0100 Subject: [PATCH 2/2] fixed tests --- pySEQTarget/weighting/_weight_fit.py | 1 - tests/test_coefficients.py | 47 ++++++++++------------------ 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index f9b08f9..c6495ba 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -40,7 +40,6 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - DT_subset.to_csv("fml.csv") model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 5d9da28..aa80120 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,10 +1,6 @@ import sys from pathlib import Path -# Add package root to path -package_root = Path(__file__).parent.parent -sys.path.insert(0, str(package_root)) - from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data @@ -196,12 +192,12 @@ def test_PreE_censoring_excused_coefs(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.383862385859305, - 1.335378710346093, - 0.029611514864625224, - -0.0017854432467035867, - 0.12888028673936663, - -0.0013855918917791584, + -5.028261715903588, + 0.09661040854758277, + -0.029861423750765226, + 0.0014186936955145387, + 0.08365564531281737, + -0.0006220464783614585, ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] @@ -228,33 +224,22 @@ def test_PostE_censoring_excused_coefs(): ) s.expand() s.fit() - s.DT.write_csv("weightdata.csv") - print(s.weight_stats) - print(s.outcome_model[0]["outcome"].summary()) - print(s.numerator_model[0].summary()) - print(s.numerator_model[1].summary()) - print(s.denominator_model[0].summary()) - print(s.denominator_model[1].summary()) - matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.354881630599161, - 0.26437059880109814, - 0.11052945840253169, - 0.0359033269938446, - -0.00016819836476915145, - 0.029330229450150295, - 0.0006840606030679354, - 0.007740880717871542, - -0.010288544399887233, - 0.1520398124246858, + -7.722441228318476, + 0.25040421685828396, + 0.08370244078073162, + 0.03644249151697272, + -0.00019169394285363785, + 0.053677366381589396, + 0.0005643189202781975, + 0.005250478928581509, + 0.0014679503081325516, + 0.3008769969502361, ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] -test_PostE_censoring_excused_coefs() - - def test_PreE_LTFU_ITT(): data = load_data("SEQdata_LTFU")