Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html

import importlib.metadata

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
Expand Down
31 changes: 24 additions & 7 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,36 @@
import numpy as np
import polars as pl

from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
_pred_risk, _risk_estimates, _subgroup_fit)
from .analysis import (
_calculate_hazard,
_calculate_survival,
_outcome_fit,
_pred_risk,
_risk_estimates,
_subgroup_fit,
)
from .error import _datachecker, _param_checker
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
from .helpers import _col_string, _format_time, bootstrap_loop
from .initialization import (_cense_denominator, _cense_numerator,
_denominator, _numerator, _outcome)
from .initialization import (
_cense_denominator,
_cense_numerator,
_denominator,
_numerator,
_outcome,
)
from .plot import _survival_plot
from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
_weight_bind, _weight_predict, _weight_setup,
_weight_stats)
from .weighting import (
_fit_denominator,
_fit_LTFU,
_fit_numerator,
_weight_bind,
_weight_predict,
_weight_setup,
_weight_stats,
)


class SEQuential:
Expand Down
3 changes: 1 addition & 2 deletions pySEQTarget/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@
from ._risk_estimates import _risk_estimates as _risk_estimates
from ._subgroup_fit import _subgroup_fit as _subgroup_fit
from ._survival_pred import _calculate_survival as _calculate_survival
from ._survival_pred import \
_get_outcome_predictions as _get_outcome_predictions
from ._survival_pred import _get_outcome_predictions as _get_outcome_predictions
from ._survival_pred import _pred_risk as _pred_risk
48 changes: 30 additions & 18 deletions pySEQTarget/expansion/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,49 @@ def _dynamic(self):
switch = (
pl.when(pl.col("followup") == 0)
.then(pl.lit(False))
.otherwise(
(pl.col("tx_lag").is_not_null())
& (pl.col("tx_lag") != pl.col(self.treatment_col))
)
.otherwise(pl.col("tx_lag") != pl.col(self.treatment_col))
)
is_excused = pl.lit(False)
if self.excused:
conditions = []
for i in range(len(self.treatment_level)):
for i, val in enumerate(self.treatment_level):
colname = self.excused_colnames[i]
if colname is not None:
conditions.append(
(pl.col(colname) == 1)
& (pl.col(self.treatment_col) == self.treatment_level[i])
(pl.col(colname) == 1) & (pl.col(self.treatment_col) == val)
)

if conditions:
excused = pl.any_horizontal(conditions)
is_excused = switch & excused
switch = pl.when(excused).then(pl.lit(False)).otherwise(switch)

DT = (
DT.with_columns([switch.alias("switch"), is_excused.alias("isExcused")])
.sort([self.id_col, "trial", "followup"])
.filter(
(pl.col("switch").cum_max().shift(1, fill_value=False)).over(
[self.id_col, "trial"]

DT = DT.with_columns(
[switch.alias("switch"), is_excused.alias("isExcused")]
).sort([self.id_col, "trial", "followup"])

if self.excused:
DT = (
DT.with_columns(
pl.col("isExcused")
.cast(pl.Int8)
.cum_sum()
.over([self.id_col, "trial"])
.alias("_excused_tmp")
)
== 0
.with_columns(
pl.when(pl.col("_excused_tmp") > 0)
.then(pl.lit(False))
.otherwise(pl.col("switch"))
.alias("switch")
)
.drop("_excused_tmp")
)
.with_columns(pl.col("switch").cast(pl.Int8).alias("switch"))
)

DT = DT.filter(
(pl.col("switch").cum_max().shift(1, fill_value=False)).over(
[self.id_col, "trial"]
)
== 0
).with_columns(pl.col("switch").cast(pl.Int8).alias("switch"))

self.DT = DT.drop(["tx_lag"])
1 change: 0 additions & 1 deletion pySEQTarget/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _fit_numerator(self, WDT):
DT_subset = DT_subset[DT_subset[tx_bas] == level]
if self.weight_eligible_colnames[i] is not None:
DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]

model = smf.mnlogit(formula, DT_subset)
model_fit = model.fit(disp=0)
fits.append(model_fit)
Expand Down
172 changes: 128 additions & 44 deletions pySEQTarget/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,137 @@ def _weight_predict(self, WDT):
[pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")]
)

for i, level in enumerate(self.treatment_level):
mask = pl.col("tx_lag") == level
lag_mask = (WDT["tx_lag"] == level).to_numpy()

if self.denominator_model[i] is not None:
pred_denom = np.ones(WDT.height)
if lag_mask.sum() > 0:
subset = WDT.filter(pl.Series(lag_mask))
p = _predict_model(self, self.denominator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (
subset[self.treatment_col] != subset["tx_lag"]
).to_numpy()
pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p)
else:
pred_denom = np.ones(WDT.height)

if hasattr(self, "numerator_model") and self.numerator_model[i] is not None:
pred_num = np.ones(WDT.height)
if lag_mask.sum() > 0:
subset = WDT.filter(pl.Series(lag_mask))
p = _predict_model(self, self.numerator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (
subset[self.treatment_col] != subset["tx_lag"]
).to_numpy()
pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p)
if not self.excused:
for i, level in enumerate(self.treatment_level):
mask = pl.col("tx_lag") == level
lag_mask = (WDT["tx_lag"] == level).to_numpy()

if self.denominator_model[i] is not None:
pred_denom = np.ones(WDT.height)
if lag_mask.sum() > 0:
subset = WDT.filter(pl.Series(lag_mask))
p = _predict_model(self, self.denominator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (
subset[self.treatment_col] != subset["tx_lag"]
).to_numpy()
pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p)
else:
pred_denom = np.ones(WDT.height)

if self.numerator_model[i] is not None:
pred_num = np.ones(WDT.height)
if lag_mask.sum() > 0:
subset = WDT.filter(pl.Series(lag_mask))
p = _predict_model(self, self.numerator_model[i], subset)
if p.ndim == 1:
p = p.reshape(-1, 1)
p = p[:, i]
switched_treatment = (
subset[self.treatment_col] != subset["tx_lag"]
).to_numpy()
pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p)
else:
pred_num = np.ones(WDT.height)

