Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class SEQopts:
subgroup_colname: str = None
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
trial_include: bool = True
visit_colname: str = None
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
weight_min: float = 0.0
weight_max: float = None
Expand Down
17 changes: 13 additions & 4 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
from .plot import _survival_plot
from .SEQopts import SEQopts
from .SEQoutput import SEQoutput
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
_weight_bind, _weight_predict, _weight_setup,
_weight_stats)
from .weighting import (
_fit_denominator,
_fit_LTFU,
_fit_numerator,
_fit_visit,
_weight_bind,
_weight_predict,
_weight_setup,
_weight_stats,
)


class SEQuential:
Expand Down Expand Up @@ -93,7 +100,7 @@ def __init__(
if self.denominator is None:
self.denominator = _denominator(self)

if self.cense_colname is not None:
if self.cense_colname is not None or self.visit_colname is not None:
if self.cense_numerator is None:
self.cense_numerator = _cense_numerator(self)

Expand All @@ -112,6 +119,7 @@ def expand(self) -> None:
self.cense_colname,
self.cense_eligible_colname,
self.compevent_colname,
self.visit_colname,
*self.weight_eligible_colnames,
*self.excused_colnames,
]
Expand Down Expand Up @@ -212,6 +220,7 @@ def fit(self) -> None:
WDT[col] = WDT[col].astype("category")

_fit_LTFU(self, WDT)
_fit_visit(self, WDT)
_fit_numerator(self, WDT)
_fit_denominator(self, WDT)

Expand Down
12 changes: 6 additions & 6 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def _calculate_risk(self, data, idx=None, val=None):
.group_by("followup")
.agg(
[
pl.col("risk").std().alias("SE"),
pl.col("risk").quantile(lci).alias("LCI"),
pl.col("risk").quantile(uci).alias("UCI"),
pl.col("risk").std().cast(pl.Float64).alias("SE"),
pl.col("risk").quantile(lci).cast(pl.Float64).alias("LCI"),
pl.col("risk").quantile(uci).cast(pl.Float64).alias("UCI"),
]
)
.join(TxDT.select(["followup", main_col]), on="followup")
Expand Down Expand Up @@ -198,9 +198,9 @@ def _calculate_risk(self, data, idx=None, val=None):
.group_by("followup")
.agg(
[
pl.col("inc_val").std().alias("inc_SE"),
pl.col("inc_val").quantile(lci).alias("inc_LCI"),
pl.col("inc_val").quantile(uci).alias("inc_UCI"),
pl.col("inc_val").std().cast(pl.Float64).alias("inc_SE"),
pl.col("inc_val").quantile(lci).cast(pl.Float64).alias("inc_LCI"),
pl.col("inc_val").quantile(uci).cast(pl.Float64).alias("inc_UCI"),
]
)
.join(TxDT.select(["followup", "inc"]), on="followup")
Expand Down
11 changes: 9 additions & 2 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ def _param_checker(self):
"Only one of followup_class or followup_include can be set to True."
)

if self.weighted and self.method == "ITT" and self.cense_colname is None:
raise ValueError("For weighted ITT analyses, cense_colname must be provided.")
if (
self.weighted
and self.method == "ITT"
and self.cense_colname is None
and self.visit_colname is None
):
raise ValueError(
"For weighted ITT analyses, cense_colname or visit_colname must be provided."
)

if self.excused:
_, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
Expand Down
1 change: 1 addition & 0 deletions pySEQTarget/weighting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from ._weight_fit import _fit_denominator as _fit_denominator
from ._weight_fit import _fit_LTFU as _fit_LTFU
from ._weight_fit import _fit_numerator as _fit_numerator
from ._weight_fit import _fit_visit as _fit_visit
from ._weight_pred import _weight_predict as _weight_predict
from ._weight_stats import _weight_stats as _weight_stats
40 changes: 29 additions & 11 deletions pySEQTarget/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def _weight_bind(self, WDT):

