diff --git a/.github/workflows/autoformat.yml b/.github/workflows/autoformat.yml new file mode 100644 index 0000000..2dcf15c --- /dev/null +++ b/.github/workflows/autoformat.yml @@ -0,0 +1,34 @@ +name: Auto-format Code + +on: + push: + branches: [main, develop] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install formatters + run: pip install black isort + + - name: Format with Black + run: black . + + - name: Sort imports with isort + run: isort . + + - name: Commit changes + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git diff --quiet || (git add -A && git commit -m "Auto-format code") + git push \ No newline at end of file diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..d6e089f --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,35 @@ +name: Build and Deploy Docs + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build-docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install sphinx sphinx-rtd-theme sphinx-autodoc-typehints + pip install -e . # Install your package + + - name: Build documentation + run: | + cd docs + make html + + - name: Deploy to GitHub Pages + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/_build/html \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..1fa8189 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,70 @@ +# This workflow will upload a Python Package to PyPI when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + release-build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + # NOTE: put your own distribution build steps here. + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + needs: + - release-build + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. + # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules + environment: + name: pypi + # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: + # url: https://pypi.org/p/YOURPROJECT + # + # ALTERNATIVE: if your GitHub Release name is the PyPI project version string + # ALTERNATIVE: exactly, uncomment the following line instead: + # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }} + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ \ No newline at end of file diff --git a/README.md b/README.md index 0a48102..7e20d3d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# pySEQ - Sequentially Nested Target Trial Emulation +# pySEQTarget - Sequentially Nested Target Trial Emulation Implementation of sequential trial emulation for the analysis of observational databases. The ‘SEQTaRget’ software accommodates @@ -9,13 +9,13 @@ intention-to-treat and per-protocol effects, and can adjust for potential selection bias. ## Installation -You can install the development version of pySEQ from github with: +You can install the development version of pySEQTarget from github with: ```shell -pip install git+https://github.com/CausalInference/pySEQ +pip install git+https://github.com/CausalInference/pySEQTarget ``` Or from pypi iwth ```shell -pip install pySEQ +pip install pySEQTarget ``` ## Setting up your Analysis @@ -25,7 +25,7 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe ```python import polars as pl -from pySEQ import SEQuential, SEQopts +from pySEQTarget import SEQuential, SEQopts data = pl.from_pandas(SEQdata) options = SEQopts(km_curves = True) diff --git a/pySEQ/docs/Makefile b/docs/Makefile similarity index 100% rename from pySEQ/docs/Makefile rename to docs/Makefile diff --git a/pySEQ/docs/make.bat b/docs/make.bat similarity index 100% rename from pySEQ/docs/make.bat rename to docs/make.bat diff --git a/pySEQ/docs/source/conf.py b/docs/source/conf.py similarity index 81% rename from pySEQ/docs/source/conf.py rename to docs/source/conf.py index 994268a..c0630db 100644 --- a/pySEQ/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,10 +6,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'SEQuential' -copyright = "2024, Ryan O'Dea, Alejandro Szmulewicz" -author = "Ryan O'Dea, Alejandro Szmulewicz" -release = '0.1.0' +project = 'pySEQTarget' +copyright = "2025, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" +author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" +release = '0.9.0' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..cee549f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,17 @@ +.. pySEQTarget documentation master file, created by + sphinx-quickstart on Mon Nov 24 20:43:34 2025. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +pySEQTarget documentation +========================= + +Add your content using ``reStructuredText`` syntax. See the +`reStructuredText `_ +documentation for details. + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + diff --git a/pySEQ/analysis/__init__.py b/pySEQ/analysis/__init__.py deleted file mode 100644 index ec9e032..0000000 --- a/pySEQ/analysis/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._outcome_fit import _outcome_fit -from ._survival_pred import _get_outcome_predictions, _pred_risk, _calculate_survival -from ._subgroup_fit import _subgroup_fit -from ._hazard import _calculate_hazard -from ._risk_estimates import _risk_estimates diff --git a/pySEQ/analysis/_risk_estimates.py b/pySEQ/analysis/_risk_estimates.py deleted file mode 100644 index 2303fa7..0000000 --- a/pySEQ/analysis/_risk_estimates.py +++ /dev/null @@ -1,107 +0,0 @@ -import polars as pl -from scipy import stats - -def _risk_estimates(self): - last_followup = self.km_data['followup'].max() - risk = self.km_data.filter( - (pl.col('followup') == last_followup) & - (pl.col('estimate') == 'risk') - ) - - group_cols = [self.subgroup_colname] if self.subgroup_colname else [] - rd_comparisons = [] - rr_comparisons = [] - - if self.bootstrap_nboot > 0: - alpha = 1 - self.bootstrap_CI - z = stats.norm.ppf(1 - alpha / 2) - - for tx_x in self.treatment_level: - for tx_y in self.treatment_level: - if tx_x == tx_y: - continue - - risk_x = risk.filter(pl.col('tx_init') == tx_x).select( - group_cols + ['pred'] - ).rename({'pred': 'risk_x'}) - - risk_y = risk.filter(pl.col('tx_init') == tx_y).select( - group_cols + ['pred'] - ).rename({'pred': 'risk_y'}) - - if group_cols: - comp = risk_x.join(risk_y, on=group_cols, how='left') - else: - comp = risk_x.join(risk_y, how='cross') - - comp = comp.with_columns([ - pl.lit(tx_x).alias('A_x'), - pl.lit(tx_y).alias('A_y') - ]) - - if self.bootstrap_nboot > 0: - se_x = risk.filter(pl.col('tx_init') == tx_x).select( - group_cols + ['SE'] - ).rename({'SE': 'se_x'}) - - se_y = risk.filter(pl.col('tx_init') == tx_y).select( - group_cols + ['SE'] - ).rename({'SE': 'se_y'}) - - if group_cols: - comp = comp.join(se_x, on=group_cols, how='left') - comp = comp.join(se_y, on=group_cols, how='left') - else: - comp = comp.join(se_x, how='cross') - comp = comp.join(se_y, how='cross') - - rd_se = (pl.col('se_x').pow(2) + pl.col('se_y').pow(2)).sqrt() - rd_comp = comp.with_columns([ - (pl.col('risk_x') - pl.col('risk_y')).alias('Risk Difference'), - (pl.col('risk_x') - pl.col('risk_y') - z * rd_se).alias('RD 95% LCI'), - (pl.col('risk_x') - pl.col('risk_y') + z * rd_se).alias('RD 95% UCI') - ]) - rd_comp = rd_comp.drop(['risk_x', 'risk_y', 'se_x', 'se_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Difference', 'RD 95% LCI', 'RD 95% UCI'] - rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) - rd_comparisons.append(rd_comp) - - rr_log_se = ( - (pl.col('se_x') / pl.col('risk_x')).pow(2) + - (pl.col('se_y') / pl.col('risk_y')).pow(2) - ).sqrt() - rr_comp = comp.with_columns([ - (pl.col('risk_x') / pl.col('risk_y')).alias('Risk Ratio'), - ((pl.col('risk_x') / pl.col('risk_y')) * (-z * rr_log_se).exp()).alias('RR 95% LCI'), - ((pl.col('risk_x') / pl.col('risk_y')) * (z * rr_log_se).exp()).alias('RR 95% UCI') - ]) - rr_comp = rr_comp.drop(['risk_x', 'risk_y', 'se_x', 'se_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Ratio', 'RR 95% LCI', 'RR 95% UCI'] - rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) - rr_comparisons.append(rr_comp) - - else: - rd_comp = comp.with_columns( - (pl.col('risk_x') - pl.col('risk_y')).alias('Risk Difference') - ) - rd_comp = rd_comp.drop(['risk_x', 'risk_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Difference'] - rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) - rd_comparisons.append(rd_comp) - - rr_comp = comp.with_columns( - (pl.col('risk_x') / pl.col('risk_y')).alias('Risk Ratio') - ) - rr_comp = rr_comp.drop(['risk_x', 'risk_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Ratio'] - rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) - rr_comparisons.append(rr_comp) - - risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() - risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() - - return { - 'risk_difference': risk_difference, - 'risk_ratio': risk_ratio - } - \ No newline at end of file diff --git a/pySEQ/analysis/_subgroup_fit.py b/pySEQ/analysis/_subgroup_fit.py deleted file mode 100644 index 10ab448..0000000 --- a/pySEQ/analysis/_subgroup_fit.py +++ /dev/null @@ -1,25 +0,0 @@ -import polars as pl -from ._outcome_fit import _outcome_fit - -def _subgroup_fit(self): - subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list()) - self._unique_subgroups = subgroups - - models_list = [] - for val in subgroups: - subDT = self.DT.filter(pl.col(self.subgroup_colname) == val) - - models = {'outcome': _outcome_fit(self, subDT, - self.outcome_col, - self.covariates, - self.weighted, - "weight")} - - if self.compevent_colname is not None: - models['compevent'] = _outcome_fit(self, subDT, - self.compevent_colname, - self.covariates, - self.weighted, - "weight") - models_list.append(models) - return models_list diff --git a/pySEQ/analysis/_survival_pred.py b/pySEQ/analysis/_survival_pred.py deleted file mode 100644 index ce8b0d5..0000000 --- a/pySEQ/analysis/_survival_pred.py +++ /dev/null @@ -1,276 +0,0 @@ -import polars as pl - -def _get_outcome_predictions(self, TxDT, idx=None): - data = TxDT.to_pandas() - predictions = {"outcome": []} - if self.compevent_colname is not None: - predictions["compevent"] = [] - - for boot_model in self.outcome_model: - model_dict = boot_model[idx] if idx is not None else boot_model - predictions["outcome"].append(model_dict["outcome"].predict(data)) - if self.compevent_colname is not None: - predictions["compevent"].append(model_dict["compevent"].predict(data)) - - return predictions - -def _pred_risk(self): - has_subgroups = (isinstance(self.outcome_model[0], list) if self.outcome_model else False) - - if not has_subgroups: - return _calculate_risk(self, self.DT, idx=None, val=None) - - all_risks = [] - original_DT = self.DT - - for i, val in enumerate(self._unique_subgroups): - subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val) - risk = _calculate_risk(self, subgroup_DT, i, val) - all_risks.append(risk) - - self.DT = original_DT - return pl.concat(all_risks) - -def _calculate_risk(self, data, idx=None, val=None): - a = 1 - self.bootstrap_CI - lci = a / 2 - uci = 1 - lci - - SDT = ( - data - .with_columns([ - (pl.col(self.id_col) - .cast(pl.Utf8) + - pl.col("trial") - .cast(pl.Utf8)) - .alias("TID") - ]) - .group_by("TID").first() - .drop(["followup", f"followup{self.indicator_squared}"]) - .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")]) - .explode("followup") - .with_columns([ - (pl.col("followup") + 1).alias("followup"), - (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}") - ]) - ).sort([self.id_col, "trial", "followup"]) - - risks = [] - for treatment_val in self.treatment_level: - TxDT = SDT.with_columns([ - pl.lit(treatment_val) - .alias(f"{self.treatment_col}{self.indicator_baseline}") - ]) - - if self.method == "dose-response": - if treatment_val == self.treatment_level[0]: - TxDT = TxDT.with_columns([ - pl.lit(0.0).alias("dose"), - pl.lit(0.0).alias("dose_sq") - ]) - else: - TxDT = TxDT.with_columns([ - pl.col("followup").alias("dose"), - pl.col(f"followup{self.indicator_squared}").alias("dose_sq") - ]) - - preds = _get_outcome_predictions(self, TxDT, idx=idx) - pred_series = [pl.Series("pred_outcome", preds["outcome"][0])] - - if self.bootstrap_nboot > 0: - for boot_idx, pred in enumerate(preds["outcome"][1:], start=1): - pred_series.append(pl.Series(f"pred_outcome_{boot_idx}", pred)) - - if self.compevent_colname is not None: - pred_series.append(pl.Series("pred_ce", preds["compevent"][0])) - if self.bootstrap_nboot > 0: - for boot_idx, pred in enumerate(preds["compevent"][1:], start=1): - pred_series.append(pl.Series(f"pred_ce_{boot_idx}", pred)) - - outcome_names = [s.name for s in pred_series if "outcome" in s.name] - ce_names = [s.name for s in pred_series if "ce" in s.name] - - TxDT = TxDT.with_columns(pred_series) - - if self.compevent_colname is not None: - for out_col, ce_col in zip(outcome_names, ce_names): - surv_col = out_col.replace("pred_outcome", "surv") - cce_col = out_col.replace("pred_outcome", "cce") - inc_col = out_col.replace("pred_outcome", "inc") - - TxDT = ( - TxDT.with_columns([ - (1 - pl.col(out_col)) - .cum_prod() - .over("TID") - .alias(surv_col), - ((1 - pl.col(out_col)) * (1 - pl.col(ce_col))) - .cum_prod().over("TID") - .alias(cce_col) - ]) - .with_columns([ - (pl.col(out_col) * (1 - pl.col(ce_col)) * pl.col(cce_col)) - .cum_sum() - .over("TID") - .alias(inc_col) - ]) - ) - - surv_names = [n.replace("pred_outcome", "surv") for n in outcome_names] - inc_names = [n.replace("pred_outcome", "inc") for n in outcome_names] - TxDT = TxDT.group_by("followup").agg([pl.col(col).mean() for col in surv_names + inc_names]).sort("followup") - main_col = "surv" - boot_cols = [col for col in surv_names if col != "surv"] - else: - TxDT = ( - TxDT.with_columns([(1 - pl.col(col)) - .cum_prod() - .over("TID") - .alias(col) for col in outcome_names - ]) - .group_by("followup").agg([pl.col(col).mean() for col in outcome_names]) - .sort("followup") - .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names]) - ) - main_col = "pred_outcome" - boot_cols = [col for col in outcome_names if col != "pred_outcome"] - - if boot_cols: - risk = ( - TxDT.select(["followup"] + boot_cols) - .unpivot(index="followup", on=boot_cols, variable_name="bootID", value_name="risk") - .group_by("followup").agg([ - pl.col("risk").std().alias("SE"), - pl.col("risk").quantile(lci).alias("LCI"), - pl.col("risk").quantile(uci).alias("UCI") - ]) - .join(TxDT.select(["followup", main_col]), on="followup") - ) - - if self.bootstrap_CI_method == "se": - from scipy.stats import norm - z = norm.ppf(1 - a / 2) - risk = risk.with_columns([ - (pl.col(main_col) - z * pl.col("SE")).alias("LCI"), - (pl.col(main_col) + z * pl.col("SE")).alias("UCI") - ]) - - fup0_val = 1.0 if self.compevent_colname else 0.0 - - if self.compevent_colname is not None: - inc_boot_cols = [col for col in inc_names if col != "inc"] - if inc_boot_cols: - inc_risk = ( - TxDT.select(["followup"] + inc_boot_cols) - .unpivot(index="followup", on=inc_boot_cols, variable_name="bootID", value_name="inc_val") - .group_by("followup").agg([ - pl.col("inc_val").std().alias("inc_SE"), - pl.col("inc_val").quantile(lci).alias("inc_LCI"), - pl.col("inc_val").quantile(uci).alias("inc_UCI") - ]) - .join(TxDT.select(["followup", "inc"]), on="followup") - ) - risk = risk.join(inc_risk, on="followup") - final_cols = ["followup", main_col, "SE", "LCI", "UCI", "inc", "inc_SE", "inc_LCI", "inc_UCI"] - risk = risk.select(final_cols).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - "SE": [0.0], - "LCI": [fup0_val], - "UCI": [fup0_val], - "inc": [0.0], - "inc_SE": [0.0], - "inc_LCI": [0.0], - "inc_UCI": [0.0], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) - else: - risk = risk.select(["followup", main_col, "SE", "LCI", "UCI"]).with_columns( - pl.lit(treatment_val) - .alias(self.treatment_col) - ) - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - "SE": [0.0], - "LCI": [fup0_val], - "UCI": [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) - else: - risk = risk.select(["followup", main_col, "SE", "LCI", "UCI"]).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - "SE": [0.0], - "LCI": [fup0_val], - "UCI": [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) - else: - fup0_val = 1.0 if self.compevent_colname else 0.0 - risk = TxDT.select(["followup", main_col]).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32)]) - - if self.compevent_colname is not None: - risk = risk.join(TxDT.select(["followup", "inc"]), on="followup") - fup0 = fup0.with_columns([pl.lit(0.0).alias("inc")]) - - risks.append(pl.concat([fup0, risk])) - out = pl.concat(risks) - - if self.compevent_colname is not None: - has_ci = "SE" in out.columns - - surv_cols = ["followup", self.treatment_col, "surv"] - if has_ci: - surv_cols.extend(["SE", "LCI", "UCI"]) - surv_out = out.select(surv_cols).rename({"surv": "pred"}).with_columns(pl.lit("survival").alias("estimate")) - - risk_cols = ["followup", self.treatment_col, (1 - pl.col("surv")).alias("pred")] - if has_ci: - risk_cols.extend([pl.col("SE"), (1 - pl.col("UCI")).alias("LCI"), (1 - pl.col("LCI")).alias("UCI")]) - risk_out = out.select(risk_cols).with_columns(pl.lit("risk").alias("estimate")) - - inc_cols = ["followup", self.treatment_col, pl.col("inc").alias("pred")] - if has_ci: - inc_cols.extend([pl.col("inc_SE").alias("SE"), pl.col("inc_LCI").alias("LCI"), pl.col("inc_UCI").alias("UCI")]) - inc_out = out.select(inc_cols).with_columns(pl.lit("incidence").alias("estimate")) - - out = pl.concat([surv_out, risk_out, inc_out]) - else: - out = out.rename({"pred_outcome": "pred"}).with_columns(pl.lit("risk").alias("estimate")) - - if val is not None: - out = out.with_columns(pl.lit(val).alias(self.subgroup_colname)) - - return out - -def _calculate_survival(self, risk_data): - if self.bootstrap_nboot > 0: - surv = risk_data.with_columns([ - (1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"] - ]).with_columns(pl.lit("survival").alias("estimate")) - else: - surv = risk_data.with_columns([ - (1 - pl.col("pred")).alias("pred"), - pl.lit("survival").alias("estimate") - ]) - return surv diff --git a/pySEQ/docs/source/index.rst b/pySEQ/docs/source/index.rst deleted file mode 100644 index 167b37b..0000000 --- a/pySEQ/docs/source/index.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. SEQuential documentation master file, created by - sphinx-quickstart on Fri Nov 15 14:02:09 2024. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to SEQuential's documentation! -====================================== - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/pySEQ/error/__init__.py b/pySEQ/error/__init__.py deleted file mode 100644 index c9ee5ee..0000000 --- a/pySEQ/error/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ._param_checker import _param_checker -from ._datachecker import _datachecker \ No newline at end of file diff --git a/pySEQ/error/_datachecker.py b/pySEQ/error/_datachecker.py deleted file mode 100644 index da8277f..0000000 --- a/pySEQ/error/_datachecker.py +++ /dev/null @@ -1,29 +0,0 @@ -import polars as pl - -def _datachecker(self): - check = self.data.group_by(self.id_col).agg([ - pl.len().alias("row_count"), - pl.col(self.time_col).max().alias("max_time") - ]) - - invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1) - if len(invalid) > 0: - raise ValueError( - f"Data validation failed: {len(invalid)} ID(s) have mismatched " - f"This suggests invalid times" - f"Invalid IDs:\n{invalid}" - ) - - for col in self.excused_colnames: - violations = self.data.sort([self.id_col, self.time_col]).group_by(self.id_col).agg([ - ((pl.col(col).cum_sum().shift(1, fill_value=0) > 0) & (pl.col(col) == 0)) - .any() - .alias("has_violation") - ]).filter(pl.col("has_violation")) - - if len(violations) > 0: - raise ValueError( - f"Column '{col}' violates 'once one, always one' rule for excusing treatment " - f"{len(violations)} ID(s) have zeros after ones." - ) - \ No newline at end of file diff --git a/pySEQ/expansion/__init__.py b/pySEQ/expansion/__init__.py deleted file mode 100644 index 60947bb..0000000 --- a/pySEQ/expansion/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._binder import _binder -from ._dynamic import _dynamic -from ._mapper import _mapper -from ._selection import _random_selection -from ._diagnostics import _diagnostics \ No newline at end of file diff --git a/pySEQ/expansion/_binder.py b/pySEQ/expansion/_binder.py deleted file mode 100644 index 389ebac..0000000 --- a/pySEQ/expansion/_binder.py +++ /dev/null @@ -1,71 +0,0 @@ -import polars as pl -from ._mapper import _mapper - -def _binder(self, kept_cols): - """ - Internal function to bind data to the map created by __mapper - """ - excluded = {"dose", - f"dose{self.indicator_squared}", - "followup", - f"followup{self.indicator_squared}", - "tx_lag", - "trial", - f"trial{self.indicator_squared}", - self.time_col, - f"{self.time_col}{self.indicator_squared}"} - - cols = kept_cols.union({self.eligible_col, self.outcome_col, self.treatment_col}) - cols = {col for col in cols if col is not None} - - regular = {col for col in cols if not (self.indicator_baseline in col or self.indicator_squared in col) and col not in excluded} - - baseline = {col for col in cols if self.indicator_baseline in col and col not in excluded} - bas_kept = {col.replace(self.indicator_baseline, "") for col in baseline} - - squared = {col for col in cols if self.indicator_squared in col and col not in excluded} - sq_kept = {col.replace(self.indicator_squared, "") for col in squared} - - kept = list(regular.union(bas_kept).union(sq_kept)) - - if self.selection_first_trial: - DT = self.data.sort([self.id_col, self.time_col]) \ - .with_columns([ - pl.col(self.time_col).alias("period"), - pl.col(self.time_col).alias("followup"), - pl.lit(0).alias("trial") - ]).drop(self.time_col) - else: - DT = _mapper(self.data, self.id_col, self.time_col, self.followup_min, self.followup_max) - DT = DT.join( - self.data.select([self.id_col, self.time_col] + kept), - left_on=[self.id_col, "period"], - right_on=[self.id_col, self.time_col], - how="left" - ) - DT = DT.sort([self.id_col, "trial", "followup"]) \ - .with_columns([ - (pl.col("trial") ** 2).alias(f"trial{self.indicator_squared}"), - (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}") - ]) - - if squared: - squares = [] - for sq in squared: - col = sq.replace(self.indicator_squared, "") - squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}")) - DT = DT.with_columns(squares) - - baseline_cols = {bas.replace(self.indicator_baseline, "") for bas in baseline} - needed = {self.eligible_col, self.treatment_col} - baseline_cols.update({c for c in needed}) - - bas = [ - pl.col(c).first().over([self.id_col, "trial"]).alias(f"{c}{self.indicator_baseline}") - for c in baseline_cols - ] - - DT = DT.with_columns(bas).filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1) \ - .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col]) - - return DT \ No newline at end of file diff --git a/pySEQ/expansion/_diagnostics.py b/pySEQ/expansion/_diagnostics.py deleted file mode 100644 index 726312f..0000000 --- a/pySEQ/expansion/_diagnostics.py +++ /dev/null @@ -1,56 +0,0 @@ -import polars as pl - -def _diagnostics(self): - unique_out = _outcome_diag(self, unique = True) - nonunique_out = _outcome_diag(self, unique = False) - out = {"unique_outcomes": unique_out, - "nonunique_outcomes": nonunique_out} - - if self.method == "censoring": - unique_switch = _switch_diag(self, unique = True) - nonunique_switch = _switch_diag(self, unique = False) - out.update({"unique_switches": unique_switch, - "nonunique_switches": nonunique_switch}) - - self.diagnostics = out - -def _outcome_diag(self, unique): - if unique: - data = self.DT.select([ - self.id_col, - self.treatment_col, - self.outcome_col - ]).group_by(self.id_col).last() - else: - data = self.DT - out = data.group_by([self.treatment_col, - self.outcome_col]).len() - - return out - -def _switch_diag(self, unique): - if not self.excused: - data = self.DT.with_columns(pl.lit(False) - .alias("isExcused")) - else: - data = self.DT - - if unique: - data = data.select([ - self.id_col, - self.treatment_col, - "switch", - "isExcused" - ]).with_columns( - pl.when((pl.col("switch") == 0) & - (pl.col("isExcused"))) - .then(1) - .otherwise(pl.col("switch")) - .alias("switch") - ).group_by(self.id_col).last() - - out = data.group_by([self.treatment_col, - "isExcused", - "switch"]).len() - return out - \ No newline at end of file diff --git a/pySEQ/expansion/_mapper.py b/pySEQ/expansion/_mapper.py deleted file mode 100644 index e2cfe80..0000000 --- a/pySEQ/expansion/_mapper.py +++ /dev/null @@ -1,30 +0,0 @@ -import polars as pl -import math - -def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf): - """ - Internal function to create the expanded map to bind data to. - """ - - DT = ( - data.select([pl.col(id_col), pl.col(time_col)]) - .with_columns([ - pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial") - ]) - .with_columns([ - pl.struct([pl.col(time_col), pl.col(time_col).max().over(id_col).alias("max_time")]) - .map_elements(lambda x: list(range(x[time_col], x["max_time"] + 1)), - return_dtype=pl.List(pl.Int64)) - .alias("period") - ]) - .explode("period") - .drop(pl.col(time_col)) - .with_columns([ - pl.col(id_col).cum_count().over([id_col, "trial"]).sub(1).alias("followup") - ]) - .filter( - (pl.col("followup") >= min_followup) & - (pl.col("followup") <= max_followup) - ) - ) - return DT diff --git a/pySEQ/expansion/_selection.py b/pySEQ/expansion/_selection.py deleted file mode 100644 index 52e39f5..0000000 --- a/pySEQ/expansion/_selection.py +++ /dev/null @@ -1,29 +0,0 @@ -import polars as pl -def _random_selection(self): - """ - Handles the case where random selection is applied for data from - the __mapper -> __binder -> optionally __dynamic pipeline - """ - UIDs = self.DT.select([ - self.id_col, - "trial", - f"{self.treatment_col}{self.indicator_baseline}"]) \ - .with_columns( - (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")) \ - .filter( - pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0) \ - .unique("trialID").to_series().to_list() - - NIDs = len(UIDs) - sample = self._rng.choice( - UIDs, - size=int(self.selection_probability * NIDs), - replace=False - ) - - self.DT = self.DT.with_columns( - (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID") - ).filter( - pl.col("trialID").is_in(sample) - ).drop("trialID") - \ No newline at end of file diff --git a/pySEQ/helpers/__init__.py b/pySEQ/helpers/__init__.py deleted file mode 100644 index 8f9a23b..0000000 --- a/pySEQ/helpers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ._col_string import _col_string -from ._bootstrap import bootstrap_loop -from ._format_time import _format_time -from ._predict_model import _predict_model -from ._prepare_data import _prepare_data -from ._pad import _pad \ No newline at end of file diff --git a/pySEQ/helpers/_prepare_data.py b/pySEQ/helpers/_prepare_data.py deleted file mode 100644 index efc0e3f..0000000 --- a/pySEQ/helpers/_prepare_data.py +++ /dev/null @@ -1,14 +0,0 @@ -import polars as pl - -def _prepare_data(self, DT): - binaries = [self.eligible_col, self.outcome_col, self.cense_colname] # self.excused_colnames + self.weight_eligible_colnames - binary_colnames = [col for col in binaries if col in DT.columns and not None] - - DT = DT.with_columns( - [ - *[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols], - *[pl.col(col).cast(pl.Int8) for col in binary_colnames], - pl.col(self.id_col).cast(pl.Utf8), - ] - ) - return DT diff --git a/pySEQ/initialization/__init__.py b/pySEQ/initialization/__init__.py deleted file mode 100644 index 004b0ce..0000000 --- a/pySEQ/initialization/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._outcome import _outcome -from ._censoring import _cense_numerator, _cense_denominator -from ._numerator import _numerator -from ._denominator import _denominator \ No newline at end of file diff --git a/pySEQ/initialization/_censoring.py b/pySEQ/initialization/_censoring.py deleted file mode 100644 index 94ce6a9..0000000 --- a/pySEQ/initialization/_censoring.py +++ /dev/null @@ -1,29 +0,0 @@ -def _cense_numerator(self) -> str: - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None - time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None - fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - - if self.weight_preexpansion: - out = "+".join(filter(None, ["tx_lag", time, fixed])) - else: - out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv_bas])) - - return out - -def _cense_denominator(self) -> str: - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None - time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None - fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - - if self.weight_preexpansion: - out = "+".join(filter(None, ["tx_lag", time, fixed, tv])) - else: - out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv, tv_bas])) - - return out - diff --git a/pySEQ/plot/__init__.py b/pySEQ/plot/__init__.py deleted file mode 100644 index 6d77c39..0000000 --- a/pySEQ/plot/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ._survival_plot import _survival_plot \ No newline at end of file diff --git a/pySEQ/weighting/__init__.py b/pySEQ/weighting/__init__.py deleted file mode 100644 index 10953df..0000000 --- a/pySEQ/weighting/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ._weight_fit import _fit_LTFU, _fit_numerator, _fit_denominator -from ._weight_pred import _weight_predict -from ._weight_bind import _weight_bind -from ._weight_data import _weight_setup -from ._weight_stats import _weight_stats \ No newline at end of file diff --git a/pySEQ/weighting/_weight_bind.py b/pySEQ/weighting/_weight_bind.py deleted file mode 100644 index be9d188..0000000 --- a/pySEQ/weighting/_weight_bind.py +++ /dev/null @@ -1,58 +0,0 @@ -import polars as pl - -def _weight_bind(self, WDT): - if self.weight_preexpansion: - join = "inner" - on = [self.id_col, "period"] - WDT = WDT.rename({self.time_col: "period"}) - else: - join = "left" - on = [self.id_col, "trial", "followup"] - - WDT = self.DT.join(WDT, on=on, how=join) - - if self.weight_preexpansion and self.excused: - trial = (pl.col("trial") == 0) & (pl.col("period") == 0) - excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0 - override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-7) - ) - elif not self.weight_preexpansion and self.excused: - trial = pl.col("followup") == 0 - excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0 - override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-7) | - (pl.col("numerator") < 1e-7) - ) - else: - trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (pl.col("followup") == 0) - excused = pl.lit(False) - override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-15) | - pl.col("numerator").is_null() - ) - - self.DT = WDT.with_columns( - pl.when(override) - .then(pl.lit(1.0)) - .otherwise(pl.col("numerator") / pl.col("denominator")) - .alias("wt") - ).sort( - [self.id_col, "trial", "followup"] - ).with_columns( - pl.col("wt") - .fill_null(1.0) - .cum_prod() - .over([self.id_col, "trial"]) - .alias("weight") - ) - \ No newline at end of file diff --git a/pySEQ/weighting/_weight_data.py b/pySEQ/weighting/_weight_data.py deleted file mode 100644 index 4aa3780..0000000 --- a/pySEQ/weighting/_weight_data.py +++ /dev/null @@ -1,38 +0,0 @@ -import polars as pl - -def _weight_setup(self): - DT = self.DT - data = self.data - if not self.weight_preexpansion: - baseline_lag = data.select([self.treatment_col, self.id_col, self.time_col]) \ - .sort([self.id_col, self.time_col]) \ - .with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over(self.id_col) - .alias("tx_lag")) \ - .drop(self.treatment_col) \ - .rename({self.time_col : "period"}) - - fup0 = DT.filter(pl.col("followup") == 0) \ - .join( - baseline_lag, - on = [self.id_col, "period"], - how = "inner" - ) - - fup = DT.sort([self.id_col, "trial", "followup"]) \ - .with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over([self.id_col, "trial"]) - .alias("tx_lag") - ).filter(pl.col("followup") != 0) - - WDT = pl.concat([fup0, fup]).sort([self.id_col, "trial", "followup"]) - else: - WDT = data.with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over(self.id_col) - .alias("tx_lag"), - (pl.col(self.time_col) ** 2).alias(f"{self.time_col}{self.indicator_squared}")) - return WDT - \ No newline at end of file diff --git a/pySEQ/weighting/_weight_stats.py b/pySEQ/weighting/_weight_stats.py deleted file mode 100644 index ddbd190..0000000 --- a/pySEQ/weighting/_weight_stats.py +++ /dev/null @@ -1,20 +0,0 @@ -import polars as pl - -def _weight_stats(self): - stats = self.DT.select([ - pl.col("weight").min().alias("weight_min"), - pl.col("weight").max().alias("weight_max"), - pl.col("weight").mean().alias("weight_mean"), - pl.col("weight").std().alias("weight_std"), - pl.col("weight").quantile(0.01).alias("weight_p01"), - pl.col("weight").quantile(0.25).alias("weight_p25"), - pl.col("weight").quantile(0.50).alias("weight_p50"), - pl.col("weight").quantile(0.75).alias("weight_p75"), - pl.col("weight").quantile(0.99).alias("weight_p99") - ]) - - if self.weight_p99: - self.weight_min = stats.select("weight_p01").item() - self.weight_max = stats.select("weight_p99").item() - - return stats diff --git a/pySEQ/SEQopts.py b/pySEQTarget/SEQopts.py similarity index 76% rename from pySEQ/SEQopts.py rename to pySEQTarget/SEQopts.py index e5d737a..4a59de6 100644 --- a/pySEQ/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -1,6 +1,7 @@ import multiprocessing from dataclasses import dataclass, field -from typing import List, Optional, Literal +from typing import List, Literal, Optional + @dataclass class SEQopts: @@ -8,7 +9,7 @@ class SEQopts: bootstrap_sample: float = 0.8 bootstrap_CI: float = 0.95 bootstrap_CI_method: Literal["se", "percentile"] = "se" - cense_colname : Optional[str] = None + cense_colname: Optional[str] = None cense_denominator: Optional[str] = None cense_numerator: Optional[str] = None cense_eligible_colname: Optional[str] = None @@ -29,7 +30,9 @@ class SEQopts: ncores: int = multiprocessing.cpu_count() numerator: Optional[str] = None parallel: bool = False - plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]) + plot_colors: List[str] = field( + default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"] + ) plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None plot_type: Literal["risk", "survival", "incidence"] = "risk" @@ -47,14 +50,23 @@ class SEQopts: weight_p99: bool = False weight_preexpansion: bool = False weighted: bool = False - + def __post_init__(self): bools = [ - "excused", "followup_class", "followup_include", - "followup_spline", "hazard_estimate", "km_curves", - "parallel", "selection_first_trial", "selection_random", - "trial_include", "weight_lag_condition", "weight_p99", - "weight_preexpansion", "weighted" + "excused", + "followup_class", + "followup_include", + "followup_spline", + "hazard_estimate", + "km_curves", + "parallel", + "selection_first_trial", + "selection_random", + "trial_include", + "weight_lag_condition", + "weight_p99", + "weight_preexpansion", + "weighted", ] for i in bools: if not isinstance(getattr(self, i), bool): @@ -62,25 +74,32 @@ def __post_init__(self): if not isinstance(self.bootstrap_nboot, int) or self.bootstrap_nboot < 0: raise ValueError("bootstrap_nboot must be a positive integer.") - + if self.ncores < 1 or not isinstance(self.ncores, int): raise ValueError("ncores must be a positive integer.") - + if not (0.0 <= self.bootstrap_sample <= 1.0): raise ValueError("bootstrap_sample must be between 0 and 1.") if not (0.0 < self.bootstrap_CI < 1.0): raise ValueError("bootstrap_CI must be between 0 and 1.") if not (0.0 <= self.selection_probability <= 1.0): raise ValueError("selection_probability must be between 0 and 1.") - + if self.plot_type not in ["risk", "survival", "incidence"]: - raise ValueError("plot_type must be either 'risk', 'survival', or 'incidence'.") - + raise ValueError( + "plot_type must be either 'risk', 'survival', or 'incidence'." + ) + if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") - for i in ("covariates", "numerator", "denominator", - "cense_numerator", "cense_denominator"): + for i in ( + "covariates", + "numerator", + "denominator", + "cense_numerator", + "cense_denominator", + ): attr = getattr(self, i) if attr is not None and not isinstance(attr, list): setattr(self, i, "".join(attr.split())) diff --git a/pySEQ/SEQoutput.py b/pySEQTarget/SEQoutput.py similarity index 75% rename from pySEQ/SEQoutput.py rename to pySEQTarget/SEQoutput.py index 1cb4d32..ed1ea74 100644 --- a/pySEQ/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -1,9 +1,12 @@ from dataclasses import dataclass -from typing import List, Optional, Literal -from .SEQopts import SEQopts -from statsmodels.base.wrapper import ResultsWrapper -import polars as pl +from typing import List, Literal, Optional + import matplotlib.figure +import polars as pl +from statsmodels.base.wrapper import ResultsWrapper + +from .SEQopts import SEQopts + @dataclass class SEQoutput: @@ -21,16 +24,13 @@ class SEQoutput: risk_difference: pl.DataFrame = None time: dict = None diagnostic_tables: dict = None - + def plot(self): print(self.km_graph) - - def summary(self, - type = Optional[Literal[ - "numerator", - "denominator", - "outcome", - "compevent"]]): + + def summary( + self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] + ): match type: case "numerator": models = self.numerator_models @@ -40,20 +40,24 @@ def summary(self, models = self.compevent_models case _: models = self.outcome_models - + return [model.summary() for model in models] - - def retrieve_data(self, - type = Optional[Literal[ - "km_data", - "hazard", - "risk_ratio", - "risk_difference", - "unique_outcomes", - "nonunique_outcomes", - "unique_switches", - "nonunique_switches" - ]]): + + def retrieve_data( + self, + type=Optional[ + Literal[ + "km_data", + "hazard", + "risk_ratio", + "risk_difference", + "unique_outcomes", + "nonunique_outcomes", + "unique_switches", + "nonunique_switches", + ] + ], + ): match type: case "hazard": data = self.hazard @@ -80,5 +84,3 @@ def retrieve_data(self, if data is None: raise ValueError("Data {type} was not created in the SEQuential process") return data - - \ No newline at end of file diff --git a/pySEQ/SEQuential.py b/pySEQTarget/SEQuential.py similarity index 53% rename from pySEQ/SEQuential.py rename to pySEQTarget/SEQuential.py index 3240e9e..1485939 100644 --- a/pySEQ/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -1,35 +1,57 @@ -from typing import Optional, List, Literal +import datetime import time -from dataclasses import asdict from collections import Counter -import polars as pl +from dataclasses import asdict +from typing import List, Literal, Optional + import numpy as np -import datetime +import polars as pl +from .analysis import ( + _calculate_hazard, + _calculate_survival, + _outcome_fit, + _pred_risk, + _risk_estimates, + _subgroup_fit, +) +from .error import _datachecker, _param_checker +from .expansion import _binder, _diagnostics, _dynamic, _random_selection +from .helpers import _col_string, _format_time, bootstrap_loop +from .initialization import ( + _cense_denominator, + _cense_numerator, + _denominator, + _numerator, + _outcome, +) +from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .error import _param_checker, _datachecker -from .helpers import _col_string, bootstrap_loop, _format_time -from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator -from .expansion import _binder, _dynamic, _random_selection, _diagnostics -from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats -from .analysis import _outcome_fit, _pred_risk, _calculate_survival, _subgroup_fit, _calculate_hazard, _risk_estimates -from .plot import _survival_plot +from .weighting import ( + _fit_denominator, + _fit_LTFU, + _fit_numerator, + _weight_bind, + _weight_predict, + _weight_setup, + _weight_stats, +) class SEQuential: def __init__( - self, - data: pl.DataFrame, - id_col: str, - time_col: str, - eligible_col: str, - treatment_col: str, - outcome_col: str, - time_varying_cols: Optional[List[str]] = None, - fixed_cols: Optional[List[str]] = None, - method: Literal["ITT", "dose-response", "censoring"] = "ITT", - parameters: Optional[SEQopts] = None + self, + data: pl.DataFrame, + id_col: str, + time_col: str, + eligible_col: str, + treatment_col: str, + outcome_col: str, + time_varying_cols: Optional[List[str]] = None, + fixed_cols: Optional[List[str]] = None, + method: Literal["ITT", "dose-response", "censoring"] = "ITT", + parameters: Optional[SEQopts] = None, ) -> None: self.data = data self.id_col = id_col @@ -40,17 +62,19 @@ def __init__( self.time_varying_cols = time_varying_cols self.fixed_cols = fixed_cols self.method = method - + self._time_initialized = datetime.datetime.now() - + if parameters is None: parameters = SEQopts() - + for name, value in asdict(parameters).items(): setattr(self, name, value) - - self._rng = np.random.RandomState(self.seed) if self.seed is not None else np.random - + + self._rng = ( + np.random.RandomState(self.seed) if self.seed is not None else np.random + ) + if self.covariates is None: self.covariates = _outcome(self) @@ -67,194 +91,227 @@ def __init__( if self.cense_denominator is None: self.cense_denominator = _cense_denominator(self) - + _param_checker(self) _datachecker(self) def expand(self): start = time.perf_counter() - kept = [self.cense_colname, self.cense_eligible_colname, - self.compevent_colname, - *self.weight_eligible_colnames, - *self.excused_colnames] - - self.data = self.data.with_columns([ - pl.when(pl.col(self.treatment_col).is_in(self.treatment_level)) - .then(self.eligible_col) - .otherwise(0) - .alias(self.eligible_col), - pl.col(self.treatment_col) - .shift(1) - .over([self.id_col]) - .alias("tx_lag"), - pl.lit(False).alias("switch") - ]).with_columns([ - pl.when(pl.col(self.time_col) == 0) - .then(pl.lit(False)) - .otherwise( - (pl.col("tx_lag").is_not_null()) & - (pl.col("tx_lag") != pl.col(self.treatment_col)) - ).cast(pl.Int8) - .alias("switch") - ]) - - self.DT = _binder(self, kept_cols= _col_string([self.covariates, - self.numerator, - self.denominator, - self.cense_numerator, - self.cense_denominator]).union(kept)) \ - .with_columns( - pl.col(self.id_col) - .cast(pl.Utf8) - .alias(self.id_col) - ) - + kept = [ + self.cense_colname, + self.cense_eligible_colname, + self.compevent_colname, + *self.weight_eligible_colnames, + *self.excused_colnames, + ] + self.data = self.data.with_columns( - pl.col(self.id_col) - .cast(pl.Utf8) - .alias(self.id_col) - ) - + [ + pl.when(pl.col(self.treatment_col).is_in(self.treatment_level)) + .then(self.eligible_col) + .otherwise(0) + .alias(self.eligible_col), + pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"), + pl.lit(False).alias("switch"), + ] + ).with_columns( + [ + pl.when(pl.col(self.time_col) == 0) + .then(pl.lit(False)) + .otherwise( + (pl.col("tx_lag").is_not_null()) + & (pl.col("tx_lag") != pl.col(self.treatment_col)) + ) + .cast(pl.Int8) + .alias("switch") + ] + ) + + self.DT = _binder( + self, + kept_cols=_col_string( + [ + self.covariates, + self.numerator, + self.denominator, + self.cense_numerator, + self.cense_denominator, + ] + ).union(kept), + ).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)) + + self.data = self.data.with_columns( + pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col) + ) + if self.method != "ITT": _dynamic(self) if self.selection_random: _random_selection(self) _diagnostics(self) - + end = time.perf_counter() self._expansion_time = _format_time(start, end) - + def bootstrap(self, **kwargs): - allowed = {"bootstrap_nboot", "bootstrap_sample", - "bootstrap_CI", "bootstrap_method"} + allowed = { + "bootstrap_nboot", + "bootstrap_sample", + "bootstrap_CI", + "bootstrap_method", + } for key, value in kwargs.items(): if key in allowed: setattr(self, key, value) else: raise ValueError(f"Unknown argument: {key}") - + UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list() NIDs = len(UIDs) - + self._boot_samples = [] for _ in range(self.bootstrap_nboot): - sampled_IDs = self._rng.choice(UIDs, size=int(self.bootstrap_sample * NIDs), replace=True) + sampled_IDs = self._rng.choice( + UIDs, size=int(self.bootstrap_sample * NIDs), replace=True + ) id_counts = Counter(sampled_IDs) self._boot_samples.append(id_counts) return self - - @bootstrap_loop + + @bootstrap_loop def fit(self): if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"): - raise ValueError("Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping.") - + raise ValueError( + "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping." + ) + if self.weighted: WDT = _weight_setup(self) if not self.weight_preexpansion and not self.excused: WDT = WDT.filter(pl.col("followup") > 0) - + WDT = WDT.to_pandas() for col in self.fixed_cols: if col in WDT.columns: WDT[col] = WDT[col].astype("category") - + _fit_LTFU(self, WDT) _fit_numerator(self, WDT) _fit_denominator(self, WDT) - + WDT = pl.from_pandas(WDT) WDT = _weight_predict(self, WDT) _weight_bind(self, WDT) self.weight_stats = _weight_stats(self) - + if self.subgroup_colname is not None: return _subgroup_fit(self) - - models = {'outcome': _outcome_fit(self, self.DT, - self.outcome_col, - self.covariates, - self.weighted, - "weight")} + + models = { + "outcome": _outcome_fit( + self, + self.DT, + self.outcome_col, + self.covariates, + self.weighted, + "weight", + ) + } if self.compevent_colname is not None: - models['compevent'] = _outcome_fit(self, self.DT, - self.compevent_colname, - self.covariates, - self.weighted, - "weight") + models["compevent"] = _outcome_fit( + self, + self.DT, + self.compevent_colname, + self.covariates, + self.weighted, + "weight", + ) return models - + def survival(self): if not hasattr(self, "outcome_model") or not self.outcome_model: - raise ValueError("Outcome model not found. Please run the 'fit' method before calculating survival.") - + raise ValueError( + "Outcome model not found. Please run the 'fit' method before calculating survival." + ) + start = time.perf_counter() - + risk_data = _pred_risk(self) surv_data = _calculate_survival(self, risk_data) self.km_data = pl.concat([risk_data, surv_data]) self.risk_estimates = _risk_estimates(self) - + end = time.perf_counter() self._survival_time = _format_time(start, end) - + def hazard(self): start = time.perf_counter() - + if not hasattr(self, "outcome_model") or not self.outcome_model: - raise ValueError("Outcome model not found. Please run the 'fit' method before calculating hazard ratio.") + raise ValueError( + "Outcome model not found. Please run the 'fit' method before calculating hazard ratio." + ) self.hazard_ratio = _calculate_hazard(self) - + end = time.perf_counter() self._hazard_time = _format_time(start, end) def plot(self): self.km_graph = _survival_plot(self) - + def collect(self): self._time_collected = datetime.datetime.now() - + generated = [ - "numerator_model", "denominator_model", + "numerator_model", + "denominator_model", "outcome_model", - "hazard_ratio", "risk_estimates", - "km_data", "km_graph", "diagnostics", - "_survival_time", "_hazard_time", - "_model_time", "_expansion_time", - "weight_stats" + "hazard_ratio", + "risk_estimates", + "km_data", + "km_graph", + "diagnostics", + "_survival_time", + "_hazard_time", + "_model_time", + "_expansion_time", + "weight_stats", ] for attr in generated: if not hasattr(self, attr): setattr(self, attr, None) - + # Options ========================== base = SEQopts() - + for name, value in vars(self).items(): if name in asdict(base).keys(): setattr(base, name, value) - - # Timing ========================= - time = {"start_time": self._time_initialized, - "expansion_time": self._expansion_time, - "model_time": self._model_time, - "survival_time": self._survival_time, - "hazard_time": self._hazard_time, - "collection_time": self._time_collected} - + + # Timing ========================= + time = { + "start_time": self._time_initialized, + "expansion_time": self._expansion_time, + "model_time": self._model_time, + "survival_time": self._survival_time, + "hazard_time": self._hazard_time, + "collection_time": self._time_collected, + } + if self.compevent_colname is not None: compevent_models = [model["compevent"] for model in self.outcome_models] else: compevent_models = None - + if self.outcome_model is not None: outcome_models = [model["outcome"] for model in self.outcome_model] - + if self.risk_estimates is None: risk_ratio = risk_difference = None else: risk_ratio = self.risk_estimates["risk_ratio"] risk_difference = self.risk_estimates["risk_difference"] - + output = SEQoutput( options=base, method=self.method, @@ -269,8 +326,7 @@ def collect(self): risk_ratio=risk_ratio, risk_difference=risk_difference, time=time, - diagnostic_tables=self.diagnostics + diagnostic_tables=self.diagnostics, ) - + return output - diff --git a/pySEQ/__init__.py b/pySEQTarget/__init__.py similarity index 66% rename from pySEQ/__init__.py rename to pySEQTarget/__init__.py index 4167c03..d8c687e 100644 --- a/pySEQ/__init__.py +++ b/pySEQTarget/__init__.py @@ -1,8 +1,5 @@ -from .SEQuential import SEQuential from .SEQopts import SEQopts from .SEQoutput import SEQoutput +from .SEQuential import SEQuential -__all__ = [ - "SEQuential", - "SEQopts" -] \ No newline at end of file +__all__ = ["SEQuential", "SEQopts", "SEQoutput"] diff --git a/pySEQTarget/analysis/__init__.py b/pySEQTarget/analysis/__init__.py new file mode 100644 index 0000000..e35ceb7 --- /dev/null +++ b/pySEQTarget/analysis/__init__.py @@ -0,0 +1,7 @@ +from ._hazard import _calculate_hazard as _calculate_hazard +from ._outcome_fit import _outcome_fit as _outcome_fit +from ._risk_estimates import _risk_estimates as _risk_estimates +from ._subgroup_fit import _subgroup_fit as _subgroup_fit +from ._survival_pred import _calculate_survival as _calculate_survival +from ._survival_pred import _get_outcome_predictions as _get_outcome_predictions +from ._survival_pred import _pred_risk as _pred_risk diff --git a/pySEQ/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py similarity index 57% rename from pySEQ/analysis/_hazard.py rename to pySEQTarget/analysis/_hazard.py index ec4ca80..4c667c9 100644 --- a/pySEQ/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -1,20 +1,22 @@ -import polars as pl +import warnings + import numpy as np +import polars as pl from lifelines import CoxPHFitter -import warnings + def _calculate_hazard(self): if self.subgroup_colname is None: return _calculate_hazard_single(self, self.DT, idx=None, val=None) - + all_hazards = [] original_DT = self.DT - + for i, val in enumerate(self._unique_subgroups): subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val) hazard = _calculate_hazard_single(self, subgroup_DT, i, val) all_hazards.append(hazard) - + self.DT = original_DT return pl.concat(all_hazards) @@ -24,30 +26,31 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if full_hr is None or np.isnan(full_hr): return _create_hazard_output(None, None, None, val, self) - + if self.bootstrap_nboot > 0: boot_hrs = [] - + for boot_idx in range(len(self._boot_samples)): id_counts = self._boot_samples[boot_idx] - + boot_data_list = [] for id_val, count in id_counts.items(): id_data = data.filter(pl.col(self.id_col) == id_val) for _ in range(count): boot_data_list.append(id_data) - + boot_data = pl.concat(boot_data_list) - + boot_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng) if boot_hr is not None and not np.isnan(boot_hr): boot_hrs.append(boot_hr) - + if len(boot_hrs) == 0: return _create_hazard_output(full_hr, None, None, val, self) - + if self.bootstrap_CI_method == "se": from scipy.stats import norm + z = norm.ppf(1 - (1 - self.bootstrap_CI) / 2) se = np.std(boot_hrs) lci = full_hr - z * se @@ -57,102 +60,132 @@ def _calculate_hazard_single(self, data, idx=None, val=None): uci = np.quantile(boot_hrs, 1 - (1 - self.bootstrap_CI) / 2) else: lci, uci = None, None - + return _create_hazard_output(full_hr, lci, uci, val, self) def _hazard_handler(self, data, idx, boot_idx, rng): - exclude_cols = ["followup", f"followup{self.indicator_squared}", self.treatment_col, - f"{self.treatment_col}{self.indicator_baseline}", "period", self.outcome_col] + exclude_cols = [ + "followup", + f"followup{self.indicator_squared}", + self.treatment_col, + f"{self.treatment_col}{self.indicator_baseline}", + "period", + self.outcome_col, + ] if self.compevent_colname: exclude_cols.append(self.compevent_colname) keep_cols = [col for col in data.columns if col not in exclude_cols] - + trials = ( data.select(keep_cols) - .group_by([self.id_col, "trial"]).first() + .group_by([self.id_col, "trial"]) + .first() .with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")]) .explode("followup") - .with_columns([(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]) + .with_columns( + [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")] + ) ) - + if idx is not None: model_dict = self.outcome_model[boot_idx][idx] else: model_dict = self.outcome_model[boot_idx] - - outcome_model = model_dict['outcome'] - ce_model = model_dict.get('compevent', None) if self.compevent_colname else None - + + outcome_model = model_dict["outcome"] + ce_model = model_dict.get("compevent", None) if self.compevent_colname else None + all_treatments = [] for val in self.treatment_level: - tmp = trials.with_columns([ - pl.lit(val) - .alias(f"{self.treatment_col}{self.indicator_baseline}") - ]) - + tmp = trials.with_columns( + [pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")] + ) + tmp_pd = tmp.to_pandas() outcome_prob = outcome_model.predict(tmp_pd) outcome_sim = rng.binomial(1, outcome_prob) - + tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)]) - + if ce_model is not None: ce_prob = ce_model.predict(tmp_pd) ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) - - tmp = tmp.with_columns([ - pl.when((pl.col("outcome") == 1) | - (pl.col("ce") == 1)) - .then(1) - .otherwise(0) - .alias("any_event") - ]).with_columns([ - pl.col("any_event") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ]).filter(pl.col("event_cumsum") <= 1) + + tmp = ( + tmp.with_columns( + [ + pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) + .then(1) + .otherwise(0) + .alias("any_event") + ] + ) + .with_columns( + [ + pl.col("any_event") + .cum_sum() + .over([self.id_col, "trial"]) + .alias("event_cumsum") + ] + ) + .filter(pl.col("event_cumsum") <= 1) + ) else: - tmp = tmp.with_columns([ - pl.col("outcome") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ]).filter(pl.col("event_cumsum") <= 1) - + tmp = tmp.with_columns( + [ + pl.col("outcome") + .cum_sum() + .over([self.id_col, "trial"]) + .alias("event_cumsum") + ] + ).filter(pl.col("event_cumsum") <= 1) + tmp = tmp.group_by([self.id_col, "trial"]).last() all_treatments.append(tmp) - + sim_data = pl.concat(all_treatments) - + if ce_model is not None: - sim_data = sim_data.with_columns([ - pl.when(pl.col("outcome") == 1).then(pl.lit(1)) - .when(pl.col("ce") == 1).then(pl.lit(2)) - .otherwise(pl.lit(0)).alias("event") - ]) + sim_data = sim_data.with_columns( + [ + pl.when(pl.col("outcome") == 1) + .then(pl.lit(1)) + .when(pl.col("ce") == 1) + .then(pl.lit(2)) + .otherwise(pl.lit(0)) + .alias("event") + ] + ) else: sim_data = sim_data.with_columns([pl.col("outcome").alias("event")]) - + sim_data_pd = sim_data.to_pandas() - + try: - #COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() - warnings.filterwarnings('ignore', message='.*datetime.datetime.utcnow.*') + # COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() + warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*") if ce_model is not None: - cox_data = sim_data_pd[sim_data_pd['event'].isin([0, 1])].copy() - cox_data['event_binary'] = (cox_data['event'] == 1).astype(int) - + cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy() + cox_data["event_binary"] = (cox_data["event"] == 1).astype(int) + cph = CoxPHFitter() - cph.fit(cox_data, duration_col='followup', event_col='event_binary', - formula=f"`{self.treatment_col}{self.indicator_baseline}`") + cph.fit( + cox_data, + duration_col="followup", + event_col="event_binary", + formula=f"`{self.treatment_col}{self.indicator_baseline}`", + ) else: cph = CoxPHFitter() - cph.fit(sim_data_pd, duration_col='followup', event_col='event', - formula=f"`{self.treatment_col}{self.indicator_baseline}`") - + cph.fit( + sim_data_pd, + duration_col="followup", + event_col="event", + formula=f"`{self.treatment_col}{self.indicator_baseline}`", + ) + hr = np.exp(cph.params_.values[0]) return hr except Exception as e: @@ -162,17 +195,17 @@ def _hazard_handler(self, data, idx, boot_idx, rng): def _create_hazard_output(hr, lci, uci, val, self): if lci is not None and uci is not None: - output = pl.DataFrame({ - "Hazard": [hr if hr is not None else float('nan')], - "LCI": [lci], - "UCI": [uci] - }) + output = pl.DataFrame( + { + "Hazard": [hr if hr is not None else float("nan")], + "LCI": [lci], + "UCI": [uci], + } + ) else: - output = pl.DataFrame({ - "Hazard": [hr if hr is not None else float('nan')] - }) - + output = pl.DataFrame({"Hazard": [hr if hr is not None else float("nan")]}) + if val is not None: output = output.with_columns(pl.lit(val).alias(self.subgroup_colname)) - + return output diff --git a/pySEQ/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py similarity index 82% rename from pySEQ/analysis/_outcome_fit.py rename to pySEQTarget/analysis/_outcome_fit.py index f357680..7ed823f 100644 --- a/pySEQ/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -1,7 +1,9 @@ +import re + +import polars as pl import statsmodels.api as sm import statsmodels.formula.api as smf -import polars as pl -import re + def _outcome_fit( self, @@ -9,60 +11,62 @@ def _outcome_fit( outcome: str, formula: str, weighted: bool = False, - weight_col: str = "weight"): + weight_col: str = "weight", +): if weighted: df = df.with_columns( pl.col(weight_col).clip( - lower_bound=self.weight_min, - upper_bound=self.weight_max + lower_bound=self.weight_min, upper_bound=self.weight_max ) ) - + df_pd = df.to_pandas() - + df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") tx_bas = f"{self.treatment_col}{self.indicator_baseline}" df_pd[tx_bas] = df_pd[tx_bas].astype("category") - + if self.followup_class and not self.followup_spline: df_pd["followup"] = df_pd["followup"].astype("category") squared_col = f"followup{self.indicator_squared}" if squared_col in df_pd.columns: df_pd[squared_col] = df_pd[squared_col].astype("category") - + if self.followup_spline: - spline = f"cr(followup, df=3)" - + spline = "cr(followup, df=3)" + formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula) formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula) - formula = re.sub(rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula) + formula = re.sub( + rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula + ) formula = re.sub(r"\bfollowup\b", "", formula) - + formula = re.sub(r"\s+", " ", formula) formula = re.sub(r"\+\s*\+", "+", formula) formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip() - + if formula: formula = f"{formula} + I({spline}**2)" else: formula = f"I({spline}**2)" - + if self.fixed_cols: for col in self.fixed_cols: if col in df_pd.columns: df_pd[col] = df_pd[col].astype("category") - + full_formula = f"{outcome} ~ {formula}" - + glm_kwargs = { "formula": full_formula, "data": df_pd, - "family": sm.families.Binomial() + "family": sm.families.Binomial(), } - + if weighted: glm_kwargs["var_weights"] = df_pd[weight_col] - + model = smf.glm(**glm_kwargs) model_fit = model.fit() return model_fit diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py new file mode 100644 index 0000000..32c0dc4 --- /dev/null +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -0,0 +1,136 @@ +import polars as pl +from scipy import stats + + +def _risk_estimates(self): + last_followup = self.km_data["followup"].max() + risk = self.km_data.filter( + (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk") + ) + + group_cols = [self.subgroup_colname] if self.subgroup_colname else [] + rd_comparisons = [] + rr_comparisons = [] + + if self.bootstrap_nboot > 0: + alpha = 1 - self.bootstrap_CI + z = stats.norm.ppf(1 - alpha / 2) + + for tx_x in self.treatment_level: + for tx_y in self.treatment_level: + if tx_x == tx_y: + continue + + risk_x = ( + risk.filter(pl.col("tx_init") == tx_x) + .select(group_cols + ["pred"]) + .rename({"pred": "risk_x"}) + ) + + risk_y = ( + risk.filter(pl.col("tx_init") == tx_y) + .select(group_cols + ["pred"]) + .rename({"pred": "risk_y"}) + ) + + if group_cols: + comp = risk_x.join(risk_y, on=group_cols, how="left") + else: + comp = risk_x.join(risk_y, how="cross") + + comp = comp.with_columns( + [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] + ) + + if self.bootstrap_nboot > 0: + se_x = ( + risk.filter(pl.col("tx_init") == tx_x) + .select(group_cols + ["SE"]) + .rename({"SE": "se_x"}) + ) + + se_y = ( + risk.filter(pl.col("tx_init") == tx_y) + .select(group_cols + ["SE"]) + .rename({"SE": "se_y"}) + ) + + if group_cols: + comp = comp.join(se_x, on=group_cols, how="left") + comp = comp.join(se_y, on=group_cols, how="left") + else: + comp = comp.join(se_x, how="cross") + comp = comp.join(se_y, how="cross") + + rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt() + rd_comp = comp.with_columns( + [ + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"), + (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias( + "RD 95% LCI" + ), + (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias( + "RD 95% UCI" + ), + ] + ) + rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + [ + "A_x", + "A_y", + "Risk Difference", + "RD 95% LCI", + "RD 95% UCI", + ] + rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) + rd_comparisons.append(rd_comp) + + rr_log_se = ( + (pl.col("se_x") / pl.col("risk_x")).pow(2) + + (pl.col("se_y") / pl.col("risk_y")).pow(2) + ).sqrt() + rr_comp = comp.with_columns( + [ + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"), + ( + (pl.col("risk_x") / pl.col("risk_y")) + * (-z * rr_log_se).exp() + ).alias("RR 95% LCI"), + ( + (pl.col("risk_x") / pl.col("risk_y")) + * (z * rr_log_se).exp() + ).alias("RR 95% UCI"), + ] + ) + rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + [ + "A_x", + "A_y", + "Risk Ratio", + "RR 95% LCI", + "RR 95% UCI", + ] + rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) + rr_comparisons.append(rr_comp) + + else: + rd_comp = comp.with_columns( + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference") + ) + rd_comp = rd_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Difference"] + rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) + rd_comparisons.append(rd_comp) + + rr_comp = comp.with_columns( + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio") + ) + rr_comp = rr_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Ratio"] + rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) + rr_comparisons.append(rr_comp) + + risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() + risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() + + return {"risk_difference": risk_difference, "risk_ratio": risk_ratio} diff --git a/pySEQTarget/analysis/_subgroup_fit.py b/pySEQTarget/analysis/_subgroup_fit.py new file mode 100644 index 0000000..fd481cf --- /dev/null +++ b/pySEQTarget/analysis/_subgroup_fit.py @@ -0,0 +1,30 @@ +import polars as pl + +from ._outcome_fit import _outcome_fit + + +def _subgroup_fit(self): + subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list()) + self._unique_subgroups = subgroups + + models_list = [] + for val in subgroups: + subDT = self.DT.filter(pl.col(self.subgroup_colname) == val) + + models = { + "outcome": _outcome_fit( + self, subDT, self.outcome_col, self.covariates, self.weighted, "weight" + ) + } + + if self.compevent_colname is not None: + models["compevent"] = _outcome_fit( + self, + subDT, + self.compevent_colname, + self.covariates, + self.weighted, + "weight", + ) + models_list.append(models) + return models_list diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py new file mode 100644 index 0000000..617fc2c --- /dev/null +++ b/pySEQTarget/analysis/_survival_pred.py @@ -0,0 +1,363 @@ +import polars as pl + + +def _get_outcome_predictions(self, TxDT, idx=None): + data = TxDT.to_pandas() + predictions = {"outcome": []} + if self.compevent_colname is not None: + predictions["compevent"] = [] + + for boot_model in self.outcome_model: + model_dict = boot_model[idx] if idx is not None else boot_model + predictions["outcome"].append(model_dict["outcome"].predict(data)) + if self.compevent_colname is not None: + predictions["compevent"].append(model_dict["compevent"].predict(data)) + + return predictions + + +def _pred_risk(self): + has_subgroups = ( + isinstance(self.outcome_model[0], list) if self.outcome_model else False + ) + + if not has_subgroups: + return _calculate_risk(self, self.DT, idx=None, val=None) + + all_risks = [] + original_DT = self.DT + + for i, val in enumerate(self._unique_subgroups): + subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val) + risk = _calculate_risk(self, subgroup_DT, i, val) + all_risks.append(risk) + + self.DT = original_DT + return pl.concat(all_risks) + + +def _calculate_risk(self, data, idx=None, val=None): + a = 1 - self.bootstrap_CI + lci = a / 2 + uci = 1 - lci + + SDT = ( + data.with_columns( + [ + ( + pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8) + ).alias("TID") + ] + ) + .group_by("TID") + .first() + .drop(["followup", f"followup{self.indicator_squared}"]) + .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")]) + .explode("followup") + .with_columns( + [ + (pl.col("followup") + 1).alias("followup"), + (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"), + ] + ) + ).sort([self.id_col, "trial", "followup"]) + + risks = [] + for treatment_val in self.treatment_level: + TxDT = SDT.with_columns( + [ + pl.lit(treatment_val).alias( + f"{self.treatment_col}{self.indicator_baseline}" + ) + ] + ) + + if self.method == "dose-response": + if treatment_val == self.treatment_level[0]: + TxDT = TxDT.with_columns( + [pl.lit(0.0).alias("dose"), pl.lit(0.0).alias("dose_sq")] + ) + else: + TxDT = TxDT.with_columns( + [ + pl.col("followup").alias("dose"), + pl.col(f"followup{self.indicator_squared}").alias("dose_sq"), + ] + ) + + preds = _get_outcome_predictions(self, TxDT, idx=idx) + pred_series = [pl.Series("pred_outcome", preds["outcome"][0])] + + if self.bootstrap_nboot > 0: + for boot_idx, pred in enumerate(preds["outcome"][1:], start=1): + pred_series.append(pl.Series(f"pred_outcome_{boot_idx}", pred)) + + if self.compevent_colname is not None: + pred_series.append(pl.Series("pred_ce", preds["compevent"][0])) + if self.bootstrap_nboot > 0: + for boot_idx, pred in enumerate(preds["compevent"][1:], start=1): + pred_series.append(pl.Series(f"pred_ce_{boot_idx}", pred)) + + outcome_names = [s.name for s in pred_series if "outcome" in s.name] + ce_names = [s.name for s in pred_series if "ce" in s.name] + + TxDT = TxDT.with_columns(pred_series) + + if self.compevent_colname is not None: + for out_col, ce_col in zip(outcome_names, ce_names): + surv_col = out_col.replace("pred_outcome", "surv") + cce_col = out_col.replace("pred_outcome", "cce") + inc_col = out_col.replace("pred_outcome", "inc") + + TxDT = TxDT.with_columns( + [ + (1 - pl.col(out_col)).cum_prod().over("TID").alias(surv_col), + ((1 - pl.col(out_col)) * (1 - pl.col(ce_col))) + .cum_prod() + .over("TID") + .alias(cce_col), + ] + ).with_columns( + [ + (pl.col(out_col) * (1 - pl.col(ce_col)) * pl.col(cce_col)) + .cum_sum() + .over("TID") + .alias(inc_col) + ] + ) + + surv_names = [n.replace("pred_outcome", "surv") for n in outcome_names] + inc_names = [n.replace("pred_outcome", "inc") for n in outcome_names] + TxDT = ( + TxDT.group_by("followup") + .agg([pl.col(col).mean() for col in surv_names + inc_names]) + .sort("followup") + ) + main_col = "surv" + boot_cols = [col for col in surv_names if col != "surv"] + else: + TxDT = ( + TxDT.with_columns( + [ + (1 - pl.col(col)).cum_prod().over("TID").alias(col) + for col in outcome_names + ] + ) + .group_by("followup") + .agg([pl.col(col).mean() for col in outcome_names]) + .sort("followup") + .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names]) + ) + main_col = "pred_outcome" + boot_cols = [col for col in outcome_names if col != "pred_outcome"] + + if boot_cols: + risk = ( + TxDT.select(["followup"] + boot_cols) + .unpivot( + index="followup", + on=boot_cols, + variable_name="bootID", + value_name="risk", + ) + .group_by("followup") + .agg( + [ + pl.col("risk").std().alias("SE"), + pl.col("risk").quantile(lci).alias("LCI"), + pl.col("risk").quantile(uci).alias("UCI"), + ] + ) + .join(TxDT.select(["followup", main_col]), on="followup") + ) + + if self.bootstrap_CI_method == "se": + from scipy.stats import norm + + z = norm.ppf(1 - a / 2) + risk = risk.with_columns( + [ + (pl.col(main_col) - z * pl.col("SE")).alias("LCI"), + (pl.col(main_col) + z * pl.col("SE")).alias("UCI"), + ] + ) + + fup0_val = 1.0 if self.compevent_colname else 0.0 + + if self.compevent_colname is not None: + inc_boot_cols = [col for col in inc_names if col != "inc"] + if inc_boot_cols: + inc_risk = ( + TxDT.select(["followup"] + inc_boot_cols) + .unpivot( + index="followup", + on=inc_boot_cols, + variable_name="bootID", + value_name="inc_val", + ) + .group_by("followup") + .agg( + [ + pl.col("inc_val").std().alias("inc_SE"), + pl.col("inc_val").quantile(lci).alias("inc_LCI"), + pl.col("inc_val").quantile(uci).alias("inc_UCI"), + ] + ) + .join(TxDT.select(["followup", "inc"]), on="followup") + ) + risk = risk.join(inc_risk, on="followup") + final_cols = [ + "followup", + main_col, + "SE", + "LCI", + "UCI", + "inc", + "inc_SE", + "inc_LCI", + "inc_UCI", + ] + risk = risk.select(final_cols).with_columns( + pl.lit(treatment_val).alias(self.treatment_col) + ) + + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + "SE": [0.0], + "LCI": [fup0_val], + "UCI": [fup0_val], + "inc": [0.0], + "inc_SE": [0.0], + "inc_LCI": [0.0], + "inc_UCI": [0.0], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + else: + risk = risk.select( + ["followup", main_col, "SE", "LCI", "UCI"] + ).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + "SE": [0.0], + "LCI": [fup0_val], + "UCI": [fup0_val], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + else: + risk = risk.select( + ["followup", main_col, "SE", "LCI", "UCI"] + ).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + "SE": [0.0], + "LCI": [fup0_val], + "UCI": [fup0_val], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + else: + fup0_val = 1.0 if self.compevent_colname else 0.0 + risk = TxDT.select(["followup", main_col]).with_columns( + pl.lit(treatment_val).alias(self.treatment_col) + ) + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + + if self.compevent_colname is not None: + risk = risk.join(TxDT.select(["followup", "inc"]), on="followup") + fup0 = fup0.with_columns([pl.lit(0.0).alias("inc")]) + + risks.append(pl.concat([fup0, risk])) + out = pl.concat(risks) + + if self.compevent_colname is not None: + has_ci = "SE" in out.columns + + surv_cols = ["followup", self.treatment_col, "surv"] + if has_ci: + surv_cols.extend(["SE", "LCI", "UCI"]) + surv_out = ( + out.select(surv_cols) + .rename({"surv": "pred"}) + .with_columns(pl.lit("survival").alias("estimate")) + ) + + risk_cols = ["followup", self.treatment_col, (1 - pl.col("surv")).alias("pred")] + if has_ci: + risk_cols.extend( + [ + pl.col("SE"), + (1 - pl.col("UCI")).alias("LCI"), + (1 - pl.col("LCI")).alias("UCI"), + ] + ) + risk_out = out.select(risk_cols).with_columns(pl.lit("risk").alias("estimate")) + + inc_cols = ["followup", self.treatment_col, pl.col("inc").alias("pred")] + if has_ci: + inc_cols.extend( + [ + pl.col("inc_SE").alias("SE"), + pl.col("inc_LCI").alias("LCI"), + pl.col("inc_UCI").alias("UCI"), + ] + ) + inc_out = out.select(inc_cols).with_columns( + pl.lit("incidence").alias("estimate") + ) + + out = pl.concat([surv_out, risk_out, inc_out]) + else: + out = out.rename({"pred_outcome": "pred"}).with_columns( + pl.lit("risk").alias("estimate") + ) + + if val is not None: + out = out.with_columns(pl.lit(val).alias(self.subgroup_colname)) + + return out + + +def _calculate_survival(self, risk_data): + if self.bootstrap_nboot > 0: + surv = risk_data.with_columns( + [(1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"]] + ).with_columns(pl.lit("survival").alias("estimate")) + else: + surv = risk_data.with_columns( + [(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")] + ) + return surv diff --git a/pySEQ/data/SEQdata.csv b/pySEQTarget/data/SEQdata.csv similarity index 100% rename from pySEQ/data/SEQdata.csv rename to pySEQTarget/data/SEQdata.csv diff --git a/pySEQ/data/SEQdata_LTFU.csv b/pySEQTarget/data/SEQdata_LTFU.csv similarity index 100% rename from pySEQ/data/SEQdata_LTFU.csv rename to pySEQTarget/data/SEQdata_LTFU.csv diff --git a/pySEQ/data/SEQdata_multitreatment.csv b/pySEQTarget/data/SEQdata_multitreatment.csv similarity index 100% rename from pySEQ/data/SEQdata_multitreatment.csv rename to pySEQTarget/data/SEQdata_multitreatment.csv diff --git a/pySEQ/data/__init__.py b/pySEQTarget/data/__init__.py similarity index 73% rename from pySEQ/data/__init__.py rename to pySEQTarget/data/__init__.py index 50e5493..e65d31d 100644 --- a/pySEQ/data/__init__.py +++ b/pySEQTarget/data/__init__.py @@ -1,8 +1,10 @@ from importlib.resources import files + import polars as pl + def load_data(name: str = "SEQdata") -> pl.DataFrame: - loc = files("pySEQ.data") + loc = files("pySEQTarget.data") if name in ["SEQdata", "SEQdata_multitreatment", "SEQdata_LTFU"]: if name == "SEQdata": data_path = loc.joinpath("SEQdata.csv") @@ -12,4 +14,6 @@ def load_data(name: str = "SEQdata") -> pl.DataFrame: data_path = loc.joinpath("SEQdata_LTFU.csv") return pl.read_csv(data_path) else: - raise ValueError(f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']") \ No newline at end of file + raise ValueError( + f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']" + ) diff --git a/pySEQTarget/error/__init__.py b/pySEQTarget/error/__init__.py new file mode 100644 index 0000000..f51f084 --- /dev/null +++ b/pySEQTarget/error/__init__.py @@ -0,0 +1,2 @@ +from ._datachecker import _datachecker as _datachecker +from ._param_checker import _param_checker as _param_checker diff --git a/pySEQTarget/error/_datachecker.py b/pySEQTarget/error/_datachecker.py new file mode 100644 index 0000000..054c581 --- /dev/null +++ b/pySEQTarget/error/_datachecker.py @@ -0,0 +1,38 @@ +import polars as pl + + +def _datachecker(self): + check = self.data.group_by(self.id_col).agg( + [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")] + ) + + invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1) + if len(invalid) > 0: + raise ValueError( + f"Data validation failed: {len(invalid)} ID(s) have mismatched " + f"This suggests invalid times" + f"Invalid IDs:\n{invalid}" + ) + + for col in self.excused_colnames: + violations = ( + self.data.sort([self.id_col, self.time_col]) + .group_by(self.id_col) + .agg( + [ + ( + (pl.col(col).cum_sum().shift(1, fill_value=0) > 0) + & (pl.col(col) == 0) + ) + .any() + .alias("has_violation") + ] + ) + .filter(pl.col("has_violation")) + ) + + if len(violations) > 0: + raise ValueError( + f"Column '{col}' violates 'once one, always one' rule for excusing treatment " + f"{len(violations)} ID(s) have zeros after ones." + ) diff --git a/pySEQ/error/_param_checker.py b/pySEQTarget/error/_param_checker.py similarity index 57% rename from pySEQ/error/_param_checker.py rename to pySEQTarget/error/_param_checker.py index 2d90168..3a96448 100644 --- a/pySEQ/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -1,32 +1,43 @@ from ..helpers import _pad + def _param_checker(self): - if self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols: + if ( + self.subgroup_colname is not None + and self.subgroup_colname not in self.fixed_cols + ): raise ValueError("subgroup_colname must be included in fixed_cols.") - + if self.followup_max is None: - self.followup_max = self.data.select(self.time_col).to_series().max() - + self.followup_max = self.data.select(self.time_col).to_series().max() + if len(self.excused_colnames) == 0 and self.excused: self.excused = False - raise Warning("Excused column names not provided but excused is set to True. Automatically set excused to False") - + raise Warning( + "Excused column names not provided but excused is set to True. Automatically set excused to False" + ) + if len(self.excused_colnames) > 0 and not self.excused: self.excused = True - raise Warning("Excused column names provided but excused is set to False. Automatically set excused to True") - + raise Warning( + "Excused column names provided but excused is set to False. Automatically set excused to True" + ) + if self.km_curves and self.hazard_estimate: raise ValueError("km_curves and hazard cannot both be set to True.") - + if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1: - raise ValueError("Only one of followup_class or followup_include can be set to True.") - + raise ValueError( + "Only one of followup_class or followup_include can be set to True." + ) + if self.weighted and self.method == "ITT" and self.cense_colname is None: raise ValueError("For weighted ITT analyses, cense_colname must be provided.") - + if self.excused: _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames) - _, self.weight_eligible_colnames = _pad(self.treatment_level, self.weight_eligible_colnames) - + _, self.weight_eligible_colnames = _pad( + self.treatment_level, self.weight_eligible_colnames + ) + return - \ No newline at end of file diff --git a/pySEQTarget/expansion/__init__.py b/pySEQTarget/expansion/__init__.py new file mode 100644 index 0000000..1262af8 --- /dev/null +++ b/pySEQTarget/expansion/__init__.py @@ -0,0 +1,5 @@ +from ._binder import _binder as _binder +from ._diagnostics import _diagnostics as _diagnostics +from ._dynamic import _dynamic as _dynamic +from ._mapper import _mapper as _mapper +from ._selection import _random_selection as _random_selection diff --git a/pySEQTarget/expansion/_binder.py b/pySEQTarget/expansion/_binder.py new file mode 100644 index 0000000..727e0e6 --- /dev/null +++ b/pySEQTarget/expansion/_binder.py @@ -0,0 +1,98 @@ +import polars as pl + +from ._mapper import _mapper + + +def _binder(self, kept_cols): + """ + Internal function to bind data to the map created by __mapper + """ + excluded = { + "dose", + f"dose{self.indicator_squared}", + "followup", + f"followup{self.indicator_squared}", + "tx_lag", + "trial", + f"trial{self.indicator_squared}", + self.time_col, + f"{self.time_col}{self.indicator_squared}", + } + + cols = kept_cols.union({self.eligible_col, self.outcome_col, self.treatment_col}) + cols = {col for col in cols if col is not None} + + regular = { + col + for col in cols + if not (self.indicator_baseline in col or self.indicator_squared in col) + and col not in excluded + } + + baseline = { + col for col in cols if self.indicator_baseline in col and col not in excluded + } + bas_kept = {col.replace(self.indicator_baseline, "") for col in baseline} + + squared = { + col for col in cols if self.indicator_squared in col and col not in excluded + } + sq_kept = {col.replace(self.indicator_squared, "") for col in squared} + + kept = list(regular.union(bas_kept).union(sq_kept)) + + if self.selection_first_trial: + DT = ( + self.data.sort([self.id_col, self.time_col]) + .with_columns( + [ + pl.col(self.time_col).alias("period"), + pl.col(self.time_col).alias("followup"), + pl.lit(0).alias("trial"), + ] + ) + .drop(self.time_col) + ) + else: + DT = _mapper( + self.data, self.id_col, self.time_col, self.followup_min, self.followup_max + ) + DT = DT.join( + self.data.select([self.id_col, self.time_col] + kept), + left_on=[self.id_col, "period"], + right_on=[self.id_col, self.time_col], + how="left", + ) + DT = DT.sort([self.id_col, "trial", "followup"]).with_columns( + [ + (pl.col("trial") ** 2).alias(f"trial{self.indicator_squared}"), + (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"), + ] + ) + + if squared: + squares = [] + for sq in squared: + col = sq.replace(self.indicator_squared, "") + squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}")) + DT = DT.with_columns(squares) + + baseline_cols = {bas.replace(self.indicator_baseline, "") for bas in baseline} + needed = {self.eligible_col, self.treatment_col} + baseline_cols.update({c for c in needed}) + + bas = [ + pl.col(c) + .first() + .over([self.id_col, "trial"]) + .alias(f"{c}{self.indicator_baseline}") + for c in baseline_cols + ] + + DT = ( + DT.with_columns(bas) + .filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1) + .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col]) + ) + + return DT diff --git a/pySEQTarget/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py new file mode 100644 index 0000000..178062a --- /dev/null +++ b/pySEQTarget/expansion/_diagnostics.py @@ -0,0 +1,53 @@ +import polars as pl + + +def _diagnostics(self): + unique_out = _outcome_diag(self, unique=True) + nonunique_out = _outcome_diag(self, unique=False) + out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out} + + if self.method == "censoring": + unique_switch = _switch_diag(self, unique=True) + nonunique_switch = _switch_diag(self, unique=False) + out.update( + {"unique_switches": unique_switch, "nonunique_switches": nonunique_switch} + ) + + self.diagnostics = out + + +def _outcome_diag(self, unique): + if unique: + data = ( + self.DT.select([self.id_col, self.treatment_col, self.outcome_col]) + .group_by(self.id_col) + .last() + ) + else: + data = self.DT + out = data.group_by([self.treatment_col, self.outcome_col]).len() + + return out + + +def _switch_diag(self, unique): + if not self.excused: + data = self.DT.with_columns(pl.lit(False).alias("isExcused")) + else: + data = self.DT + + if unique: + data = ( + data.select([self.id_col, self.treatment_col, "switch", "isExcused"]) + .with_columns( + pl.when((pl.col("switch") == 0) & (pl.col("isExcused"))) + .then(1) + .otherwise(pl.col("switch")) + .alias("switch") + ) + .group_by(self.id_col) + .last() + ) + + out = data.group_by([self.treatment_col, "isExcused", "switch"]).len() + return out diff --git a/pySEQ/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py similarity index 54% rename from pySEQ/expansion/_dynamic.py rename to pySEQTarget/expansion/_dynamic.py index 8c8b53d..d2a99c5 100644 --- a/pySEQ/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -1,5 +1,6 @@ import polars as pl + def _dynamic(self): """ Handles special cases for the data from the __mapper -> __binder pipeline @@ -10,12 +11,9 @@ def _dynamic(self): .cum_sum() .over([self.id_col, "trial"]) .alias("dose") - ).with_columns([ - (pl.col("dose") ** 2) - .alias(f"dose{self.indicator_squared}") - ]) + ).with_columns([(pl.col("dose") ** 2).alias(f"dose{self.indicator_squared}")]) self.DT = DT - + elif self.method == "censoring": DT = self.DT.sort([self.id_col, "trial", "followup"]).with_columns( pl.col(self.treatment_col) @@ -23,13 +21,13 @@ def _dynamic(self): .over([self.id_col, "trial"]) .alias("tx_lag") ) - + switch = ( pl.when(pl.col("followup") == 0) .then(pl.lit(False)) .otherwise( - (pl.col("tx_lag").is_not_null()) & - (pl.col("tx_lag") != pl.col(self.treatment_col)) + (pl.col("tx_lag").is_not_null()) + & (pl.col("tx_lag") != pl.col(self.treatment_col)) ) ) is_excused = pl.lit(False) @@ -39,35 +37,25 @@ def _dynamic(self): colname = self.excused_colnames[i] if colname is not None: conditions.append( - (pl.col(colname) == 1) & - (pl.col(self.treatment_col) == self.treatment_level[i]) + (pl.col(colname) == 1) + & (pl.col(self.treatment_col) == self.treatment_level[i]) ) - + if conditions: excused = pl.any_horizontal(conditions) is_excused = switch & excused - switch = ( - pl.when(excused) - .then(pl.lit(False)) - .otherwise(switch) - ) - - DT = DT.with_columns([ - switch.alias("switch"), - is_excused.alias("isExcused") - ]).sort([self.id_col, "trial", "followup"]) \ + switch = pl.when(excused).then(pl.lit(False)).otherwise(switch) + + DT = ( + DT.with_columns([switch.alias("switch"), is_excused.alias("isExcused")]) + .sort([self.id_col, "trial", "followup"]) .filter( - ( - pl.col("switch") - .cum_max() - .shift(1, fill_value=False) + (pl.col("switch").cum_max().shift(1, fill_value=False)).over( + [self.id_col, "trial"] ) - .over([self.id_col, "trial"]) == 0 - ).with_columns( - pl.col("switch") - .cast(pl.Int8) - .alias("switch") + ) + .with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) ) - - self.DT = DT.drop(["tx_lag"]) \ No newline at end of file + + self.DT = DT.drop(["tx_lag"]) diff --git a/pySEQTarget/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py new file mode 100644 index 0000000..c169669 --- /dev/null +++ b/pySEQTarget/expansion/_mapper.py @@ -0,0 +1,44 @@ +import math + +import polars as pl + + +def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf): + """ + Internal function to create the expanded map to bind data to. + """ + + DT = ( + data.select([pl.col(id_col), pl.col(time_col)]) + .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")]) + .with_columns( + [ + pl.struct( + [ + pl.col(time_col), + pl.col(time_col).max().over(id_col).alias("max_time"), + ] + ) + .map_elements( + lambda x: list(range(x[time_col], x["max_time"] + 1)), + return_dtype=pl.List(pl.Int64), + ) + .alias("period") + ] + ) + .explode("period") + .drop(pl.col(time_col)) + .with_columns( + [ + pl.col(id_col) + .cum_count() + .over([id_col, "trial"]) + .sub(1) + .alias("followup") + ] + ) + .filter( + (pl.col("followup") >= min_followup) & (pl.col("followup") <= max_followup) + ) + ) + return DT diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py new file mode 100644 index 0000000..9f58cd7 --- /dev/null +++ b/pySEQTarget/expansion/_selection.py @@ -0,0 +1,31 @@ +import polars as pl + + +def _random_selection(self): + """ + Handles the case where random selection is applied for data from + the __mapper -> __binder -> optionally __dynamic pipeline + """ + UIDs = ( + self.DT.select( + [self.id_col, "trial", f"{self.treatment_col}{self.indicator_baseline}"] + ) + .with_columns((pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")) + .filter(pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0) + .unique("trialID") + .to_series() + .to_list() + ) + + NIDs = len(UIDs) + sample = self._rng.choice( + UIDs, size=int(self.selection_probability * NIDs), replace=False + ) + + self.DT = ( + self.DT.with_columns( + (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID") + ) + .filter(pl.col("trialID").is_in(sample)) + .drop("trialID") + ) diff --git a/pySEQTarget/helpers/__init__.py b/pySEQTarget/helpers/__init__.py new file mode 100644 index 0000000..f621544 --- /dev/null +++ b/pySEQTarget/helpers/__init__.py @@ -0,0 +1,6 @@ +from ._bootstrap import bootstrap_loop as bootstrap_loop +from ._col_string import _col_string as _col_string +from ._format_time import _format_time as _format_time +from ._pad import _pad as _pad +from ._predict_model import _predict_model as _predict_model +from ._prepare_data import _prepare_data as _prepare_data diff --git a/pySEQ/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py similarity index 62% rename from pySEQ/helpers/_bootstrap.py rename to pySEQTarget/helpers/_bootstrap.py index 5e2987e..7aefef6 100644 --- a/pySEQ/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -1,84 +1,111 @@ -from functools import wraps +import copy +import time from concurrent.futures import ProcessPoolExecutor, as_completed -import polars as pl +from functools import wraps + import numpy as np +import polars as pl from tqdm import tqdm -import copy -import time + from ._format_time import _format_time + def _prepare_boot_data(self, data, boot_id): id_counts = self._boot_samples[boot_id] - - counts = pl.DataFrame({ - self.id_col: list(id_counts.keys()), - "count": list(id_counts.values()) - }) - + + counts = pl.DataFrame( + {self.id_col: list(id_counts.keys()), "count": list(id_counts.values())} + ) + bootstrapped = data.join(counts, on=self.id_col, how="inner") - bootstrapped = bootstrapped.with_columns( - pl.int_ranges(0, pl.col("count")).alias("replicate") - ).explode("replicate").with_columns( - (pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("replicate").cast(pl.Utf8)).alias(self.id_col) - ).drop("count", "replicate") - + bootstrapped = ( + bootstrapped.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) + .explode("replicate") + .with_columns( + ( + pl.col(self.id_col).cast(pl.Utf8) + + "_" + + pl.col("replicate").cast(pl.Utf8) + ).alias(self.id_col) + ) + .drop("count", "replicate") + ) + return bootstrapped + def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): obj = copy.deepcopy(obj) - obj._rng = np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() + obj._rng = ( + np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() + ) obj.DT = _prepare_boot_data(obj, original_DT, i) - + # Disable bootstrapping to prevent recursion obj.bootstrap_nboot = 0 - + method = getattr(obj, method_name) result = method(*args, **kwargs) obj._rng = None return result + def bootstrap_loop(method): @wraps(method) def wrapper(self, *args, **kwargs): if not hasattr(self, "outcome_model"): self.outcome_model = [] start = time.perf_counter() - + results = [] full = method(self, *args, **kwargs) results.append(full) - - if getattr(self, "bootstrap_nboot") > 0 and getattr(self, "_boot_samples", None): + + if getattr(self, "bootstrap_nboot") > 0 and getattr( + self, "_boot_samples", None + ): original_DT = self.DT nboot = self.bootstrap_nboot ncores = self.ncores seed = getattr(self, "seed", None) method_name = method.__name__ - + if getattr(self, "parallel", False): original_rng = getattr(self, "_rng", None) self._rng = None - + with ProcessPoolExecutor(max_workers=ncores) as executor: futures = [ - executor.submit(_bootstrap_worker, self, method_name, original_DT, i, seed, args, kwargs) + executor.submit( + _bootstrap_worker, + self, + method_name, + original_DT, + i, + seed, + args, + kwargs, + ) for i in range(nboot) ] - for j in tqdm(as_completed(futures), total=nboot, desc="Bootstrapping..."): + for j in tqdm( + as_completed(futures), total=nboot, desc="Bootstrapping..." + ): results.append(j.result()) - + self._rng = original_rng else: for i in tqdm(range(nboot), desc="Bootstrapping..."): self.DT = _prepare_boot_data(self, original_DT, i) boot_fit = method(self, *args, **kwargs) results.append(boot_fit) - + self.DT = original_DT - + end = time.perf_counter() self._model_time = _format_time(start, end) - + self.outcome_model = results return results + return wrapper diff --git a/pySEQ/helpers/_col_string.py b/pySEQTarget/helpers/_col_string.py similarity index 92% rename from pySEQ/helpers/_col_string.py rename to pySEQTarget/helpers/_col_string.py index 62b9f16..907ef80 100644 --- a/pySEQ/helpers/_col_string.py +++ b/pySEQTarget/helpers/_col_string.py @@ -3,4 +3,4 @@ def _col_string(expressions): for expression in expressions: if expression is not None: cols.update(expression.replace("+", " ").replace("*", " ").split()) - return cols \ No newline at end of file + return cols diff --git a/pySEQ/helpers/_format_time.py b/pySEQTarget/helpers/_format_time.py similarity index 100% rename from pySEQ/helpers/_format_time.py rename to pySEQTarget/helpers/_format_time.py diff --git a/pySEQ/helpers/_pad.py b/pySEQTarget/helpers/_pad.py similarity index 100% rename from pySEQ/helpers/_pad.py rename to pySEQTarget/helpers/_pad.py diff --git a/pySEQ/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py similarity index 74% rename from pySEQ/helpers/_predict_model.py rename to pySEQTarget/helpers/_predict_model.py index 08ecc46..5ddd731 100644 --- a/pySEQ/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -1,8 +1,9 @@ import numpy as np + def _predict_model(self, model, newdata): newdata = newdata.to_pandas() for col in self.fixed_cols: if col in newdata.columns: - newdata[col] = newdata[col].astype("category") + newdata[col] = newdata[col].astype("category") return np.array(model.predict(newdata)) diff --git a/pySEQTarget/helpers/_prepare_data.py b/pySEQTarget/helpers/_prepare_data.py new file mode 100644 index 0000000..682e23d --- /dev/null +++ b/pySEQTarget/helpers/_prepare_data.py @@ -0,0 +1,19 @@ +import polars as pl + + +def _prepare_data(self, DT): + binaries = [ + self.eligible_col, + self.outcome_col, + self.cense_colname, + ] # self.excused_colnames + self.weight_eligible_colnames + binary_colnames = [col for col in binaries if col in DT.columns and not None] + + DT = DT.with_columns( + [ + *[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols], + *[pl.col(col).cast(pl.Int8) for col in binary_colnames], + pl.col(self.id_col).cast(pl.Utf8), + ] + ) + return DT diff --git a/pySEQTarget/initialization/__init__.py b/pySEQTarget/initialization/__init__.py new file mode 100644 index 0000000..4f026ca --- /dev/null +++ b/pySEQTarget/initialization/__init__.py @@ -0,0 +1,5 @@ +from ._censoring import _cense_denominator as _cense_denominator +from ._censoring import _cense_numerator as _cense_numerator +from ._denominator import _denominator as _denominator +from ._numerator import _numerator as _numerator +from ._outcome import _outcome as _outcome diff --git a/pySEQTarget/initialization/_censoring.py b/pySEQTarget/initialization/_censoring.py new file mode 100644 index 0000000..828d584 --- /dev/null +++ b/pySEQTarget/initialization/_censoring.py @@ -0,0 +1,53 @@ +def _cense_numerator(self) -> str: + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) + time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) + fixed = "+".join(self.fixed_cols) if self.fixed_cols else None + + if self.weight_preexpansion: + out = "+".join(filter(None, ["tx_lag", time, fixed])) + else: + out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv_bas])) + + return out + + +def _cense_denominator(self) -> str: + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) + time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) + tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) + fixed = "+".join(self.fixed_cols) if self.fixed_cols else None + + if self.weight_preexpansion: + out = "+".join(filter(None, ["tx_lag", time, fixed, tv])) + else: + out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv, tv_bas])) + + return out diff --git a/pySEQ/initialization/_denominator.py b/pySEQTarget/initialization/_denominator.py similarity index 70% rename from pySEQ/initialization/_denominator.py rename to pySEQTarget/initialization/_denominator.py index 4d74bdc..0a081b3 100644 --- a/pySEQ/initialization/_denominator.py +++ b/pySEQTarget/initialization/_denominator.py @@ -1,14 +1,26 @@ def _denominator(self) -> str: if self.method == "ITT": return - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - + tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: if self.method == "dose-response": out = "+".join(filter(None, [fixed, tv, time])) @@ -24,4 +36,4 @@ def _denominator(self) -> str: elif self.method == "censoring" and self.excused: out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial])) - return out \ No newline at end of file + return out diff --git a/pySEQ/initialization/_numerator.py b/pySEQTarget/initialization/_numerator.py similarity index 68% rename from pySEQ/initialization/_numerator.py rename to pySEQTarget/initialization/_numerator.py index 8f232d6..b47ba23 100644 --- a/pySEQ/initialization/_numerator.py +++ b/pySEQTarget/initialization/_numerator.py @@ -1,13 +1,25 @@ def _numerator(self) -> str: if self.method == "ITT": return - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: if self.method == "dose-response": out = "+".join(filter(None, [fixed, time])) @@ -22,4 +34,4 @@ def _numerator(self) -> str: out = "+".join(filter(None, [fixed, tv_bas, followup, trial])) elif self.method == "censoring" and self.excused: out = "+".join(filter(None, [fixed, tv_bas, followup, trial])) - return out \ No newline at end of file + return out diff --git a/pySEQ/initialization/_outcome.py b/pySEQTarget/initialization/_outcome.py similarity index 83% rename from pySEQ/initialization/_outcome.py rename to pySEQTarget/initialization/_outcome.py index 8e30e9f..4ec0fc9 100644 --- a/pySEQ/initialization/_outcome.py +++ b/pySEQTarget/initialization/_outcome.py @@ -2,28 +2,36 @@ def _outcome(self) -> str: tx_bas = f"{self.treatment_col}{self.indicator_baseline}" dose = "+".join(["dose", f"dose{self.indicator_squared}"]) interaction = f"{tx_bas}*followup" - interaction_dose = "+".join(["followup*dose", f"followup*dose{self.indicator_squared}"]) - + interaction_dose = "+".join( + ["followup*dose", f"followup*dose{self.indicator_squared}"] + ) + if self.hazard or not self.km_curves: interaction = interaction_dose = None - + tv_bas = ( - "+".join([f"{v}_bas" for v in self.time_varying_cols]) if self.time_varying_cols else None + "+".join([f"{v}_bas" for v in self.time_varying_cols]) + if self.time_varying_cols + else None ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + if self.followup_include: followup = "+".join(["followup", f"followup{self.indicator_squared}"]) elif (self.followup_spline or self.followup_class) and not self.followup_include: followup = "followup" else: followup = None - + if self.method == "ITT": parts = [tx_bas, followup, trial, fixed, tv_bas, interaction] return "+".join(filter(None, parts)) - + if self.weighted: if self.weight_preexpansion: if self.method == "dose-response": @@ -39,7 +47,7 @@ def _outcome(self) -> str: elif self.method == "censoring": parts = [tx_bas, followup, trial, fixed, tv_bas, interaction] return "+".join(filter(None, parts)) - + if self.method == "dose-response": parts = [dose, followup, trial, fixed, tv_bas, interaction_dose] elif self.method == "censoring": diff --git a/pySEQTarget/plot/__init__.py b/pySEQTarget/plot/__init__.py new file mode 100644 index 0000000..f417f55 --- /dev/null +++ b/pySEQTarget/plot/__init__.py @@ -0,0 +1 @@ +from ._survival_plot import _survival_plot as _survival_plot diff --git a/pySEQ/plot/_survival_plot.py b/pySEQTarget/plot/_survival_plot.py similarity index 80% rename from pySEQ/plot/_survival_plot.py rename to pySEQTarget/plot/_survival_plot.py index 9641da8..0592036 100644 --- a/pySEQ/plot/_survival_plot.py +++ b/pySEQTarget/plot/_survival_plot.py @@ -1,16 +1,18 @@ import itertools + import matplotlib.pyplot as plt -import polars as pl import numpy as np +import polars as pl + def _survival_plot(self): if self.plot_type == "risk": plot_data = self.km_data.filter(pl.col("estimate") == "risk") - elif self.plot_type == "survival": + elif self.plot_type == "survival": plot_data = self.km_data.filter(pl.col("estimate") == "survival") else: plot_data = self.km_data.filter(pl.col("estimate") == "incidence") - + if self.subgroup_colname is None: _plot_single(self, plot_data) else: @@ -20,10 +22,10 @@ def _survival_plot(self): def _plot_single(self, plot_data): plt.figure(figsize=(10, 6)) _plot_data(self, plot_data, plt.gca()) - + if self.plot_title is None: self.plot_title = f"Cumulative {self.plot_type.title()}" - + plt.xlabel("Followup") plt.ylabel(self.plot_type.title()) plt.title(self.plot_title) @@ -35,60 +37,62 @@ def _plot_single(self, plot_data): def _plot_subgroups(self, plot_data): subgroups = sorted(plot_data[self.subgroup_colname].unique().to_list()) n_subgroups = len(subgroups) - + n_cols = min(3, n_subgroups) n_rows = (n_subgroups + n_cols - 1) // n_cols - - fig, axes = plt.subplots(n_rows, n_cols, figsize=(7*n_cols, 6*n_rows)) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows)) axes = np.atleast_1d(axes).flatten() - + for idx, subgroup_val in enumerate(subgroups): ax = axes[idx] subgroup_data = plot_data.filter(pl.col(self.subgroup_colname) == subgroup_val) _plot_data(self, subgroup_data, ax) - - subgroup_label = str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val - + + subgroup_label = ( + str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val + ) + ax.set_xlabel("Followup") ax.set_ylabel(self.plot_type.title()) - ax.set_title(f"{self.subgroup_colname.title()}: {subgroup_label}", fontsize=10, style="italic") + ax.set_title( + f"{self.subgroup_colname.title()}: {subgroup_label}", + fontsize=10, + style="italic", + ) ax.legend() ax.grid() - + for idx in range(n_subgroups, len(axes)): axes[idx].set_visible(False) - + if self.plot_title: fig.suptitle(self.plot_title, fontsize=14) else: fig.suptitle(f"Cumulative {self.plot_type.title()}", fontsize=14) - + plt.tight_layout() plt.show() def _plot_data(self, plot_data, ax): color_cycle = itertools.cycle(self.plot_colors) if self.plot_colors else None - + for idx, i in enumerate(self.treatment_level): subset = plot_data.filter(pl.col(self.treatment_col) == i) if subset.is_empty(): continue - + label = f"treatment = {i}" if self.plot_labels and idx < len(self.plot_labels): label = self.plot_labels[idx] - + color = next(color_cycle) if color_cycle else None - - line, = ax.plot( - subset["followup"], - subset["pred"], - "-", - label=label, - color=color + + (line,) = ax.plot( + subset["followup"], subset["pred"], "-", label=label, color=color ) - + if "LCI" in subset.columns and "UCI" in subset.columns: ax.fill_between( subset["followup"], @@ -96,5 +100,5 @@ def _plot_data(self, plot_data, ax): subset["UCI"], color=line.get_color(), alpha=0.2, - label="_nolegend_" + label="_nolegend_", ) diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py new file mode 100644 index 0000000..4874865 --- /dev/null +++ b/pySEQTarget/weighting/__init__.py @@ -0,0 +1,7 @@ +from ._weight_bind import _weight_bind as _weight_bind +from ._weight_data import _weight_setup as _weight_setup +from ._weight_fit import _fit_denominator as _fit_denominator +from ._weight_fit import _fit_LTFU as _fit_LTFU +from ._weight_fit import _fit_numerator as _fit_numerator +from ._weight_pred import _weight_predict as _weight_predict +from ._weight_stats import _weight_stats as _weight_stats diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py new file mode 100644 index 0000000..8f8f4a8 --- /dev/null +++ b/pySEQTarget/weighting/_weight_bind.py @@ -0,0 +1,68 @@ +import polars as pl + + +def _weight_bind(self, WDT): + if self.weight_preexpansion: + join = "inner" + on = [self.id_col, "period"] + WDT = WDT.rename({self.time_col: "period"}) + else: + join = "left" + on = [self.id_col, "trial", "followup"] + + WDT = self.DT.join(WDT, on=on, how=join) + + if self.weight_preexpansion and self.excused: + trial = (pl.col("trial") == 0) & (pl.col("period") == 0) + excused = ( + pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) + > 0 + ) + override = ( + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-7) + ) + elif not self.weight_preexpansion and self.excused: + trial = pl.col("followup") == 0 + excused = ( + pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) + > 0 + ) + override = ( + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-7) + | (pl.col("numerator") < 1e-7) + ) + else: + trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & ( + pl.col("followup") == 0 + ) + excused = pl.lit(False) + override = ( + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-15) + | pl.col("numerator").is_null() + ) + + self.DT = ( + WDT.with_columns( + pl.when(override) + .then(pl.lit(1.0)) + .otherwise(pl.col("numerator") / pl.col("denominator")) + .alias("wt") + ) + .sort([self.id_col, "trial", "followup"]) + .with_columns( + pl.col("wt") + .fill_null(1.0) + .cum_prod() + .over([self.id_col, "trial"]) + .alias("weight") + ) + ) diff --git a/pySEQTarget/weighting/_weight_data.py b/pySEQTarget/weighting/_weight_data.py new file mode 100644 index 0000000..72f8ae7 --- /dev/null +++ b/pySEQTarget/weighting/_weight_data.py @@ -0,0 +1,47 @@ +import polars as pl + + +def _weight_setup(self): + DT = self.DT + data = self.data + if not self.weight_preexpansion: + baseline_lag = ( + data.select([self.treatment_col, self.id_col, self.time_col]) + .sort([self.id_col, self.time_col]) + .with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over(self.id_col) + .alias("tx_lag") + ) + .drop(self.treatment_col) + .rename({self.time_col: "period"}) + ) + + fup0 = DT.filter(pl.col("followup") == 0).join( + baseline_lag, on=[self.id_col, "period"], how="inner" + ) + + fup = ( + DT.sort([self.id_col, "trial", "followup"]) + .with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over([self.id_col, "trial"]) + .alias("tx_lag") + ) + .filter(pl.col("followup") != 0) + ) + + WDT = pl.concat([fup0, fup]).sort([self.id_col, "trial", "followup"]) + else: + WDT = data.with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over(self.id_col) + .alias("tx_lag"), + (pl.col(self.time_col) ** 2).alias( + f"{self.time_col}{self.indicator_squared}" + ), + ) + return WDT diff --git a/pySEQ/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py similarity index 77% rename from pySEQ/weighting/_weight_fit.py rename to pySEQTarget/weighting/_weight_fit.py index ea0ef71..70fc85d 100644 --- a/pySEQ/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -1,6 +1,7 @@ import statsmodels.api as sm import statsmodels.formula.api as smf + def _fit_LTFU(self, WDT): if self.cense_colname is None: return @@ -8,20 +9,17 @@ def _fit_LTFU(self, WDT): fits = [] if self.cense_eligible_colname is not None: WDT = WDT[WDT[self.cense_eligible_colname] == 1] - + for i in [self.cense_numerator, self.cense_denominator]: formula = f"{self.cense_colname}~{i}" - model = smf.glm( - formula, - WDT, - family=sm.families.Binomial() - ) + model = smf.glm(formula, WDT, family=sm.families.Binomial()) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.cense_numerator = fits[0] self.cense_denominator = fits[1] + def _fit_numerator(self, WDT): if self.weight_preexpansion and self.excused: return @@ -29,7 +27,9 @@ def _fit_numerator(self, WDT): return predictor = "switch" if self.excused else self.treatment_col formula = f"{predictor}~{self.numerator}" - tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" + tx_bas = ( + f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" + ) fits = [] for i, level in enumerate(self.treatment_level): if self.excused and self.excused_colnames[i] is not None: @@ -40,20 +40,22 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - - model = smf.mnlogit( - formula, - DT_subset - ) + + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.numerator_model = fits - + + def _fit_denominator(self, WDT): if self.method == "ITT": return - predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col + predictor = ( + "switch" + if self.excused and not self.weight_preexpansion + else self.treatment_col + ) formula = f"{predictor}~{self.denominator}" fits = [] for i, level in enumerate(self.treatment_level): @@ -62,18 +64,14 @@ def _fit_denominator(self, WDT): else: DT_subset = WDT if self.weight_lag_condition: - DT_subset = DT_subset[DT_subset["tx_lag"] == level] + DT_subset = DT_subset[DT_subset["tx_lag"] == level] if not self.weight_preexpansion and not self.excused: DT_subset = DT_subset[DT_subset["followup"] != 0] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - - model = smf.mnlogit( - formula, - DT_subset - ) + + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.denominator_model = fits - \ No newline at end of file diff --git a/pySEQ/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py similarity index 61% rename from pySEQ/weighting/_weight_pred.py rename to pySEQTarget/weighting/_weight_pred.py index aff0c6f..5a858a8 100644 --- a/pySEQ/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -1,27 +1,27 @@ -from ..helpers import _predict_model -import polars as pl import numpy as np +import polars as pl + +from ..helpers import _predict_model + def _weight_predict(self, WDT): grouping = [self.id_col] grouping += ["trial"] if not self.weight_preexpansion else [] time = self.time_col if self.weight_preexpansion else "followup" - + if self.method == "ITT": - WDT = WDT.with_columns([ - pl.lit(1.).alias("numerator"), - pl.lit(1.).alias("denominator") - ]) + WDT = WDT.with_columns( + [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] + ) else: - WDT = WDT.with_columns([ - pl.lit(1.).alias("numerator"), - pl.lit(1.).alias("denominator") - ]) - + WDT = WDT.with_columns( + [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] + ) + for i, level in enumerate(self.treatment_level): mask = pl.col("tx_lag") == level lag_mask = (WDT["tx_lag"] == level).to_numpy() - + if self.denominator_model[i] is not None: pred_denom = np.ones(WDT.height) if lag_mask.sum() > 0: @@ -30,11 +30,13 @@ def _weight_predict(self, WDT): if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] - switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy() - pred_denom[lag_mask] = np.where(switched_treatment, 1. - p, p) + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) else: pred_denom = np.ones(WDT.height) - + if hasattr(self, "numerator_model") and self.numerator_model[i] is not None: pred_num = np.ones(WDT.height) if lag_mask.sum() > 0: @@ -43,34 +45,40 @@ def _weight_predict(self, WDT): if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] - switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy() - pred_num[lag_mask] = np.where(switched_treatment, 1. - p, p) + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) else: pred_num = np.ones(WDT.height) - - WDT = WDT.with_columns([ - pl.when(mask) - .then(pl.Series(pred_num)) - .otherwise(pl.col("numerator")) - .alias("numerator"), - pl.when(mask) - .then(pl.Series(pred_denom)) - .otherwise(pl.col("denominator")) - .alias("denominator") - ]) - + + WDT = WDT.with_columns( + [ + pl.when(mask) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator"), + pl.when(mask) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator"), + ] + ) + if self.cense_colname is not None: p_num = _predict_model(self, self.cense_numerator, WDT).flatten() p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() - WDT = WDT.with_columns([ - pl.Series("cense_numerator", p_num), - pl.Series("cense_denominator", p_denom) - ]).with_columns( + WDT = WDT.with_columns( + [ + pl.Series("cense_numerator", p_num), + pl.Series("cense_denominator", p_denom), + ] + ).with_columns( (pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense") ) else: - WDT = WDT.with_columns(pl.lit(1.).alias("cense")) - + WDT = WDT.with_columns(pl.lit(1.0).alias("cense")) + kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"] exists = [col for col in kept if col in WDT.columns] return WDT.select(exists).sort(grouping + [time]) diff --git a/pySEQTarget/weighting/_weight_stats.py b/pySEQTarget/weighting/_weight_stats.py new file mode 100644 index 0000000..b6da331 --- /dev/null +++ b/pySEQTarget/weighting/_weight_stats.py @@ -0,0 +1,23 @@ +import polars as pl + + +def _weight_stats(self): + stats = self.DT.select( + [ + pl.col("weight").min().alias("weight_min"), + pl.col("weight").max().alias("weight_max"), + pl.col("weight").mean().alias("weight_mean"), + pl.col("weight").std().alias("weight_std"), + pl.col("weight").quantile(0.01).alias("weight_p01"), + pl.col("weight").quantile(0.25).alias("weight_p25"), + pl.col("weight").quantile(0.50).alias("weight_p50"), + pl.col("weight").quantile(0.75).alias("weight_p75"), + pl.col("weight").quantile(0.99).alias("weight_p99"), + ] + ) + + if self.weight_p99: + self.weight_min = stats.select("weight_p01").item() + self.weight_max = stats.select("weight_p99").item() + + return stats diff --git a/pyproject.toml b/pyproject.toml index 956a738..9b83517 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=60", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "pySEQ" +name = "pySEQTarget" version = "0.9.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" @@ -41,9 +41,9 @@ dependencies = [ ] [project.urls] -Homepage = "https://github.com/CausalInference/pySEQ" -Repository = "https://github.com/CausalInference/pySEQ" -"Bug Tracker" = "https://github.com/CausalInference/pySEQ/issues" +Homepage = "https://github.com/CausalInference/pySEQTarget" +Repository = "https://github.com/CausalInference/pySEQTarget" +"Bug Tracker" = "https://github.com/CausalInference/pySEQTarget/issues" "Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546" "Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X" @@ -53,7 +53,7 @@ Repository = "https://github.com/CausalInference/pySEQ" "Harvard University (ROR)" = "https://ror.org/03vek6s52" [tool.setuptools] -packages = ["pySEQ", "pySEQ.data"] +packages = ["pySEQTarget", "pySEQTarget.data"] [tool.setuptools.package-data] SEQdata = ["data/*.csv"] diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 317ef5c..ab9b796 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -1,10 +1,12 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data import pytest +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + def test_ITT_collector(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -14,12 +16,12 @@ def test_ITT_collector(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), ) s.expand() s.fit() collector = s.collect() - outcomes = collector.retrieve_data("unique_outcomes") + collector.retrieve_data("unique_outcomes") with pytest.raises(ValueError): - collector.retrieve_data("km_data") \ No newline at end of file + collector.retrieve_data("km_data") diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index f240137..07435c7 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,9 +1,10 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_ITT_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,21 +14,30 @@ def test_ITT_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.828506035553407, 0.18935003090041902, 0.12717241010542563, - 0.033715156987629266, -0.00014691202235029346, 0.044566165558944326, - 0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771, - 0.20072409918428052] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.828506035553407, + 0.18935003090041902, + 0.12717241010542563, + 0.033715156987629266, + -0.00014691202235029346, + 0.044566165558944326, + 0.0005787770439053261, + 0.0032906669395295026, + -0.01339242049205771, + 0.20072409918428052, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_dose_response_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -37,21 +47,28 @@ def test_PreE_dose_response_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-4.842735939069144, 0.14286755770151904, 0.055221018477671927, - -0.000581657931537684, -0.008484541900408258, 0.00021073328759912806, - 0.010537967151467553, 0.0007772316818101141] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -4.842735939069144, + 0.14286755770151904, + 0.055221018477671927, + -0.000581657931537684, + -0.008484541900408258, + 0.00021073328759912806, + 0.010537967151467553, + 0.0007772316818101141, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_dose_response_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -61,22 +78,32 @@ def test_PostE_dose_response_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True) + method="dose-response", + parameters=SEQopts(weighted=True), ) - + s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.265901713761531, 0.14065954021957594, 0.048626017624679704, - -0.0004688287307505834, -0.003975906839775267, 0.00016676441745740924, - 0.03866279977096397, 0.0005928449623613982, 0.0030001459817949844, - -0.02106338184559446, 0.14867250693140854] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.265901713761531, + 0.14065954021957594, + 0.048626017624679704, + -0.0004688287307505834, + -0.003975906839775267, + 0.00016676441745740924, + 0.03866279977096397, + 0.0005928449623613982, + 0.0030001459817949844, + -0.02106338184559446, + 0.14867250693140854, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_censoring_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -86,21 +113,27 @@ def test_PreE_censoring_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="censoring", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-4.818288687908951, 0.511665606890965, 0.062028316788368384, - 0.025489681857269905, 0.00018215948440046585, -0.014019017637918164, - 0.001110238926667272] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -4.818288687908951, + 0.511665606890965, + 0.062028316788368384, + 0.025489681857269905, + 0.00018215948440046585, + -0.014019017637918164, + 0.001110238926667272, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_censoring_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -110,21 +143,30 @@ def test_PostE_censoring_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True) + method="censoring", + parameters=SEQopts(weighted=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-7.9113179326280445, 0.49092190701455873, 0.08903087485402544, - 0.026160806382879903, 0.00019078148503570062, 0.04445697224987294, - 0.0007051968052005897, 0.004316239095295115, 0.013762799304812959, - 0.3196331024454665] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -7.9113179326280445, + 0.49092190701455873, + 0.08903087485402544, + 0.026160806382879903, + 0.00019078148503570062, + 0.04445697224987294, + 0.0007051968052005897, + 0.004316239095295115, + 0.013762799304812959, + 0.3196331024454665, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_censoring_excused_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -134,22 +176,31 @@ def test_PreE_censoring_excused_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.175691049418418, 1.3493634846413598, 0.1072284696749134, - -0.003977965364113033, 0.06959432825811135, -0.00034297574787048573] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.175691049418418, + 1.3493634846413598, + 0.1072284696749134, + -0.003977965364113033, + 0.06959432825811135, + -0.00034297574787048573, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_censoring_excused_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -159,24 +210,35 @@ def test_PostE_censoring_excused_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"], - weight_max=1) + method="censoring", + parameters=SEQopts( + weighted=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + weight_max=1, + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-7.126398786875262, 0.2632047482928519, 0.13345454814736696, - 0.03967181206032395, -0.00033089446793392585, 0.03763545026332514, - 0.0007588725152627089, 0.0036793093608788923, -0.022372677571544992, - 0.2441842617520696] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -7.126398786875262, + 0.2632047482928519, + 0.13345454814736696, + 0.03967181206032395, + -0.00033089446793392585, + 0.03763545026332514, + 0.0007588725152627089, + 0.0036793093608788923, + -0.022372677571544992, + 0.2441842617520696, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] - + + def test_PreE_LTFU_ITT(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -186,23 +248,32 @@ def test_PreE_LTFU_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - cense_colname="LTFU") + method="ITT", + parameters=SEQopts( + weighted=True, weight_preexpansion=True, cense_colname="LTFU" + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-21.640523091572796, 0.0685235184372898, -0.19006360662228572, - 0.028750950193838918, -0.0005762057433736666, 0.28554312978583757, - -0.001373044229623057, 0.006589141394458155, -0.44898959259422394, - 1.3875089788036237] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -21.640523091572796, + 0.0685235184372898, + -0.19006360662228572, + 0.028750950193838918, + -0.0005762057433736666, + 0.28554312978583757, + -0.001373044229623057, + 0.006589141394458155, + -0.44898959259422394, + 1.3875089788036237, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_LTFU_ITT(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -212,22 +283,30 @@ def test_PostE_LTFU_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(weighted=True, - cense_colname="LTFU") + method="ITT", + parameters=SEQopts(weighted=True, cense_colname="LTFU"), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-21.640523091572796, 0.0685235184372898, -0.19006360662228572, - 0.028750950193838918, -0.0005762057433736666, 0.28554312978583757, - -0.001373044229623057, 0.006589141394458155, -0.44898959259422394, - 1.3875089788036237] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -21.640523091572796, + 0.0685235184372898, + -0.19006360662228572, + 0.028750950193838918, + -0.0005762057433736666, + 0.28554312978583757, + -0.001373044229623057, + 0.006589141394458155, + -0.44898959259422394, + 1.3875089788036237, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_ITT_multinomial(): data = load_data("SEQdata_multitreatment") - + s = SEQuential( data, id_col="ID", @@ -237,22 +316,31 @@ def test_ITT_multinomial(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(treatment_level=[1,2]) + method="ITT", + parameters=SEQopts(treatment_level=[1, 2]), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-47.505262164163625, 1.76628017234151, 22.79205044396338, - 0.14473536056627245, -0.003725499516376173, 0.2893070991930884, - -0.004266608123938117, 0.05574429164512122, 0.7847862691929901, - 1.4703411759229423] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -47.505262164163625, + 1.76628017234151, + 22.79205044396338, + 0.14473536056627245, + -0.003725499516376173, + 0.2893070991930884, + -0.004266608123938117, + 0.05574429164512122, + 0.7847862691929901, + 1.4703411759229423, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] - + + def test_weighted_multinomial(): - data = load_data("SEQdata_multitreatment") - - s = SEQuential( + data = load_data("SEQdata_multitreatment") + + s = SEQuential( data, id_col="ID", time_col="time", @@ -261,15 +349,21 @@ def test_weighted_multinomial(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted = True, - weight_preexpansion=True, - treatment_level=[1,2]) + method="censoring", + parameters=SEQopts( + weighted=True, weight_preexpansion=True, treatment_level=[1, 2] + ), ) - s.expand() - s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-111.35419661939163, -12.571187230338328, 9.234157699403015, - -0.6336774763031923, 0.016754692338530056, 5.8240772329087225, - -0.08598454090661659] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] \ No newline at end of file + s.expand() + s.fit() + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -111.35419661939163, + -12.571187230338328, + 9.234157699403015, + -0.6336774763031923, + 0.016754692338530056, + 5.8240772329087225, + -0.08598454090661659, + ] + assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] diff --git a/tests/test_covariates.py b/tests/test_covariates.py index f07b2c4..9863f8f 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -1,9 +1,10 @@ -from pySEQ.data import load_data -from pySEQ import SEQuential, SEQopts +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_ITT_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,18 +14,22 @@ def test_ITT_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), + ) + + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator is None assert s.denominator is None return + def test_PreE_dose_response_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -34,19 +39,19 @@ def test_PreE_dose_response_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) - + assert s.covariates == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex" assert s.numerator == "sex+time+time_sq" assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_dose_response_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -56,17 +61,24 @@ def test_PostE_dose_response_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True) + method="dose-response", + parameters=SEQopts(weighted=True), + ) + assert ( + s.covariates + == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) return + def test_PreE_censoring_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -76,18 +88,18 @@ def test_PreE_censoring_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="censoring", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex" assert s.numerator == "sex+time+time_sq" assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_censoring_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -97,18 +109,25 @@ def test_PostE_censoring_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True) + method="censoring", + parameters=SEQopts(weighted=True), + ) + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) + return + def test_PreE_censoring_excused_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -118,20 +137,23 @@ def test_PreE_censoring_excused_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + ), ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq" assert s.numerator is None assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_censoring_excused_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -141,13 +163,18 @@ def test_PostE_censoring_excused_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, excused=True, excused_colnames=["excusedZero", "excusedOne"] + ), + ) + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) return - \ No newline at end of file diff --git a/tests/test_followup_options.py b/tests/test_followup_options.py index 2cee22c..74d5451 100644 --- a/tests/test_followup_options.py +++ b/tests/test_followup_options.py @@ -1,9 +1,10 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_followup_class(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,24 +14,33 @@ def test_followup_class(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_class=True, - followup_include=False, - followup_max=5) + method="ITT", + parameters=SEQopts(followup_class=True, followup_include=False, followup_max=5), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.6000834193414635, 0.36024705241286203, 0.04326409573404126, - 0.07627958175273072, 0.11375627612408938, 0.14496108664292745, - 0.1798424095611678, 0.09066206802273916, 0.015738693166264354, - 0.0009258560187318309, 0.011267393559242982, 0.022194521411244304, - 0.19115237121222872] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.6000834193414635, + 0.36024705241286203, + 0.04326409573404126, + 0.07627958175273072, + 0.11375627612408938, + 0.14496108664292745, + 0.1798424095611678, + 0.09066206802273916, + 0.015738693166264354, + 0.0009258560187318309, + 0.011267393559242982, + 0.022194521411244304, + 0.19115237121222872, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_followup_spline(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -40,22 +50,31 @@ def test_followup_spline(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_spline=True, - followup_include=False) + method="ITT", + parameters=SEQopts(followup_spline=True, followup_include=False), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.264817962084417, 0.20125056343026881, 0.12568743032952776, - 0.03823426390103046, 0.0006607691746414019, 0.003343365539743267, - -0.01319460158923785, 0.19601796921732118, -0.5186462478511427, - 0.37598656666756103, 1.6553848469346044] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.264817962084417, + 0.20125056343026881, + 0.12568743032952776, + 0.03823426390103046, + 0.0006607691746414019, + 0.003343365539743267, + -0.01319460158923785, + 0.19601796921732118, + -0.5186462478511427, + 0.37598656666756103, + 1.6553848469346044, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_no_followup(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -65,13 +84,20 @@ def test_no_followup(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_include=False) + method="ITT", + parameters=SEQopts(followup_include=False), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.062350570326165, 0.17748844870984498, 0.11209431124681817, - 0.03344595751001804, 0.0005457002039545119, 0.0032236473201563585, - -0.014463448024337773, 0.20398559747503964] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] \ No newline at end of file + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.062350570326165, + 0.17748844870984498, + 0.11209431124681817, + 0.03344595751001804, + 0.0005457002039545119, + 0.0032236473201563585, + -0.014463448024337773, + 0.20398559747503964, + ] + assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] diff --git a/tests/test_hazard.py b/tests/test_hazard.py index a7e0cf8..f057dfd 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,9 +1,10 @@ -from pySEQ.data import load_data -from pySEQ import SEQuential, SEQopts +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_ITT_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,16 +14,17 @@ def test_ITT_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True) + method="ITT", + parameters=SEQopts(hazard_estimate=True), ) s.expand() s.fit() s.hazard() + def test_bootstrap_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -32,18 +34,18 @@ def test_bootstrap_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True, - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2), ) s.expand() s.bootstrap() s.fit() s.hazard() - + + def test_subgroup_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -53,9 +55,8 @@ def test_subgroup_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True, - subgroup_colname="sex") + method="ITT", + parameters=SEQopts(hazard_estimate=True, subgroup_colname="sex"), ) s.expand() s.bootstrap() diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 3d691fd..2ed0351 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,15 +1,17 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data import os + import pytest +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + @pytest.mark.skipif( - os.getenv('CI') == 'true', - reason="Parallelism test hangs in CI environment" + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" ) def test_parallel_ITT(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -19,16 +21,22 @@ def test_parallel_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(parallel=True, - bootstrap_nboot=2, - ncores=1) + method="ITT", + parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1), ) s.expand() s.bootstrap() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - assert matrix == [-6.828506035553407, 0.18935003090041902, 0.12717241010542563, - 0.033715156987629266, -0.00014691202235029346, 0.044566165558944326, - 0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771, - 0.20072409918428052] \ No newline at end of file + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + assert matrix == [ + -6.828506035553407, + 0.18935003090041902, + 0.12717241010542563, + 0.033715156987629266, + -0.00014691202235029346, + 0.044566165558944326, + 0.0005787770439053261, + 0.0032906669395295026, + -0.01339242049205771, + 0.20072409918428052, + ] diff --git a/tests/test_survival.py b/tests/test_survival.py index 4466655..0e4ddfe 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,9 +1,10 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_regular_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,17 +14,18 @@ def test_regular_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True) + method="ITT", + parameters=SEQopts(km_curves=True), ) s.expand() s.fit() s.survival() return - + + def test_bootstrapped_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -33,9 +35,8 @@ def test_bootstrapped_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(km_curves=True, bootstrap_nboot=2), ) s.expand() s.bootstrap() @@ -43,9 +44,10 @@ def test_bootstrapped_survival(): s.survival() return + def test_subgroup_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -55,18 +57,18 @@ def test_subgroup_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - subgroup_colname="sex") + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex"), ) s.expand() s.fit() s.survival() return - + + def test_subgroup_bootstrapped_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -76,10 +78,8 @@ def test_subgroup_bootstrapped_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - subgroup_colname="sex", - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2), ) s.expand() s.bootstrap() @@ -87,9 +87,10 @@ def test_subgroup_bootstrapped_survival(): s.survival() return + def test_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -99,19 +100,20 @@ def test_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence") + method="ITT", + parameters=SEQopts( + km_curves=True, compevent_colname="LTFU", plot_type="incidence" + ), ) s.expand() s.fit() s.survival() return - + + def test_bootstrapped_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -121,21 +123,24 @@ def test_bootstrapped_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence", - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts( + km_curves=True, + compevent_colname="LTFU", + plot_type="incidence", + bootstrap_nboot=2, + ), ) s.expand() s.bootstrap() s.fit() s.survival() return - + + def test_subgroup_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -145,11 +150,13 @@ def test_subgroup_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence", - subgroup_colname = "sex") + method="ITT", + parameters=SEQopts( + km_curves=True, + compevent_colname="LTFU", + plot_type="incidence", + subgroup_colname="sex", + ), ) s.expand() s.bootstrap()