WDT = WDT.with_columns(
[
pl.when(mask)
.then(pl.Series(pred_num))
.otherwise(pl.col("numerator"))
.alias("numerator"),
pl.when(mask)
.then(pl.Series(pred_denom))
.otherwise(pl.col("denominator"))
.alias("denominator"),
]
)

else:
for i, level in enumerate(self.treatment_level):
col = self.excused_colnames[i]

if col is not None:
denom_mask = ((WDT["tx_lag"] == level) & (WDT[col] != 1)).to_numpy()

if self.denominator_model[i] is not None and denom_mask.sum() > 0:
pred_denom = np.ones(WDT.height)
subset = WDT.filter(pl.Series(denom_mask))
p = _predict_model(self, self.denominator_model[i], subset)

if p.ndim == 1:
prob_switch = p
else:
prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten()

pred_denom[denom_mask] = prob_switch

WDT = WDT.with_columns(
pl.when(pl.Series(denom_mask))
.then(pl.Series(pred_denom))
.otherwise(pl.col("denominator"))
.alias("denominator")
)

if i == 0:
flip_mask = (
(WDT["tx_lag"] == level)
& (WDT[col] == 0)
& (WDT[self.treatment_col] == level)
).to_numpy()
else:
flip_mask = (
(WDT["tx_lag"] == level)
& (WDT[col] == 0)
& (WDT[self.treatment_col] != level)
).to_numpy()

WDT = WDT.with_columns(
pl.when(pl.Series(flip_mask))
.then(1.0 - pl.col("denominator"))
.otherwise(pl.col("denominator"))
.alias("denominator")
)

if self.weight_preexpansion:
WDT = WDT.with_columns(pl.lit(1.0).alias("numerator"))
else:
pred_num = np.ones(WDT.height)
for i, level in enumerate(self.treatment_level):
col = self.excused_colnames[i]

WDT = WDT.with_columns(
[
pl.when(mask)
.then(pl.Series(pred_num))
.otherwise(pl.col("numerator"))
.alias("numerator"),
pl.when(mask)
.then(pl.Series(pred_denom))
.otherwise(pl.col("denominator"))
.alias("denominator"),
]
)
if col is not None:
num_mask = (
(WDT[self.treatment_col] == level) & (WDT[col] == 0)
).to_numpy()

if self.numerator_model[i] is not None and num_mask.sum() > 0:
pred_num = np.ones(WDT.height)
subset = WDT.filter(pl.Series(num_mask))
p = _predict_model(self, self.numerator_model[i], subset)

if p.ndim == 1:
prob_switch = p
else:
prob_switch = p[:, 1] if p.shape[1] > 1 else p.flatten()

pred_num[num_mask] = prob_switch

WDT = WDT.with_columns(
pl.when(pl.Series(num_mask))
.then(pl.Series(pred_num))
.otherwise(pl.col("numerator"))
.alias("numerator")
)

first_level = self.treatment_level[0]
WDT = WDT.with_columns(
pl.when(pl.col(self.treatment_col) == first_level)
.then(1.0 - pl.col("numerator"))
.otherwise(pl.col("numerator"))
.alias("numerator")
)
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
Expand Down
36 changes: 19 additions & 17 deletions tests/test_coefficients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys
from pathlib import Path

from pySEQTarget import SEQopts, SEQuential
from pySEQTarget.data import load_data

Expand Down Expand Up @@ -189,14 +192,13 @@ def test_PreE_censoring_excused_coefs():
s.fit()
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
expected = [
-6.460912082691973,
1.309708035546933,
0.10853511682679658,
-0.0038913520688693823,
0.08849129909709463,
-0.000647578869153453,
-5.028261715903588,
0.09661040854758277,
-0.029861423750765226,
0.0014186936955145387,
0.08365564531281737,
-0.0006220464783614585,
]
print(matrix)
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]


Expand Down Expand Up @@ -224,16 +226,16 @@ def test_PostE_censoring_excused_coefs():
s.fit()
matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list()
expected = [
-6.782732929102242,
0.26371172100905477,
0.13625528598217598,
0.040580427030886,
-0.000343018323531494,
0.031185150775465315,
0.000784356550754563,
0.004338417236024277,
-0.013052187516528172,
0.20402680950820007,
-7.722441228318476,
0.25040421685828396,
0.08370244078073162,
0.03644249151697272,
-0.00019169394285363785,
0.053677366381589396,
0.0005643189202781975,
0.005250478928581509,
0.0014679503081325516,
0.3008769969502361,
]
assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected]

Expand Down