WDT = self.DT.join(WDT, on=on, how=join)

if self.visit_colname is not None:
visit = pl.col(self.visit_colname) == 0
else:
visit = pl.lit(False)

if self.weight_preexpansion and self.excused:
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
excused = (
Expand All @@ -21,6 +26,7 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-7)
)
Expand All @@ -33,6 +39,7 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-7)
| (pl.col("numerator") < 1e-7)
Expand All @@ -45,24 +52,35 @@ def _weight_bind(self, WDT):
override = (
trial
| excused
| visit
| pl.col(self.outcome_col).is_null()
| (pl.col("denominator") < 1e-15)
| pl.col("numerator").is_null()
)

self.DT = (
WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
(
WDT.with_columns(
pl.when(override)
.then(pl.lit(1.0))
.otherwise(pl.col("numerator") / pl.col("denominator"))
.alias("wt")
)
.sort([self.id_col, "trial", "followup"])
.with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
)
)
.sort([self.id_col, "trial", "followup"])
.with_columns(
pl.col("wt")
.fill_null(1.0)
.cum_prod()
.over([self.id_col, "trial"])
.alias("weight")
(
pl.col("weight")
* pl.col("_cense").fill_null(1.0)
* pl.col("_visit").fill_null(1.0)
).alias("weight")
)
.drop(["_cense", "_visit"])
)
45 changes: 34 additions & 11 deletions pySEQTarget/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,45 @@
import statsmodels.formula.api as smf


def _fit_pair(
self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None
):
outcome = getattr(self, outcome_attr)

if eligible_colname_attr is not None:
_eligible_col = getattr(self, eligible_colname_attr)
if _eligible_col is not None:
WDT = WDT[WDT[_eligible_col] == 1]

for rhs, out in zip(formula_attr, output_attrs):
formula = f"{outcome}~{rhs}"
model = smf.glm(formula, WDT, family=sm.families.Binomial())
setattr(self, out, model.fit(disp=0))


def _fit_LTFU(self, WDT):
if self.cense_colname is None:
return
else:
fits = []
if self.cense_eligible_colname is not None:
WDT = WDT[WDT[self.cense_eligible_colname] == 1]
_fit_pair(
self,
WDT,
"cense_colname",
[self.cense_numerator, self.cense_denominator],
["cense_numerator", "cense_denominator"],
"cense_eligible_colname",
)

for i in [self.cense_numerator, self.cense_denominator]:
formula = f"{self.cense_colname}~{i}"
model = smf.glm(formula, WDT, family=sm.families.Binomial())
model_fit = model.fit(disp=0)
fits.append(model_fit)

self.cense_numerator = fits[0]
self.cense_denominator = fits[1]
def _fit_visit(self, WDT):
if self.visit_colname is None:
return
_fit_pair(
self,
WDT,
"visit_colname",
[self.cense_numerator, self.cense_denominator],
["visit_numerator", "visit_denominator"],
)


def _fit_numerator(self, WDT):
Expand Down
52 changes: 38 additions & 14 deletions pySEQTarget/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,44 @@ def _weight_predict(self, WDT):
.otherwise(pl.col("numerator"))
.alias("numerator")
)
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("cense"))
if self.cense_colname is not None:
p_num = _predict_model(self, self.cense_numerator, WDT).flatten()
p_denom = _predict_model(self, self.cense_denominator, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("_cense")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_cense"))

kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"]
if self.visit_colname is not None:
p_num = _predict_model(self, self.visit_numerator, WDT).flatten()
p_denom = _predict_model(self, self.visit_denominator, WDT).flatten()

WDT = WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
).with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias("_visit")
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_visit"))

kept = [
"numerator",
"denominator",
"_cense",
"_visit",
self.id_col,
"trial",
time,
"tx_lag",
]
exists = [col for col in kept if col in WDT.columns]
return WDT.select(exists).sort(grouping + [time])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.9.2"
version = "0.10.0"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
Loading