diff --git a/docs/conf.py b/docs/conf.py index 0fe672d..a7f536b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,6 +4,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html import importlib.metadata + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index b068dc9..d13685f 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -7,19 +7,36 @@ import numpy as np import polars as pl -from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit, - _pred_risk, _risk_estimates, _subgroup_fit) +from .analysis import ( + _calculate_hazard, + _calculate_survival, + _outcome_fit, + _pred_risk, + _risk_estimates, + _subgroup_fit, +) from .error import _datachecker, _param_checker from .expansion import _binder, _diagnostics, _dynamic, _random_selection from .helpers import _col_string, _format_time, bootstrap_loop -from .initialization import (_cense_denominator, _cense_numerator, - _denominator, _numerator, _outcome) +from .initialization import ( + _cense_denominator, + _cense_numerator, + _denominator, + _numerator, + _outcome, +) from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, - _weight_bind, _weight_predict, _weight_setup, - _weight_stats) +from .weighting import ( + _fit_denominator, + _fit_LTFU, + _fit_numerator, + _weight_bind, + _weight_predict, + _weight_setup, + _weight_stats, +) class SEQuential: diff --git a/pySEQTarget/analysis/__init__.py b/pySEQTarget/analysis/__init__.py index 6799dfd..e35ceb7 100644 --- a/pySEQTarget/analysis/__init__.py +++ b/pySEQTarget/analysis/__init__.py @@ -3,6 +3,5 @@ from ._risk_estimates import _risk_estimates as _risk_estimates from ._subgroup_fit import _subgroup_fit as _subgroup_fit from ._survival_pred import _calculate_survival as _calculate_survival -from ._survival_pred import \ - _get_outcome_predictions as _get_outcome_predictions +from ._survival_pred import _get_outcome_predictions as _get_outcome_predictions from ._survival_pred import _pred_risk as _pred_risk diff --git a/pySEQTarget/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py index d2a99c5..932ba75 100644 --- a/pySEQTarget/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -25,37 +25,49 @@ def _dynamic(self): switch = ( pl.when(pl.col("followup") == 0) .then(pl.lit(False)) - .otherwise( - (pl.col("tx_lag").is_not_null()) - & (pl.col("tx_lag") != pl.col(self.treatment_col)) - ) + .otherwise(pl.col("tx_lag") != pl.col(self.treatment_col)) ) is_excused = pl.lit(False) if self.excused: conditions = [] - for i in range(len(self.treatment_level)): + for i, val in enumerate(self.treatment_level): colname = self.excused_colnames[i] if colname is not None: conditions.append( - (pl.col(colname) == 1) - & (pl.col(self.treatment_col) == self.treatment_level[i]) + (pl.col(colname) == 1) & (pl.col(self.treatment_col) == val) ) if conditions: excused = pl.any_horizontal(conditions) is_excused = switch & excused - switch = pl.when(excused).then(pl.lit(False)).otherwise(switch) - - DT = ( - DT.with_columns([switch.alias("switch"), is_excused.alias("isExcused")]) - .sort([self.id_col, "trial", "followup"]) - .filter( - (pl.col("switch").cum_max().shift(1, fill_value=False)).over( - [self.id_col, "trial"] + + DT = DT.with_columns( + [switch.alias("switch"), is_excused.alias("isExcused")] + ).sort([self.id_col, "trial", "followup"]) + + if self.excused: + DT = ( + DT.with_columns( + pl.col("isExcused") + .cast(pl.Int8) + .cum_sum() + .over([self.id_col, "trial"]) + .alias("_excused_tmp") ) - == 0 + .with_columns( + pl.when(pl.col("_excused_tmp") > 0) + .then(pl.lit(False)) + .otherwise(pl.col("switch")) + .alias("switch") + ) + .drop("_excused_tmp") ) - .with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) - ) + + DT = DT.filter( + (pl.col("switch").cum_max().shift(1, fill_value=False)).over( + [self.id_col, "trial"] + ) + == 0 + ).with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) self.DT = DT.drop(["tx_lag"]) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index 2996764..c6495ba 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -40,7 +40,6 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 5a858a8..7e77c4d 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -18,53 +18,137 @@ def _weight_predict(self, WDT): [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] ) - for i, level in enumerate(self.treatment_level): - mask = pl.col("tx_lag") == level - lag_mask = (WDT["tx_lag"] == level).to_numpy() - - if self.denominator_model[i] is not None: - pred_denom = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.denominator_model[i], subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) - else: - pred_denom = np.ones(WDT.height) - - if hasattr(self, "numerator_model") and self.numerator_model[i] is not None: - pred_num = np.ones(WDT.height) - if lag_mask.sum() > 0: - subset = WDT.filter(pl.Series(lag_mask)) - p = _predict_model(self, self.numerator_model[i], subset) - if p.ndim == 1: - p = p.reshape(-1, 1) - p = p[:, i] - switched_treatment = ( - subset[self.treatment_col] != subset["tx_lag"] - ).to_numpy() - pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + if not self.excused: + for i, level in enumerate(self.treatment_level): + mask = pl.col("tx_lag") == level + lag_mask = (WDT["tx_lag"] == level).to_numpy() + + if self.denominator_model[i] is not None: + pred_denom = np.ones(WDT.height) + if lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, self.denominator_model[i], subset) + if p.ndim == 1: + p = p.reshape(-1, 1) + p = p[:, i] + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + else: + pred_denom = np.ones(WDT.height) + + if self.numerator_model[i] is not None: + pred_num = np.ones(WDT.height) + if lag_mask.sum() > 0: + subset = WDT.filter(pl.Series(lag_mask)) + p = _predict_model(self, self.numerator_model[i], subset) + if p.ndim == 1: + p = p.reshape(-1, 1) + p = p[:, i] + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) + else: + pred_num = np.ones(WDT.height) + + WDT = WDT.with_columns( + [ + pl.when(mask) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator"), + pl.when(mask) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator"), + ] + ) + + else: + for i, level in enumerate(self.treatment_level): + col = self.excused_colnames[i] + + if col is not None: + denom_mask = ((WDT["tx_lag"] == level) & (WDT[col] != 1)).to_numpy() + + if self.denominator_model[i] is not None and denom_mask.sum() > 0: + pred_denom = np.ones(WDT.height) + subset = WDT.filter(pl.Series(denom_mask)) + p = _predict_model(self, self.denominator_model[i], subset) + + if p.ndim == 1: + prob_switch = p + else: + prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten() + + pred_denom[denom_mask] = prob_switch + + WDT = WDT.with_columns( + pl.when(pl.Series(denom_mask)) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator") + ) + + if i == 0: + flip_mask = ( + (WDT["tx_lag"] == level) + & (WDT[col] == 0) + & (WDT[self.treatment_col] == level) + ).to_numpy() + else: + flip_mask = ( + (WDT["tx_lag"] == level) + & (WDT[col] == 0) + & (WDT[self.treatment_col] != level) + ).to_numpy() + + WDT = WDT.with_columns( + pl.when(pl.Series(flip_mask)) + .then(1.0 - pl.col("denominator")) + .otherwise(pl.col("denominator")) + .alias("denominator") + ) + + if self.weight_preexpansion: + WDT = WDT.with_columns(pl.lit(1.0).alias("numerator")) else: - pred_num = np.ones(WDT.height) + for i, level in enumerate(self.treatment_level): + col = self.excused_colnames[i] - WDT = WDT.with_columns( - [ - pl.when(mask) - .then(pl.Series(pred_num)) - .otherwise(pl.col("numerator")) - .alias("numerator"), - pl.when(mask) - .then(pl.Series(pred_denom)) - .otherwise(pl.col("denominator")) - .alias("denominator"), - ] - ) + if col is not None: + num_mask = ( + (WDT[self.treatment_col] == level) & (WDT[col] == 0) + ).to_numpy() + + if self.numerator_model[i] is not None and num_mask.sum() > 0: + pred_num = np.ones(WDT.height) + subset = WDT.filter(pl.Series(num_mask)) + p = _predict_model(self, self.numerator_model[i], subset) + + if p.ndim == 1: + prob_switch = p + else: + prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten() + pred_num[num_mask] = prob_switch + + WDT = WDT.with_columns( + pl.when(pl.Series(num_mask)) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator") + ) + + first_level = self.treatment_level[0] + WDT = WDT.with_columns( + pl.when(pl.col(self.treatment_col) == first_level) + .then(1.0 - pl.col("numerator")) + .otherwise(pl.col("numerator")) + .alias("numerator") + ) if self.cense_colname is not None: p_num = _predict_model(self, self.cense_numerator, WDT).flatten() p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 948ba11..aa80120 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,3 +1,6 @@ +import sys +from pathlib import Path + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data @@ -189,14 +192,13 @@ def test_PreE_censoring_excused_coefs(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.460912082691973, - 1.309708035546933, - 0.10853511682679658, - -0.0038913520688693823, - 0.08849129909709463, - -0.000647578869153453, + -5.028261715903588, + 0.09661040854758277, + -0.029861423750765226, + 0.0014186936955145387, + 0.08365564531281737, + -0.0006220464783614585, ] - print(matrix) assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] @@ -224,16 +226,16 @@ def test_PostE_censoring_excused_coefs(): s.fit() matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() expected = [ - -6.782732929102242, - 0.26371172100905477, - 0.13625528598217598, - 0.040580427030886, - -0.000343018323531494, - 0.031185150775465315, - 0.000784356550754563, - 0.004338417236024277, - -0.013052187516528172, - 0.20402680950820007, + -7.722441228318476, + 0.25040421685828396, + 0.08370244078073162, + 0.03644249151697272, + -0.00019169394285363785, + 0.053677366381589396, + 0.0005643189202781975, + 0.005250478928581509, + 0.0014679503081325516, + 0.3008769969502361, ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]