From 0fc0c1223aa2c1d53c96ecfd5a185c939c4a3568 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:52:10 +0100 Subject: [PATCH 1/9] rename package --- .github/workflows/pre-commit.yml | 0 {pySEQ => pySEQTarget}/SEQopts.py | 0 {pySEQ => pySEQTarget}/SEQoutput.py | 0 {pySEQ => pySEQTarget}/SEQuential.py | 0 {pySEQ => pySEQTarget}/__init__.py | 0 {pySEQ => pySEQTarget}/analysis/__init__.py | 0 {pySEQ => pySEQTarget}/analysis/_hazard.py | 0 {pySEQ => pySEQTarget}/analysis/_outcome_fit.py | 0 {pySEQ => pySEQTarget}/analysis/_risk_estimates.py | 0 {pySEQ => pySEQTarget}/analysis/_subgroup_fit.py | 0 {pySEQ => pySEQTarget}/analysis/_survival_pred.py | 0 {pySEQ => pySEQTarget}/data/SEQdata.csv | 0 {pySEQ => pySEQTarget}/data/SEQdata_LTFU.csv | 0 {pySEQ => pySEQTarget}/data/SEQdata_multitreatment.csv | 0 {pySEQ => pySEQTarget}/data/__init__.py | 0 {pySEQ => pySEQTarget}/docs/Makefile | 0 {pySEQ => pySEQTarget}/docs/make.bat | 0 {pySEQ => pySEQTarget}/docs/source/conf.py | 0 {pySEQ => pySEQTarget}/docs/source/index.rst | 0 {pySEQ => pySEQTarget}/error/__init__.py | 0 {pySEQ => pySEQTarget}/error/_datachecker.py | 0 {pySEQ => pySEQTarget}/error/_param_checker.py | 0 {pySEQ => pySEQTarget}/expansion/__init__.py | 0 {pySEQ => pySEQTarget}/expansion/_binder.py | 0 {pySEQ => pySEQTarget}/expansion/_diagnostics.py | 0 {pySEQ => pySEQTarget}/expansion/_dynamic.py | 0 {pySEQ => pySEQTarget}/expansion/_mapper.py | 0 {pySEQ => pySEQTarget}/expansion/_selection.py | 0 {pySEQ => pySEQTarget}/helpers/__init__.py | 0 {pySEQ => pySEQTarget}/helpers/_bootstrap.py | 0 {pySEQ => pySEQTarget}/helpers/_col_string.py | 0 {pySEQ => pySEQTarget}/helpers/_format_time.py | 0 {pySEQ => pySEQTarget}/helpers/_pad.py | 0 {pySEQ => pySEQTarget}/helpers/_predict_model.py | 0 {pySEQ => pySEQTarget}/helpers/_prepare_data.py | 0 {pySEQ => pySEQTarget}/initialization/__init__.py | 0 {pySEQ => pySEQTarget}/initialization/_censoring.py | 0 {pySEQ => pySEQTarget}/initialization/_denominator.py | 0 {pySEQ => pySEQTarget}/initialization/_numerator.py | 0 {pySEQ => pySEQTarget}/initialization/_outcome.py | 0 {pySEQ => pySEQTarget}/plot/__init__.py | 0 {pySEQ => pySEQTarget}/plot/_survival_plot.py | 0 {pySEQ => pySEQTarget}/weighting/__init__.py | 0 {pySEQ => pySEQTarget}/weighting/_weight_bind.py | 0 {pySEQ => pySEQTarget}/weighting/_weight_data.py | 0 {pySEQ => pySEQTarget}/weighting/_weight_fit.py | 0 {pySEQ => pySEQTarget}/weighting/_weight_pred.py | 0 {pySEQ => pySEQTarget}/weighting/_weight_stats.py | 0 pyproject.toml | 2 +- 49 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 .github/workflows/pre-commit.yml rename {pySEQ => pySEQTarget}/SEQopts.py (100%) rename {pySEQ => pySEQTarget}/SEQoutput.py (100%) rename {pySEQ => pySEQTarget}/SEQuential.py (100%) rename {pySEQ => pySEQTarget}/__init__.py (100%) rename {pySEQ => pySEQTarget}/analysis/__init__.py (100%) rename {pySEQ => pySEQTarget}/analysis/_hazard.py (100%) rename {pySEQ => pySEQTarget}/analysis/_outcome_fit.py (100%) rename {pySEQ => pySEQTarget}/analysis/_risk_estimates.py (100%) rename {pySEQ => pySEQTarget}/analysis/_subgroup_fit.py (100%) rename {pySEQ => pySEQTarget}/analysis/_survival_pred.py (100%) rename {pySEQ => pySEQTarget}/data/SEQdata.csv (100%) rename {pySEQ => pySEQTarget}/data/SEQdata_LTFU.csv (100%) rename {pySEQ => pySEQTarget}/data/SEQdata_multitreatment.csv (100%) rename {pySEQ => pySEQTarget}/data/__init__.py (100%) rename {pySEQ => pySEQTarget}/docs/Makefile (100%) rename {pySEQ => pySEQTarget}/docs/make.bat (100%) rename {pySEQ => pySEQTarget}/docs/source/conf.py (100%) rename {pySEQ => pySEQTarget}/docs/source/index.rst (100%) rename {pySEQ => pySEQTarget}/error/__init__.py (100%) rename {pySEQ => pySEQTarget}/error/_datachecker.py (100%) rename {pySEQ => pySEQTarget}/error/_param_checker.py (100%) rename {pySEQ => pySEQTarget}/expansion/__init__.py (100%) rename {pySEQ => pySEQTarget}/expansion/_binder.py (100%) rename {pySEQ => pySEQTarget}/expansion/_diagnostics.py (100%) rename {pySEQ => pySEQTarget}/expansion/_dynamic.py (100%) rename {pySEQ => pySEQTarget}/expansion/_mapper.py (100%) rename {pySEQ => pySEQTarget}/expansion/_selection.py (100%) rename {pySEQ => pySEQTarget}/helpers/__init__.py (100%) rename {pySEQ => pySEQTarget}/helpers/_bootstrap.py (100%) rename {pySEQ => pySEQTarget}/helpers/_col_string.py (100%) rename {pySEQ => pySEQTarget}/helpers/_format_time.py (100%) rename {pySEQ => pySEQTarget}/helpers/_pad.py (100%) rename {pySEQ => pySEQTarget}/helpers/_predict_model.py (100%) rename {pySEQ => pySEQTarget}/helpers/_prepare_data.py (100%) rename {pySEQ => pySEQTarget}/initialization/__init__.py (100%) rename {pySEQ => pySEQTarget}/initialization/_censoring.py (100%) rename {pySEQ => pySEQTarget}/initialization/_denominator.py (100%) rename {pySEQ => pySEQTarget}/initialization/_numerator.py (100%) rename {pySEQ => pySEQTarget}/initialization/_outcome.py (100%) rename {pySEQ => pySEQTarget}/plot/__init__.py (100%) rename {pySEQ => pySEQTarget}/plot/_survival_plot.py (100%) rename {pySEQ => pySEQTarget}/weighting/__init__.py (100%) rename {pySEQ => pySEQTarget}/weighting/_weight_bind.py (100%) rename {pySEQ => pySEQTarget}/weighting/_weight_data.py (100%) rename {pySEQ => pySEQTarget}/weighting/_weight_fit.py (100%) rename {pySEQ => pySEQTarget}/weighting/_weight_pred.py (100%) rename {pySEQ => pySEQTarget}/weighting/_weight_stats.py (100%) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..e69de29 diff --git a/pySEQ/SEQopts.py b/pySEQTarget/SEQopts.py similarity index 100% rename from pySEQ/SEQopts.py rename to pySEQTarget/SEQopts.py diff --git a/pySEQ/SEQoutput.py b/pySEQTarget/SEQoutput.py similarity index 100% rename from pySEQ/SEQoutput.py rename to pySEQTarget/SEQoutput.py diff --git a/pySEQ/SEQuential.py b/pySEQTarget/SEQuential.py similarity index 100% rename from pySEQ/SEQuential.py rename to pySEQTarget/SEQuential.py diff --git a/pySEQ/__init__.py b/pySEQTarget/__init__.py similarity index 100% rename from pySEQ/__init__.py rename to pySEQTarget/__init__.py diff --git a/pySEQ/analysis/__init__.py b/pySEQTarget/analysis/__init__.py similarity index 100% rename from pySEQ/analysis/__init__.py rename to pySEQTarget/analysis/__init__.py diff --git a/pySEQ/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py similarity index 100% rename from pySEQ/analysis/_hazard.py rename to pySEQTarget/analysis/_hazard.py diff --git a/pySEQ/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py similarity index 100% rename from pySEQ/analysis/_outcome_fit.py rename to pySEQTarget/analysis/_outcome_fit.py diff --git a/pySEQ/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py similarity index 100% rename from pySEQ/analysis/_risk_estimates.py rename to pySEQTarget/analysis/_risk_estimates.py diff --git a/pySEQ/analysis/_subgroup_fit.py b/pySEQTarget/analysis/_subgroup_fit.py similarity index 100% rename from pySEQ/analysis/_subgroup_fit.py rename to pySEQTarget/analysis/_subgroup_fit.py diff --git a/pySEQ/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py similarity index 100% rename from pySEQ/analysis/_survival_pred.py rename to pySEQTarget/analysis/_survival_pred.py diff --git a/pySEQ/data/SEQdata.csv b/pySEQTarget/data/SEQdata.csv similarity index 100% rename from pySEQ/data/SEQdata.csv rename to pySEQTarget/data/SEQdata.csv diff --git a/pySEQ/data/SEQdata_LTFU.csv b/pySEQTarget/data/SEQdata_LTFU.csv similarity index 100% rename from pySEQ/data/SEQdata_LTFU.csv rename to pySEQTarget/data/SEQdata_LTFU.csv diff --git a/pySEQ/data/SEQdata_multitreatment.csv b/pySEQTarget/data/SEQdata_multitreatment.csv similarity index 100% rename from pySEQ/data/SEQdata_multitreatment.csv rename to pySEQTarget/data/SEQdata_multitreatment.csv diff --git a/pySEQ/data/__init__.py b/pySEQTarget/data/__init__.py similarity index 100% rename from pySEQ/data/__init__.py rename to pySEQTarget/data/__init__.py diff --git a/pySEQ/docs/Makefile b/pySEQTarget/docs/Makefile similarity index 100% rename from pySEQ/docs/Makefile rename to pySEQTarget/docs/Makefile diff --git a/pySEQ/docs/make.bat b/pySEQTarget/docs/make.bat similarity index 100% rename from pySEQ/docs/make.bat rename to pySEQTarget/docs/make.bat diff --git a/pySEQ/docs/source/conf.py b/pySEQTarget/docs/source/conf.py similarity index 100% rename from pySEQ/docs/source/conf.py rename to pySEQTarget/docs/source/conf.py diff --git a/pySEQ/docs/source/index.rst b/pySEQTarget/docs/source/index.rst similarity index 100% rename from pySEQ/docs/source/index.rst rename to pySEQTarget/docs/source/index.rst diff --git a/pySEQ/error/__init__.py b/pySEQTarget/error/__init__.py similarity index 100% rename from pySEQ/error/__init__.py rename to pySEQTarget/error/__init__.py diff --git a/pySEQ/error/_datachecker.py b/pySEQTarget/error/_datachecker.py similarity index 100% rename from pySEQ/error/_datachecker.py rename to pySEQTarget/error/_datachecker.py diff --git a/pySEQ/error/_param_checker.py b/pySEQTarget/error/_param_checker.py similarity index 100% rename from pySEQ/error/_param_checker.py rename to pySEQTarget/error/_param_checker.py diff --git a/pySEQ/expansion/__init__.py b/pySEQTarget/expansion/__init__.py similarity index 100% rename from pySEQ/expansion/__init__.py rename to pySEQTarget/expansion/__init__.py diff --git a/pySEQ/expansion/_binder.py b/pySEQTarget/expansion/_binder.py similarity index 100% rename from pySEQ/expansion/_binder.py rename to pySEQTarget/expansion/_binder.py diff --git a/pySEQ/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py similarity index 100% rename from pySEQ/expansion/_diagnostics.py rename to pySEQTarget/expansion/_diagnostics.py diff --git a/pySEQ/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py similarity index 100% rename from pySEQ/expansion/_dynamic.py rename to pySEQTarget/expansion/_dynamic.py diff --git a/pySEQ/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py similarity index 100% rename from pySEQ/expansion/_mapper.py rename to pySEQTarget/expansion/_mapper.py diff --git a/pySEQ/expansion/_selection.py b/pySEQTarget/expansion/_selection.py similarity index 100% rename from pySEQ/expansion/_selection.py rename to pySEQTarget/expansion/_selection.py diff --git a/pySEQ/helpers/__init__.py b/pySEQTarget/helpers/__init__.py similarity index 100% rename from pySEQ/helpers/__init__.py rename to pySEQTarget/helpers/__init__.py diff --git a/pySEQ/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py similarity index 100% rename from pySEQ/helpers/_bootstrap.py rename to pySEQTarget/helpers/_bootstrap.py diff --git a/pySEQ/helpers/_col_string.py b/pySEQTarget/helpers/_col_string.py similarity index 100% rename from pySEQ/helpers/_col_string.py rename to pySEQTarget/helpers/_col_string.py diff --git a/pySEQ/helpers/_format_time.py b/pySEQTarget/helpers/_format_time.py similarity index 100% rename from pySEQ/helpers/_format_time.py rename to pySEQTarget/helpers/_format_time.py diff --git a/pySEQ/helpers/_pad.py b/pySEQTarget/helpers/_pad.py similarity index 100% rename from pySEQ/helpers/_pad.py rename to pySEQTarget/helpers/_pad.py diff --git a/pySEQ/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py similarity index 100% rename from pySEQ/helpers/_predict_model.py rename to pySEQTarget/helpers/_predict_model.py diff --git a/pySEQ/helpers/_prepare_data.py b/pySEQTarget/helpers/_prepare_data.py similarity index 100% rename from pySEQ/helpers/_prepare_data.py rename to pySEQTarget/helpers/_prepare_data.py diff --git a/pySEQ/initialization/__init__.py b/pySEQTarget/initialization/__init__.py similarity index 100% rename from pySEQ/initialization/__init__.py rename to pySEQTarget/initialization/__init__.py diff --git a/pySEQ/initialization/_censoring.py b/pySEQTarget/initialization/_censoring.py similarity index 100% rename from pySEQ/initialization/_censoring.py rename to pySEQTarget/initialization/_censoring.py diff --git a/pySEQ/initialization/_denominator.py b/pySEQTarget/initialization/_denominator.py similarity index 100% rename from pySEQ/initialization/_denominator.py rename to pySEQTarget/initialization/_denominator.py diff --git a/pySEQ/initialization/_numerator.py b/pySEQTarget/initialization/_numerator.py similarity index 100% rename from pySEQ/initialization/_numerator.py rename to pySEQTarget/initialization/_numerator.py diff --git a/pySEQ/initialization/_outcome.py b/pySEQTarget/initialization/_outcome.py similarity index 100% rename from pySEQ/initialization/_outcome.py rename to pySEQTarget/initialization/_outcome.py diff --git a/pySEQ/plot/__init__.py b/pySEQTarget/plot/__init__.py similarity index 100% rename from pySEQ/plot/__init__.py rename to pySEQTarget/plot/__init__.py diff --git a/pySEQ/plot/_survival_plot.py b/pySEQTarget/plot/_survival_plot.py similarity index 100% rename from pySEQ/plot/_survival_plot.py rename to pySEQTarget/plot/_survival_plot.py diff --git a/pySEQ/weighting/__init__.py b/pySEQTarget/weighting/__init__.py similarity index 100% rename from pySEQ/weighting/__init__.py rename to pySEQTarget/weighting/__init__.py diff --git a/pySEQ/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py similarity index 100% rename from pySEQ/weighting/_weight_bind.py rename to pySEQTarget/weighting/_weight_bind.py diff --git a/pySEQ/weighting/_weight_data.py b/pySEQTarget/weighting/_weight_data.py similarity index 100% rename from pySEQ/weighting/_weight_data.py rename to pySEQTarget/weighting/_weight_data.py diff --git a/pySEQ/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py similarity index 100% rename from pySEQ/weighting/_weight_fit.py rename to pySEQTarget/weighting/_weight_fit.py diff --git a/pySEQ/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py similarity index 100% rename from pySEQ/weighting/_weight_pred.py rename to pySEQTarget/weighting/_weight_pred.py diff --git a/pySEQ/weighting/_weight_stats.py b/pySEQTarget/weighting/_weight_stats.py similarity index 100% rename from pySEQ/weighting/_weight_stats.py rename to pySEQTarget/weighting/_weight_stats.py diff --git a/pyproject.toml b/pyproject.toml index 956a738..bb41f23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ Repository = "https://github.com/CausalInference/pySEQ" "Harvard University (ROR)" = "https://ror.org/03vek6s52" [tool.setuptools] -packages = ["pySEQ", "pySEQ.data"] +packages = ["pySEQTarget", "pySEQTarget.data"] [tool.setuptools.package-data] SEQdata = ["data/*.csv"] From 005d02273d76e44bd271ba7f6fc899f42fc8d1ac Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:56:30 +0100 Subject: [PATCH 2/9] add workflows --- .github/workflows/autoformat.yml | 34 +++++++++++++++++++++++++++++++ .github/workflows/docs.yml | 35 ++++++++++++++++++++++++++++++++ .github/workflows/lint.yml | 34 +++++++++++++++++++++++++++++++ .github/workflows/pre-commit.yml | 0 pyproject.toml | 6 +++--- 5 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/autoformat.yml create mode 100644 .github/workflows/docs.yml create mode 100644 .github/workflows/lint.yml delete mode 100644 .github/workflows/pre-commit.yml diff --git a/.github/workflows/autoformat.yml b/.github/workflows/autoformat.yml new file mode 100644 index 0000000..2dcf15c --- /dev/null +++ b/.github/workflows/autoformat.yml @@ -0,0 +1,34 @@ +name: Auto-format Code + +on: + push: + branches: [main, develop] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install formatters + run: pip install black isort + + - name: Format with Black + run: black . + + - name: Sort imports with isort + run: isort . + + - name: Commit changes + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git diff --quiet || (git add -A && git commit -m "Auto-format code") + git push \ No newline at end of file diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..d6e089f --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,35 @@ +name: Build and Deploy Docs + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + build-docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install sphinx sphinx-rtd-theme sphinx-autodoc-typehints + pip install -e . # Install your package + + - name: Build documentation + run: | + cd docs + make html + + - name: Deploy to GitHub Pages + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/_build/html \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..d573f35 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,34 @@ +name: Lint and Format + +on: + push: + branches: [main, develop] + pull_request: + branches: [main, develop] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + pip install ruff black isort mypy + + - name: Run Black + run: black --check . + + - name: Run isort + run: isort --check-only . + + - name: Run Ruff + run: ruff check . + + - name: Run mypy + run: mypy . --ignore-missing-imports \ No newline at end of file diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml deleted file mode 100644 index e69de29..0000000 diff --git a/pyproject.toml b/pyproject.toml index bb41f23..b33e142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,9 @@ dependencies = [ ] [project.urls] -Homepage = "https://github.com/CausalInference/pySEQ" -Repository = "https://github.com/CausalInference/pySEQ" -"Bug Tracker" = "https://github.com/CausalInference/pySEQ/issues" +Homepage = "https://github.com/CausalInference/pySEQTarget" +Repository = "https://github.com/CausalInference/pySEQTarget" +"Bug Tracker" = "https://github.com/CausalInference/pySEQTarget/issues" "Ryan O'Dea (ORCID)" = "https://orcid.org/0009-0000-0103-9546" "Alejandro Szmulewicz (ORCID)" = "https://orcid.org/0000-0002-2664-802X" From 1f7260463e2408a484c3672ee06046327fd70921 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:59:46 +0100 Subject: [PATCH 3/9] Create publish.yml --- .github/workflows/publish.yml | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 .github/workflows/publish.yml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..1fa8189 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,70 @@ +# This workflow will upload a Python Package to PyPI when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + release-build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Build release distributions + run: | + # NOTE: put your own distribution build steps here. + python -m pip install build + python -m build + + - name: Upload distributions + uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + needs: + - release-build + permissions: + # IMPORTANT: this permission is mandatory for trusted publishing + id-token: write + + # Dedicated environments with protections for publishing are strongly recommended. + # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules + environment: + name: pypi + # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: + # url: https://pypi.org/p/YOURPROJECT + # + # ALTERNATIVE: if your GitHub Release name is the PyPI project version string + # ALTERNATIVE: exactly, uncomment the following line instead: + # url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }} + + steps: + - name: Retrieve release distributions + uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + + - name: Publish release distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ \ No newline at end of file From dad16edf88baa5dfffc6cff4e71ebfa17c5a6530 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:05:49 +0100 Subject: [PATCH 4/9] fixed tests --- tests/test_accessor.py | 4 ++-- tests/test_coefficients.py | 4 ++-- tests/test_covariates.py | 4 ++-- tests/test_followup_options.py | 4 ++-- tests/test_hazard.py | 4 ++-- tests/test_parallel.py | 4 ++-- tests/test_survival.py | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 317ef5c..9039030 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -1,5 +1,5 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQuential, SEQopts +from pySEQTarget.data import load_data import pytest def test_ITT_collector(): diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index f240137..ae76aa5 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,5 +1,5 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQuential, SEQopts +from pySEQTarget.data import load_data def test_ITT_coefs(): data = load_data("SEQdata") diff --git a/tests/test_covariates.py b/tests/test_covariates.py index f07b2c4..31cc727 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -1,5 +1,5 @@ -from pySEQ.data import load_data -from pySEQ import SEQuential, SEQopts +from pySEQTarget.data import load_data +from pySEQTarget import SEQuential, SEQopts def test_ITT_covariates(): data = load_data("SEQdata") diff --git a/tests/test_followup_options.py b/tests/test_followup_options.py index 2cee22c..6bc0c08 100644 --- a/tests/test_followup_options.py +++ b/tests/test_followup_options.py @@ -1,5 +1,5 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQuential, SEQopts +from pySEQTarget.data import load_data def test_followup_class(): data = load_data("SEQdata") diff --git a/tests/test_hazard.py b/tests/test_hazard.py index a7e0cf8..2244bed 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,5 +1,5 @@ -from pySEQ.data import load_data -from pySEQ import SEQuential, SEQopts +from pySEQTarget.data import load_data +from pySEQTarget import SEQuential, SEQopts def test_ITT_hazard(): data = load_data("SEQdata") diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 3d691fd..b3e63db 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,5 +1,5 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQuential, SEQopts +from pySEQTarget.data import load_data import os import pytest diff --git a/tests/test_survival.py b/tests/test_survival.py index 4466655..abb89a7 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,5 +1,5 @@ -from pySEQ import SEQuential, SEQopts -from pySEQ.data import load_data +from pySEQTarget import SEQuential, SEQopts +from pySEQTarget.data import load_data def test_regular_survival(): data = load_data("SEQdata") From 57fe984c00b68862c01170f78243c974d4efd650 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:09:40 +0100 Subject: [PATCH 5/9] black . --- pySEQTarget/SEQopts.py | 49 ++- pySEQTarget/SEQoutput.py | 46 +-- pySEQTarget/SEQuential.py | 299 +++++++++------- pySEQTarget/__init__.py | 5 +- pySEQTarget/analysis/_hazard.py | 186 ++++++---- pySEQTarget/analysis/_outcome_fit.py | 37 +- pySEQTarget/analysis/_risk_estimates.py | 167 +++++---- pySEQTarget/analysis/_subgroup_fit.py | 30 +- pySEQTarget/analysis/_survival_pred.py | 389 +++++++++++++-------- pySEQTarget/data/__init__.py | 5 +- pySEQTarget/docs/source/conf.py | 11 +- pySEQTarget/error/__init__.py | 2 +- pySEQTarget/error/_datachecker.py | 35 +- pySEQTarget/error/_param_checker.py | 41 ++- pySEQTarget/expansion/__init__.py | 2 +- pySEQTarget/expansion/_binder.py | 104 +++--- pySEQTarget/expansion/_diagnostics.py | 77 ++-- pySEQTarget/expansion/_dynamic.py | 52 ++- pySEQTarget/expansion/_mapper.py | 45 ++- pySEQTarget/expansion/_selection.py | 44 +-- pySEQTarget/helpers/__init__.py | 2 +- pySEQTarget/helpers/_bootstrap.py | 77 ++-- pySEQTarget/helpers/_col_string.py | 2 +- pySEQTarget/helpers/_predict_model.py | 3 +- pySEQTarget/helpers/_prepare_data.py | 19 +- pySEQTarget/initialization/__init__.py | 2 +- pySEQTarget/initialization/_censoring.py | 46 ++- pySEQTarget/initialization/_denominator.py | 24 +- pySEQTarget/initialization/_numerator.py | 24 +- pySEQTarget/initialization/_outcome.py | 26 +- pySEQTarget/plot/__init__.py | 2 +- pySEQTarget/plot/_survival_plot.py | 57 +-- pySEQTarget/weighting/__init__.py | 2 +- pySEQTarget/weighting/_weight_bind.py | 80 +++-- pySEQTarget/weighting/_weight_data.py | 65 ++-- pySEQTarget/weighting/_weight_fit.py | 46 ++- pySEQTarget/weighting/_weight_pred.py | 75 ++-- pySEQTarget/weighting/_weight_stats.py | 29 +- 38 files changed, 1299 insertions(+), 908 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index e5d737a..944c86b 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -2,13 +2,14 @@ from dataclasses import dataclass, field from typing import List, Optional, Literal + @dataclass class SEQopts: bootstrap_nboot: int = 0 bootstrap_sample: float = 0.8 bootstrap_CI: float = 0.95 bootstrap_CI_method: Literal["se", "percentile"] = "se" - cense_colname : Optional[str] = None + cense_colname: Optional[str] = None cense_denominator: Optional[str] = None cense_numerator: Optional[str] = None cense_eligible_colname: Optional[str] = None @@ -29,7 +30,9 @@ class SEQopts: ncores: int = multiprocessing.cpu_count() numerator: Optional[str] = None parallel: bool = False - plot_colors: List[str] = field(default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]) + plot_colors: List[str] = field( + default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"] + ) plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None plot_type: Literal["risk", "survival", "incidence"] = "risk" @@ -47,14 +50,23 @@ class SEQopts: weight_p99: bool = False weight_preexpansion: bool = False weighted: bool = False - + def __post_init__(self): bools = [ - "excused", "followup_class", "followup_include", - "followup_spline", "hazard_estimate", "km_curves", - "parallel", "selection_first_trial", "selection_random", - "trial_include", "weight_lag_condition", "weight_p99", - "weight_preexpansion", "weighted" + "excused", + "followup_class", + "followup_include", + "followup_spline", + "hazard_estimate", + "km_curves", + "parallel", + "selection_first_trial", + "selection_random", + "trial_include", + "weight_lag_condition", + "weight_p99", + "weight_preexpansion", + "weighted", ] for i in bools: if not isinstance(getattr(self, i), bool): @@ -62,25 +74,32 @@ def __post_init__(self): if not isinstance(self.bootstrap_nboot, int) or self.bootstrap_nboot < 0: raise ValueError("bootstrap_nboot must be a positive integer.") - + if self.ncores < 1 or not isinstance(self.ncores, int): raise ValueError("ncores must be a positive integer.") - + if not (0.0 <= self.bootstrap_sample <= 1.0): raise ValueError("bootstrap_sample must be between 0 and 1.") if not (0.0 < self.bootstrap_CI < 1.0): raise ValueError("bootstrap_CI must be between 0 and 1.") if not (0.0 <= self.selection_probability <= 1.0): raise ValueError("selection_probability must be between 0 and 1.") - + if self.plot_type not in ["risk", "survival", "incidence"]: - raise ValueError("plot_type must be either 'risk', 'survival', or 'incidence'.") - + raise ValueError( + "plot_type must be either 'risk', 'survival', or 'incidence'." + ) + if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") - for i in ("covariates", "numerator", "denominator", - "cense_numerator", "cense_denominator"): + for i in ( + "covariates", + "numerator", + "denominator", + "cense_numerator", + "cense_denominator", + ): attr = getattr(self, i) if attr is not None and not isinstance(attr, list): setattr(self, i, "".join(attr.split())) diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index 1cb4d32..ca22a20 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -5,6 +5,7 @@ import polars as pl import matplotlib.figure + @dataclass class SEQoutput: options: SEQopts = None @@ -21,16 +22,13 @@ class SEQoutput: risk_difference: pl.DataFrame = None time: dict = None diagnostic_tables: dict = None - + def plot(self): print(self.km_graph) - - def summary(self, - type = Optional[Literal[ - "numerator", - "denominator", - "outcome", - "compevent"]]): + + def summary( + self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] + ): match type: case "numerator": models = self.numerator_models @@ -40,20 +38,24 @@ def summary(self, models = self.compevent_models case _: models = self.outcome_models - + return [model.summary() for model in models] - - def retrieve_data(self, - type = Optional[Literal[ - "km_data", - "hazard", - "risk_ratio", - "risk_difference", - "unique_outcomes", - "nonunique_outcomes", - "unique_switches", - "nonunique_switches" - ]]): + + def retrieve_data( + self, + type=Optional[ + Literal[ + "km_data", + "hazard", + "risk_ratio", + "risk_difference", + "unique_outcomes", + "nonunique_outcomes", + "unique_switches", + "nonunique_switches", + ] + ], + ): match type: case "hazard": data = self.hazard @@ -80,5 +82,3 @@ def retrieve_data(self, if data is None: raise ValueError("Data {type} was not created in the SEQuential process") return data - - \ No newline at end of file diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 3240e9e..05158e2 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -10,26 +10,47 @@ from .SEQoutput import SEQoutput from .error import _param_checker, _datachecker from .helpers import _col_string, bootstrap_loop, _format_time -from .initialization import _outcome, _numerator, _denominator, _cense_numerator, _cense_denominator +from .initialization import ( + _outcome, + _numerator, + _denominator, + _cense_numerator, + _cense_denominator, +) from .expansion import _binder, _dynamic, _random_selection, _diagnostics -from .weighting import _weight_setup, _fit_LTFU, _fit_numerator, _fit_denominator, _weight_bind, _weight_predict, _weight_stats -from .analysis import _outcome_fit, _pred_risk, _calculate_survival, _subgroup_fit, _calculate_hazard, _risk_estimates +from .weighting import ( + _weight_setup, + _fit_LTFU, + _fit_numerator, + _fit_denominator, + _weight_bind, + _weight_predict, + _weight_stats, +) +from .analysis import ( + _outcome_fit, + _pred_risk, + _calculate_survival, + _subgroup_fit, + _calculate_hazard, + _risk_estimates, +) from .plot import _survival_plot class SEQuential: def __init__( - self, - data: pl.DataFrame, - id_col: str, - time_col: str, - eligible_col: str, - treatment_col: str, - outcome_col: str, - time_varying_cols: Optional[List[str]] = None, - fixed_cols: Optional[List[str]] = None, - method: Literal["ITT", "dose-response", "censoring"] = "ITT", - parameters: Optional[SEQopts] = None + self, + data: pl.DataFrame, + id_col: str, + time_col: str, + eligible_col: str, + treatment_col: str, + outcome_col: str, + time_varying_cols: Optional[List[str]] = None, + fixed_cols: Optional[List[str]] = None, + method: Literal["ITT", "dose-response", "censoring"] = "ITT", + parameters: Optional[SEQopts] = None, ) -> None: self.data = data self.id_col = id_col @@ -40,17 +61,19 @@ def __init__( self.time_varying_cols = time_varying_cols self.fixed_cols = fixed_cols self.method = method - + self._time_initialized = datetime.datetime.now() - + if parameters is None: parameters = SEQopts() - + for name, value in asdict(parameters).items(): setattr(self, name, value) - - self._rng = np.random.RandomState(self.seed) if self.seed is not None else np.random - + + self._rng = ( + np.random.RandomState(self.seed) if self.seed is not None else np.random + ) + if self.covariates is None: self.covariates = _outcome(self) @@ -67,194 +90,227 @@ def __init__( if self.cense_denominator is None: self.cense_denominator = _cense_denominator(self) - + _param_checker(self) _datachecker(self) def expand(self): start = time.perf_counter() - kept = [self.cense_colname, self.cense_eligible_colname, - self.compevent_colname, - *self.weight_eligible_colnames, - *self.excused_colnames] - - self.data = self.data.with_columns([ - pl.when(pl.col(self.treatment_col).is_in(self.treatment_level)) - .then(self.eligible_col) - .otherwise(0) - .alias(self.eligible_col), - pl.col(self.treatment_col) - .shift(1) - .over([self.id_col]) - .alias("tx_lag"), - pl.lit(False).alias("switch") - ]).with_columns([ - pl.when(pl.col(self.time_col) == 0) - .then(pl.lit(False)) - .otherwise( - (pl.col("tx_lag").is_not_null()) & - (pl.col("tx_lag") != pl.col(self.treatment_col)) - ).cast(pl.Int8) - .alias("switch") - ]) - - self.DT = _binder(self, kept_cols= _col_string([self.covariates, - self.numerator, - self.denominator, - self.cense_numerator, - self.cense_denominator]).union(kept)) \ - .with_columns( - pl.col(self.id_col) - .cast(pl.Utf8) - .alias(self.id_col) - ) - + kept = [ + self.cense_colname, + self.cense_eligible_colname, + self.compevent_colname, + *self.weight_eligible_colnames, + *self.excused_colnames, + ] + self.data = self.data.with_columns( - pl.col(self.id_col) - .cast(pl.Utf8) - .alias(self.id_col) - ) - + [ + pl.when(pl.col(self.treatment_col).is_in(self.treatment_level)) + .then(self.eligible_col) + .otherwise(0) + .alias(self.eligible_col), + pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"), + pl.lit(False).alias("switch"), + ] + ).with_columns( + [ + pl.when(pl.col(self.time_col) == 0) + .then(pl.lit(False)) + .otherwise( + (pl.col("tx_lag").is_not_null()) + & (pl.col("tx_lag") != pl.col(self.treatment_col)) + ) + .cast(pl.Int8) + .alias("switch") + ] + ) + + self.DT = _binder( + self, + kept_cols=_col_string( + [ + self.covariates, + self.numerator, + self.denominator, + self.cense_numerator, + self.cense_denominator, + ] + ).union(kept), + ).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)) + + self.data = self.data.with_columns( + pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col) + ) + if self.method != "ITT": _dynamic(self) if self.selection_random: _random_selection(self) _diagnostics(self) - + end = time.perf_counter() self._expansion_time = _format_time(start, end) - + def bootstrap(self, **kwargs): - allowed = {"bootstrap_nboot", "bootstrap_sample", - "bootstrap_CI", "bootstrap_method"} + allowed = { + "bootstrap_nboot", + "bootstrap_sample", + "bootstrap_CI", + "bootstrap_method", + } for key, value in kwargs.items(): if key in allowed: setattr(self, key, value) else: raise ValueError(f"Unknown argument: {key}") - + UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list() NIDs = len(UIDs) - + self._boot_samples = [] for _ in range(self.bootstrap_nboot): - sampled_IDs = self._rng.choice(UIDs, size=int(self.bootstrap_sample * NIDs), replace=True) + sampled_IDs = self._rng.choice( + UIDs, size=int(self.bootstrap_sample * NIDs), replace=True + ) id_counts = Counter(sampled_IDs) self._boot_samples.append(id_counts) return self - - @bootstrap_loop + + @bootstrap_loop def fit(self): if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"): - raise ValueError("Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping.") - + raise ValueError( + "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping." + ) + if self.weighted: WDT = _weight_setup(self) if not self.weight_preexpansion and not self.excused: WDT = WDT.filter(pl.col("followup") > 0) - + WDT = WDT.to_pandas() for col in self.fixed_cols: if col in WDT.columns: WDT[col] = WDT[col].astype("category") - + _fit_LTFU(self, WDT) _fit_numerator(self, WDT) _fit_denominator(self, WDT) - + WDT = pl.from_pandas(WDT) WDT = _weight_predict(self, WDT) _weight_bind(self, WDT) self.weight_stats = _weight_stats(self) - + if self.subgroup_colname is not None: return _subgroup_fit(self) - - models = {'outcome': _outcome_fit(self, self.DT, - self.outcome_col, - self.covariates, - self.weighted, - "weight")} + + models = { + "outcome": _outcome_fit( + self, + self.DT, + self.outcome_col, + self.covariates, + self.weighted, + "weight", + ) + } if self.compevent_colname is not None: - models['compevent'] = _outcome_fit(self, self.DT, - self.compevent_colname, - self.covariates, - self.weighted, - "weight") + models["compevent"] = _outcome_fit( + self, + self.DT, + self.compevent_colname, + self.covariates, + self.weighted, + "weight", + ) return models - + def survival(self): if not hasattr(self, "outcome_model") or not self.outcome_model: - raise ValueError("Outcome model not found. Please run the 'fit' method before calculating survival.") - + raise ValueError( + "Outcome model not found. Please run the 'fit' method before calculating survival." + ) + start = time.perf_counter() - + risk_data = _pred_risk(self) surv_data = _calculate_survival(self, risk_data) self.km_data = pl.concat([risk_data, surv_data]) self.risk_estimates = _risk_estimates(self) - + end = time.perf_counter() self._survival_time = _format_time(start, end) - + def hazard(self): start = time.perf_counter() - + if not hasattr(self, "outcome_model") or not self.outcome_model: - raise ValueError("Outcome model not found. Please run the 'fit' method before calculating hazard ratio.") + raise ValueError( + "Outcome model not found. Please run the 'fit' method before calculating hazard ratio." + ) self.hazard_ratio = _calculate_hazard(self) - + end = time.perf_counter() self._hazard_time = _format_time(start, end) def plot(self): self.km_graph = _survival_plot(self) - + def collect(self): self._time_collected = datetime.datetime.now() - + generated = [ - "numerator_model", "denominator_model", + "numerator_model", + "denominator_model", "outcome_model", - "hazard_ratio", "risk_estimates", - "km_data", "km_graph", "diagnostics", - "_survival_time", "_hazard_time", - "_model_time", "_expansion_time", - "weight_stats" + "hazard_ratio", + "risk_estimates", + "km_data", + "km_graph", + "diagnostics", + "_survival_time", + "_hazard_time", + "_model_time", + "_expansion_time", + "weight_stats", ] for attr in generated: if not hasattr(self, attr): setattr(self, attr, None) - + # Options ========================== base = SEQopts() - + for name, value in vars(self).items(): if name in asdict(base).keys(): setattr(base, name, value) - - # Timing ========================= - time = {"start_time": self._time_initialized, - "expansion_time": self._expansion_time, - "model_time": self._model_time, - "survival_time": self._survival_time, - "hazard_time": self._hazard_time, - "collection_time": self._time_collected} - + + # Timing ========================= + time = { + "start_time": self._time_initialized, + "expansion_time": self._expansion_time, + "model_time": self._model_time, + "survival_time": self._survival_time, + "hazard_time": self._hazard_time, + "collection_time": self._time_collected, + } + if self.compevent_colname is not None: compevent_models = [model["compevent"] for model in self.outcome_models] else: compevent_models = None - + if self.outcome_model is not None: outcome_models = [model["outcome"] for model in self.outcome_model] - + if self.risk_estimates is None: risk_ratio = risk_difference = None else: risk_ratio = self.risk_estimates["risk_ratio"] risk_difference = self.risk_estimates["risk_difference"] - + output = SEQoutput( options=base, method=self.method, @@ -269,8 +325,7 @@ def collect(self): risk_ratio=risk_ratio, risk_difference=risk_difference, time=time, - diagnostic_tables=self.diagnostics + diagnostic_tables=self.diagnostics, ) - + return output - diff --git a/pySEQTarget/__init__.py b/pySEQTarget/__init__.py index 4167c03..540b3eb 100644 --- a/pySEQTarget/__init__.py +++ b/pySEQTarget/__init__.py @@ -2,7 +2,4 @@ from .SEQopts import SEQopts from .SEQoutput import SEQoutput -__all__ = [ - "SEQuential", - "SEQopts" -] \ No newline at end of file +__all__ = ["SEQuential", "SEQopts"] diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index ec4ca80..195120e 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -3,18 +3,19 @@ from lifelines import CoxPHFitter import warnings + def _calculate_hazard(self): if self.subgroup_colname is None: return _calculate_hazard_single(self, self.DT, idx=None, val=None) - + all_hazards = [] original_DT = self.DT - + for i, val in enumerate(self._unique_subgroups): subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val) hazard = _calculate_hazard_single(self, subgroup_DT, i, val) all_hazards.append(hazard) - + self.DT = original_DT return pl.concat(all_hazards) @@ -24,30 +25,31 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if full_hr is None or np.isnan(full_hr): return _create_hazard_output(None, None, None, val, self) - + if self.bootstrap_nboot > 0: boot_hrs = [] - + for boot_idx in range(len(self._boot_samples)): id_counts = self._boot_samples[boot_idx] - + boot_data_list = [] for id_val, count in id_counts.items(): id_data = data.filter(pl.col(self.id_col) == id_val) for _ in range(count): boot_data_list.append(id_data) - + boot_data = pl.concat(boot_data_list) - + boot_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng) if boot_hr is not None and not np.isnan(boot_hr): boot_hrs.append(boot_hr) - + if len(boot_hrs) == 0: return _create_hazard_output(full_hr, None, None, val, self) - + if self.bootstrap_CI_method == "se": from scipy.stats import norm + z = norm.ppf(1 - (1 - self.bootstrap_CI) / 2) se = np.std(boot_hrs) lci = full_hr - z * se @@ -57,102 +59,132 @@ def _calculate_hazard_single(self, data, idx=None, val=None): uci = np.quantile(boot_hrs, 1 - (1 - self.bootstrap_CI) / 2) else: lci, uci = None, None - + return _create_hazard_output(full_hr, lci, uci, val, self) def _hazard_handler(self, data, idx, boot_idx, rng): - exclude_cols = ["followup", f"followup{self.indicator_squared}", self.treatment_col, - f"{self.treatment_col}{self.indicator_baseline}", "period", self.outcome_col] + exclude_cols = [ + "followup", + f"followup{self.indicator_squared}", + self.treatment_col, + f"{self.treatment_col}{self.indicator_baseline}", + "period", + self.outcome_col, + ] if self.compevent_colname: exclude_cols.append(self.compevent_colname) keep_cols = [col for col in data.columns if col not in exclude_cols] - + trials = ( data.select(keep_cols) - .group_by([self.id_col, "trial"]).first() + .group_by([self.id_col, "trial"]) + .first() .with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")]) .explode("followup") - .with_columns([(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]) + .with_columns( + [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")] + ) ) - + if idx is not None: model_dict = self.outcome_model[boot_idx][idx] else: model_dict = self.outcome_model[boot_idx] - - outcome_model = model_dict['outcome'] - ce_model = model_dict.get('compevent', None) if self.compevent_colname else None - + + outcome_model = model_dict["outcome"] + ce_model = model_dict.get("compevent", None) if self.compevent_colname else None + all_treatments = [] for val in self.treatment_level: - tmp = trials.with_columns([ - pl.lit(val) - .alias(f"{self.treatment_col}{self.indicator_baseline}") - ]) - + tmp = trials.with_columns( + [pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")] + ) + tmp_pd = tmp.to_pandas() outcome_prob = outcome_model.predict(tmp_pd) outcome_sim = rng.binomial(1, outcome_prob) - + tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)]) - + if ce_model is not None: ce_prob = ce_model.predict(tmp_pd) ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) - - tmp = tmp.with_columns([ - pl.when((pl.col("outcome") == 1) | - (pl.col("ce") == 1)) - .then(1) - .otherwise(0) - .alias("any_event") - ]).with_columns([ - pl.col("any_event") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ]).filter(pl.col("event_cumsum") <= 1) + + tmp = ( + tmp.with_columns( + [ + pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1)) + .then(1) + .otherwise(0) + .alias("any_event") + ] + ) + .with_columns( + [ + pl.col("any_event") + .cum_sum() + .over([self.id_col, "trial"]) + .alias("event_cumsum") + ] + ) + .filter(pl.col("event_cumsum") <= 1) + ) else: - tmp = tmp.with_columns([ - pl.col("outcome") - .cum_sum() - .over([self.id_col, "trial"]) - .alias("event_cumsum") - ]).filter(pl.col("event_cumsum") <= 1) - + tmp = tmp.with_columns( + [ + pl.col("outcome") + .cum_sum() + .over([self.id_col, "trial"]) + .alias("event_cumsum") + ] + ).filter(pl.col("event_cumsum") <= 1) + tmp = tmp.group_by([self.id_col, "trial"]).last() all_treatments.append(tmp) - + sim_data = pl.concat(all_treatments) - + if ce_model is not None: - sim_data = sim_data.with_columns([ - pl.when(pl.col("outcome") == 1).then(pl.lit(1)) - .when(pl.col("ce") == 1).then(pl.lit(2)) - .otherwise(pl.lit(0)).alias("event") - ]) + sim_data = sim_data.with_columns( + [ + pl.when(pl.col("outcome") == 1) + .then(pl.lit(1)) + .when(pl.col("ce") == 1) + .then(pl.lit(2)) + .otherwise(pl.lit(0)) + .alias("event") + ] + ) else: sim_data = sim_data.with_columns([pl.col("outcome").alias("event")]) - + sim_data_pd = sim_data.to_pandas() - + try: - #COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() - warnings.filterwarnings('ignore', message='.*datetime.datetime.utcnow.*') + # COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() + warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*") if ce_model is not None: - cox_data = sim_data_pd[sim_data_pd['event'].isin([0, 1])].copy() - cox_data['event_binary'] = (cox_data['event'] == 1).astype(int) - + cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy() + cox_data["event_binary"] = (cox_data["event"] == 1).astype(int) + cph = CoxPHFitter() - cph.fit(cox_data, duration_col='followup', event_col='event_binary', - formula=f"`{self.treatment_col}{self.indicator_baseline}`") + cph.fit( + cox_data, + duration_col="followup", + event_col="event_binary", + formula=f"`{self.treatment_col}{self.indicator_baseline}`", + ) else: cph = CoxPHFitter() - cph.fit(sim_data_pd, duration_col='followup', event_col='event', - formula=f"`{self.treatment_col}{self.indicator_baseline}`") - + cph.fit( + sim_data_pd, + duration_col="followup", + event_col="event", + formula=f"`{self.treatment_col}{self.indicator_baseline}`", + ) + hr = np.exp(cph.params_.values[0]) return hr except Exception as e: @@ -162,17 +194,17 @@ def _hazard_handler(self, data, idx, boot_idx, rng): def _create_hazard_output(hr, lci, uci, val, self): if lci is not None and uci is not None: - output = pl.DataFrame({ - "Hazard": [hr if hr is not None else float('nan')], - "LCI": [lci], - "UCI": [uci] - }) + output = pl.DataFrame( + { + "Hazard": [hr if hr is not None else float("nan")], + "LCI": [lci], + "UCI": [uci], + } + ) else: - output = pl.DataFrame({ - "Hazard": [hr if hr is not None else float('nan')] - }) - + output = pl.DataFrame({"Hazard": [hr if hr is not None else float("nan")]}) + if val is not None: output = output.with_columns(pl.lit(val).alias(self.subgroup_colname)) - + return output diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index f357680..451dc43 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -3,66 +3,69 @@ import polars as pl import re + def _outcome_fit( self, df: pl.DataFrame, outcome: str, formula: str, weighted: bool = False, - weight_col: str = "weight"): + weight_col: str = "weight", +): if weighted: df = df.with_columns( pl.col(weight_col).clip( - lower_bound=self.weight_min, - upper_bound=self.weight_max + lower_bound=self.weight_min, upper_bound=self.weight_max ) ) - + df_pd = df.to_pandas() - + df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") tx_bas = f"{self.treatment_col}{self.indicator_baseline}" df_pd[tx_bas] = df_pd[tx_bas].astype("category") - + if self.followup_class and not self.followup_spline: df_pd["followup"] = df_pd["followup"].astype("category") squared_col = f"followup{self.indicator_squared}" if squared_col in df_pd.columns: df_pd[squared_col] = df_pd[squared_col].astype("category") - + if self.followup_spline: spline = f"cr(followup, df=3)" - + formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula) formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula) - formula = re.sub(rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula) + formula = re.sub( + rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula + ) formula = re.sub(r"\bfollowup\b", "", formula) - + formula = re.sub(r"\s+", " ", formula) formula = re.sub(r"\+\s*\+", "+", formula) formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip() - + if formula: formula = f"{formula} + I({spline}**2)" else: formula = f"I({spline}**2)" - + if self.fixed_cols: for col in self.fixed_cols: if col in df_pd.columns: df_pd[col] = df_pd[col].astype("category") - + full_formula = f"{outcome} ~ {formula}" - + glm_kwargs = { "formula": full_formula, "data": df_pd, - "family": sm.families.Binomial() + "family": sm.families.Binomial(), } - + if weighted: glm_kwargs["var_weights"] = df_pd[weight_col] - + model = smf.glm(**glm_kwargs) model_fit = model.fit() return model_fit diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py index 2303fa7..32c0dc4 100644 --- a/pySEQTarget/analysis/_risk_estimates.py +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -1,107 +1,136 @@ import polars as pl from scipy import stats + def _risk_estimates(self): - last_followup = self.km_data['followup'].max() + last_followup = self.km_data["followup"].max() risk = self.km_data.filter( - (pl.col('followup') == last_followup) & - (pl.col('estimate') == 'risk') + (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk") ) - + group_cols = [self.subgroup_colname] if self.subgroup_colname else [] rd_comparisons = [] rr_comparisons = [] - + if self.bootstrap_nboot > 0: alpha = 1 - self.bootstrap_CI z = stats.norm.ppf(1 - alpha / 2) - + for tx_x in self.treatment_level: for tx_y in self.treatment_level: if tx_x == tx_y: continue - - risk_x = risk.filter(pl.col('tx_init') == tx_x).select( - group_cols + ['pred'] - ).rename({'pred': 'risk_x'}) - - risk_y = risk.filter(pl.col('tx_init') == tx_y).select( - group_cols + ['pred'] - ).rename({'pred': 'risk_y'}) - + + risk_x = ( + risk.filter(pl.col("tx_init") == tx_x) + .select(group_cols + ["pred"]) + .rename({"pred": "risk_x"}) + ) + + risk_y = ( + risk.filter(pl.col("tx_init") == tx_y) + .select(group_cols + ["pred"]) + .rename({"pred": "risk_y"}) + ) + if group_cols: - comp = risk_x.join(risk_y, on=group_cols, how='left') + comp = risk_x.join(risk_y, on=group_cols, how="left") else: - comp = risk_x.join(risk_y, how='cross') - - comp = comp.with_columns([ - pl.lit(tx_x).alias('A_x'), - pl.lit(tx_y).alias('A_y') - ]) - + comp = risk_x.join(risk_y, how="cross") + + comp = comp.with_columns( + [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")] + ) + if self.bootstrap_nboot > 0: - se_x = risk.filter(pl.col('tx_init') == tx_x).select( - group_cols + ['SE'] - ).rename({'SE': 'se_x'}) - - se_y = risk.filter(pl.col('tx_init') == tx_y).select( - group_cols + ['SE'] - ).rename({'SE': 'se_y'}) - + se_x = ( + risk.filter(pl.col("tx_init") == tx_x) + .select(group_cols + ["SE"]) + .rename({"SE": "se_x"}) + ) + + se_y = ( + risk.filter(pl.col("tx_init") == tx_y) + .select(group_cols + ["SE"]) + .rename({"SE": "se_y"}) + ) + if group_cols: - comp = comp.join(se_x, on=group_cols, how='left') - comp = comp.join(se_y, on=group_cols, how='left') + comp = comp.join(se_x, on=group_cols, how="left") + comp = comp.join(se_y, on=group_cols, how="left") else: - comp = comp.join(se_x, how='cross') - comp = comp.join(se_y, how='cross') - - rd_se = (pl.col('se_x').pow(2) + pl.col('se_y').pow(2)).sqrt() - rd_comp = comp.with_columns([ - (pl.col('risk_x') - pl.col('risk_y')).alias('Risk Difference'), - (pl.col('risk_x') - pl.col('risk_y') - z * rd_se).alias('RD 95% LCI'), - (pl.col('risk_x') - pl.col('risk_y') + z * rd_se).alias('RD 95% UCI') - ]) - rd_comp = rd_comp.drop(['risk_x', 'risk_y', 'se_x', 'se_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Difference', 'RD 95% LCI', 'RD 95% UCI'] + comp = comp.join(se_x, how="cross") + comp = comp.join(se_y, how="cross") + + rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt() + rd_comp = comp.with_columns( + [ + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"), + (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias( + "RD 95% LCI" + ), + (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias( + "RD 95% UCI" + ), + ] + ) + rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + [ + "A_x", + "A_y", + "Risk Difference", + "RD 95% LCI", + "RD 95% UCI", + ] rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) rd_comparisons.append(rd_comp) - + rr_log_se = ( - (pl.col('se_x') / pl.col('risk_x')).pow(2) + - (pl.col('se_y') / pl.col('risk_y')).pow(2) + (pl.col("se_x") / pl.col("risk_x")).pow(2) + + (pl.col("se_y") / pl.col("risk_y")).pow(2) ).sqrt() - rr_comp = comp.with_columns([ - (pl.col('risk_x') / pl.col('risk_y')).alias('Risk Ratio'), - ((pl.col('risk_x') / pl.col('risk_y')) * (-z * rr_log_se).exp()).alias('RR 95% LCI'), - ((pl.col('risk_x') / pl.col('risk_y')) * (z * rr_log_se).exp()).alias('RR 95% UCI') - ]) - rr_comp = rr_comp.drop(['risk_x', 'risk_y', 'se_x', 'se_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Ratio', 'RR 95% LCI', 'RR 95% UCI'] + rr_comp = comp.with_columns( + [ + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"), + ( + (pl.col("risk_x") / pl.col("risk_y")) + * (-z * rr_log_se).exp() + ).alias("RR 95% LCI"), + ( + (pl.col("risk_x") / pl.col("risk_y")) + * (z * rr_log_se).exp() + ).alias("RR 95% UCI"), + ] + ) + rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"]) + col_order = group_cols + [ + "A_x", + "A_y", + "Risk Ratio", + "RR 95% LCI", + "RR 95% UCI", + ] rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) rr_comparisons.append(rr_comp) - + else: rd_comp = comp.with_columns( - (pl.col('risk_x') - pl.col('risk_y')).alias('Risk Difference') + (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference") ) - rd_comp = rd_comp.drop(['risk_x', 'risk_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Difference'] + rd_comp = rd_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Difference"] rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns]) rd_comparisons.append(rd_comp) - + rr_comp = comp.with_columns( - (pl.col('risk_x') / pl.col('risk_y')).alias('Risk Ratio') + (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio") ) - rr_comp = rr_comp.drop(['risk_x', 'risk_y']) - col_order = group_cols + ['A_x', 'A_y', 'Risk Ratio'] + rr_comp = rr_comp.drop(["risk_x", "risk_y"]) + col_order = group_cols + ["A_x", "A_y", "Risk Ratio"] rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns]) rr_comparisons.append(rr_comp) - + risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame() risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame() - - return { - 'risk_difference': risk_difference, - 'risk_ratio': risk_ratio - } - \ No newline at end of file + + return {"risk_difference": risk_difference, "risk_ratio": risk_ratio} diff --git a/pySEQTarget/analysis/_subgroup_fit.py b/pySEQTarget/analysis/_subgroup_fit.py index 10ab448..b6a6c04 100644 --- a/pySEQTarget/analysis/_subgroup_fit.py +++ b/pySEQTarget/analysis/_subgroup_fit.py @@ -1,25 +1,29 @@ import polars as pl from ._outcome_fit import _outcome_fit + def _subgroup_fit(self): subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list()) self._unique_subgroups = subgroups - + models_list = [] for val in subgroups: subDT = self.DT.filter(pl.col(self.subgroup_colname) == val) - - models = {'outcome': _outcome_fit(self, subDT, - self.outcome_col, - self.covariates, - self.weighted, - "weight")} - + + models = { + "outcome": _outcome_fit( + self, subDT, self.outcome_col, self.covariates, self.weighted, "weight" + ) + } + if self.compevent_colname is not None: - models['compevent'] = _outcome_fit(self, subDT, - self.compevent_colname, - self.covariates, - self.weighted, - "weight") + models["compevent"] = _outcome_fit( + self, + subDT, + self.compevent_colname, + self.covariates, + self.weighted, + "weight", + ) models_list.append(models) return models_list diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index ce8b0d5..617fc2c 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -1,276 +1,363 @@ import polars as pl + def _get_outcome_predictions(self, TxDT, idx=None): data = TxDT.to_pandas() predictions = {"outcome": []} if self.compevent_colname is not None: predictions["compevent"] = [] - + for boot_model in self.outcome_model: model_dict = boot_model[idx] if idx is not None else boot_model predictions["outcome"].append(model_dict["outcome"].predict(data)) if self.compevent_colname is not None: predictions["compevent"].append(model_dict["compevent"].predict(data)) - + return predictions + def _pred_risk(self): - has_subgroups = (isinstance(self.outcome_model[0], list) if self.outcome_model else False) - + has_subgroups = ( + isinstance(self.outcome_model[0], list) if self.outcome_model else False + ) + if not has_subgroups: return _calculate_risk(self, self.DT, idx=None, val=None) - + all_risks = [] original_DT = self.DT - + for i, val in enumerate(self._unique_subgroups): subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val) risk = _calculate_risk(self, subgroup_DT, i, val) all_risks.append(risk) - + self.DT = original_DT return pl.concat(all_risks) + def _calculate_risk(self, data, idx=None, val=None): a = 1 - self.bootstrap_CI lci = a / 2 uci = 1 - lci - + SDT = ( - data - .with_columns([ - (pl.col(self.id_col) - .cast(pl.Utf8) + - pl.col("trial") - .cast(pl.Utf8)) - .alias("TID") - ]) - .group_by("TID").first() + data.with_columns( + [ + ( + pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8) + ).alias("TID") + ] + ) + .group_by("TID") + .first() .drop(["followup", f"followup{self.indicator_squared}"]) .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")]) .explode("followup") - .with_columns([ - (pl.col("followup") + 1).alias("followup"), - (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}") - ]) + .with_columns( + [ + (pl.col("followup") + 1).alias("followup"), + (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"), + ] + ) ).sort([self.id_col, "trial", "followup"]) - + risks = [] for treatment_val in self.treatment_level: - TxDT = SDT.with_columns([ - pl.lit(treatment_val) - .alias(f"{self.treatment_col}{self.indicator_baseline}") - ]) - + TxDT = SDT.with_columns( + [ + pl.lit(treatment_val).alias( + f"{self.treatment_col}{self.indicator_baseline}" + ) + ] + ) + if self.method == "dose-response": if treatment_val == self.treatment_level[0]: - TxDT = TxDT.with_columns([ - pl.lit(0.0).alias("dose"), - pl.lit(0.0).alias("dose_sq") - ]) + TxDT = TxDT.with_columns( + [pl.lit(0.0).alias("dose"), pl.lit(0.0).alias("dose_sq")] + ) else: - TxDT = TxDT.with_columns([ - pl.col("followup").alias("dose"), - pl.col(f"followup{self.indicator_squared}").alias("dose_sq") - ]) - + TxDT = TxDT.with_columns( + [ + pl.col("followup").alias("dose"), + pl.col(f"followup{self.indicator_squared}").alias("dose_sq"), + ] + ) + preds = _get_outcome_predictions(self, TxDT, idx=idx) pred_series = [pl.Series("pred_outcome", preds["outcome"][0])] - + if self.bootstrap_nboot > 0: for boot_idx, pred in enumerate(preds["outcome"][1:], start=1): pred_series.append(pl.Series(f"pred_outcome_{boot_idx}", pred)) - + if self.compevent_colname is not None: pred_series.append(pl.Series("pred_ce", preds["compevent"][0])) if self.bootstrap_nboot > 0: for boot_idx, pred in enumerate(preds["compevent"][1:], start=1): pred_series.append(pl.Series(f"pred_ce_{boot_idx}", pred)) - + outcome_names = [s.name for s in pred_series if "outcome" in s.name] ce_names = [s.name for s in pred_series if "ce" in s.name] - + TxDT = TxDT.with_columns(pred_series) - + if self.compevent_colname is not None: for out_col, ce_col in zip(outcome_names, ce_names): surv_col = out_col.replace("pred_outcome", "surv") cce_col = out_col.replace("pred_outcome", "cce") inc_col = out_col.replace("pred_outcome", "inc") - - TxDT = ( - TxDT.with_columns([ - (1 - pl.col(out_col)) + + TxDT = TxDT.with_columns( + [ + (1 - pl.col(out_col)).cum_prod().over("TID").alias(surv_col), + ((1 - pl.col(out_col)) * (1 - pl.col(ce_col))) .cum_prod() .over("TID") - .alias(surv_col), - ((1 - pl.col(out_col)) * (1 - pl.col(ce_col))) - .cum_prod().over("TID") - .alias(cce_col) - ]) - .with_columns([ + .alias(cce_col), + ] + ).with_columns( + [ (pl.col(out_col) * (1 - pl.col(ce_col)) * pl.col(cce_col)) .cum_sum() .over("TID") .alias(inc_col) - ]) + ] ) - + surv_names = [n.replace("pred_outcome", "surv") for n in outcome_names] inc_names = [n.replace("pred_outcome", "inc") for n in outcome_names] - TxDT = TxDT.group_by("followup").agg([pl.col(col).mean() for col in surv_names + inc_names]).sort("followup") + TxDT = ( + TxDT.group_by("followup") + .agg([pl.col(col).mean() for col in surv_names + inc_names]) + .sort("followup") + ) main_col = "surv" boot_cols = [col for col in surv_names if col != "surv"] else: TxDT = ( - TxDT.with_columns([(1 - pl.col(col)) - .cum_prod() - .over("TID") - .alias(col) for col in outcome_names - ]) - .group_by("followup").agg([pl.col(col).mean() for col in outcome_names]) + TxDT.with_columns( + [ + (1 - pl.col(col)).cum_prod().over("TID").alias(col) + for col in outcome_names + ] + ) + .group_by("followup") + .agg([pl.col(col).mean() for col in outcome_names]) .sort("followup") .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names]) ) main_col = "pred_outcome" boot_cols = [col for col in outcome_names if col != "pred_outcome"] - + if boot_cols: risk = ( TxDT.select(["followup"] + boot_cols) - .unpivot(index="followup", on=boot_cols, variable_name="bootID", value_name="risk") - .group_by("followup").agg([ - pl.col("risk").std().alias("SE"), - pl.col("risk").quantile(lci).alias("LCI"), - pl.col("risk").quantile(uci).alias("UCI") - ]) + .unpivot( + index="followup", + on=boot_cols, + variable_name="bootID", + value_name="risk", + ) + .group_by("followup") + .agg( + [ + pl.col("risk").std().alias("SE"), + pl.col("risk").quantile(lci).alias("LCI"), + pl.col("risk").quantile(uci).alias("UCI"), + ] + ) .join(TxDT.select(["followup", main_col]), on="followup") ) - + if self.bootstrap_CI_method == "se": from scipy.stats import norm + z = norm.ppf(1 - a / 2) - risk = risk.with_columns([ - (pl.col(main_col) - z * pl.col("SE")).alias("LCI"), - (pl.col(main_col) + z * pl.col("SE")).alias("UCI") - ]) - + risk = risk.with_columns( + [ + (pl.col(main_col) - z * pl.col("SE")).alias("LCI"), + (pl.col(main_col) + z * pl.col("SE")).alias("UCI"), + ] + ) + fup0_val = 1.0 if self.compevent_colname else 0.0 - + if self.compevent_colname is not None: inc_boot_cols = [col for col in inc_names if col != "inc"] if inc_boot_cols: inc_risk = ( TxDT.select(["followup"] + inc_boot_cols) - .unpivot(index="followup", on=inc_boot_cols, variable_name="bootID", value_name="inc_val") - .group_by("followup").agg([ - pl.col("inc_val").std().alias("inc_SE"), - pl.col("inc_val").quantile(lci).alias("inc_LCI"), - pl.col("inc_val").quantile(uci).alias("inc_UCI") - ]) + .unpivot( + index="followup", + on=inc_boot_cols, + variable_name="bootID", + value_name="inc_val", + ) + .group_by("followup") + .agg( + [ + pl.col("inc_val").std().alias("inc_SE"), + pl.col("inc_val").quantile(lci).alias("inc_LCI"), + pl.col("inc_val").quantile(uci).alias("inc_UCI"), + ] + ) .join(TxDT.select(["followup", "inc"]), on="followup") ) risk = risk.join(inc_risk, on="followup") - final_cols = ["followup", main_col, "SE", "LCI", "UCI", "inc", "inc_SE", "inc_LCI", "inc_UCI"] - risk = risk.select(final_cols).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - "SE": [0.0], - "LCI": [fup0_val], - "UCI": [fup0_val], - "inc": [0.0], - "inc_SE": [0.0], - "inc_LCI": [0.0], - "inc_UCI": [0.0], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) + final_cols = [ + "followup", + main_col, + "SE", + "LCI", + "UCI", + "inc", + "inc_SE", + "inc_LCI", + "inc_UCI", + ] + risk = risk.select(final_cols).with_columns( + pl.lit(treatment_val).alias(self.treatment_col) + ) + + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + "SE": [0.0], + "LCI": [fup0_val], + "UCI": [fup0_val], + "inc": [0.0], + "inc_SE": [0.0], + "inc_LCI": [0.0], + "inc_UCI": [0.0], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) else: - risk = risk.select(["followup", main_col, "SE", "LCI", "UCI"]).with_columns( - pl.lit(treatment_val) - .alias(self.treatment_col) - ) - fup0 = pl.DataFrame({ + risk = risk.select( + ["followup", main_col, "SE", "LCI", "UCI"] + ).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + "SE": [0.0], + "LCI": [fup0_val], + "UCI": [fup0_val], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + else: + risk = risk.select( + ["followup", main_col, "SE", "LCI", "UCI"] + ).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) + fup0 = pl.DataFrame( + { "followup": [0], main_col: [fup0_val], "SE": [0.0], "LCI": [fup0_val], "UCI": [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) - else: - risk = risk.select(["followup", main_col, "SE", "LCI", "UCI"]).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - "SE": [0.0], - "LCI": [fup0_val], - "UCI": [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32) - ]) + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) else: fup0_val = 1.0 if self.compevent_colname else 0.0 - risk = TxDT.select(["followup", main_col]).with_columns(pl.lit(treatment_val).alias(self.treatment_col)) - fup0 = pl.DataFrame({ - "followup": [0], - main_col: [fup0_val], - self.treatment_col: [treatment_val] - }).with_columns([ - pl.col("followup").cast(pl.Int64), - pl.col(self.treatment_col).cast(pl.Int32)]) - + risk = TxDT.select(["followup", main_col]).with_columns( + pl.lit(treatment_val).alias(self.treatment_col) + ) + fup0 = pl.DataFrame( + { + "followup": [0], + main_col: [fup0_val], + self.treatment_col: [treatment_val], + } + ).with_columns( + [ + pl.col("followup").cast(pl.Int64), + pl.col(self.treatment_col).cast(pl.Int32), + ] + ) + if self.compevent_colname is not None: risk = risk.join(TxDT.select(["followup", "inc"]), on="followup") fup0 = fup0.with_columns([pl.lit(0.0).alias("inc")]) - + risks.append(pl.concat([fup0, risk])) out = pl.concat(risks) - + if self.compevent_colname is not None: has_ci = "SE" in out.columns - + surv_cols = ["followup", self.treatment_col, "surv"] if has_ci: surv_cols.extend(["SE", "LCI", "UCI"]) - surv_out = out.select(surv_cols).rename({"surv": "pred"}).with_columns(pl.lit("survival").alias("estimate")) - + surv_out = ( + out.select(surv_cols) + .rename({"surv": "pred"}) + .with_columns(pl.lit("survival").alias("estimate")) + ) + risk_cols = ["followup", self.treatment_col, (1 - pl.col("surv")).alias("pred")] if has_ci: - risk_cols.extend([pl.col("SE"), (1 - pl.col("UCI")).alias("LCI"), (1 - pl.col("LCI")).alias("UCI")]) + risk_cols.extend( + [ + pl.col("SE"), + (1 - pl.col("UCI")).alias("LCI"), + (1 - pl.col("LCI")).alias("UCI"), + ] + ) risk_out = out.select(risk_cols).with_columns(pl.lit("risk").alias("estimate")) - + inc_cols = ["followup", self.treatment_col, pl.col("inc").alias("pred")] if has_ci: - inc_cols.extend([pl.col("inc_SE").alias("SE"), pl.col("inc_LCI").alias("LCI"), pl.col("inc_UCI").alias("UCI")]) - inc_out = out.select(inc_cols).with_columns(pl.lit("incidence").alias("estimate")) - + inc_cols.extend( + [ + pl.col("inc_SE").alias("SE"), + pl.col("inc_LCI").alias("LCI"), + pl.col("inc_UCI").alias("UCI"), + ] + ) + inc_out = out.select(inc_cols).with_columns( + pl.lit("incidence").alias("estimate") + ) + out = pl.concat([surv_out, risk_out, inc_out]) else: - out = out.rename({"pred_outcome": "pred"}).with_columns(pl.lit("risk").alias("estimate")) - + out = out.rename({"pred_outcome": "pred"}).with_columns( + pl.lit("risk").alias("estimate") + ) + if val is not None: out = out.with_columns(pl.lit(val).alias(self.subgroup_colname)) - + return out + def _calculate_survival(self, risk_data): if self.bootstrap_nboot > 0: - surv = risk_data.with_columns([ - (1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"] - ]).with_columns(pl.lit("survival").alias("estimate")) - else: - surv = risk_data.with_columns([ - (1 - pl.col("pred")).alias("pred"), - pl.lit("survival").alias("estimate") - ]) + surv = risk_data.with_columns( + [(1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"]] + ).with_columns(pl.lit("survival").alias("estimate")) + else: + surv = risk_data.with_columns( + [(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")] + ) return surv diff --git a/pySEQTarget/data/__init__.py b/pySEQTarget/data/__init__.py index 50e5493..37c47d3 100644 --- a/pySEQTarget/data/__init__.py +++ b/pySEQTarget/data/__init__.py @@ -1,6 +1,7 @@ from importlib.resources import files import polars as pl + def load_data(name: str = "SEQdata") -> pl.DataFrame: loc = files("pySEQ.data") if name in ["SEQdata", "SEQdata_multitreatment", "SEQdata_LTFU"]: @@ -12,4 +13,6 @@ def load_data(name: str = "SEQdata") -> pl.DataFrame: data_path = loc.joinpath("SEQdata_LTFU.csv") return pl.read_csv(data_path) else: - raise ValueError(f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']") \ No newline at end of file + raise ValueError( + f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']" + ) diff --git a/pySEQTarget/docs/source/conf.py b/pySEQTarget/docs/source/conf.py index 994268a..6b536b6 100644 --- a/pySEQTarget/docs/source/conf.py +++ b/pySEQTarget/docs/source/conf.py @@ -6,23 +6,22 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'SEQuential' +project = "SEQuential" copyright = "2024, Ryan O'Dea, Alejandro Szmulewicz" author = "Ryan O'Dea, Alejandro Szmulewicz" -release = '0.1.0' +release = "0.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'alabaster' -html_static_path = ['_static'] +html_theme = "alabaster" +html_static_path = ["_static"] diff --git a/pySEQTarget/error/__init__.py b/pySEQTarget/error/__init__.py index c9ee5ee..3aa4a33 100644 --- a/pySEQTarget/error/__init__.py +++ b/pySEQTarget/error/__init__.py @@ -1,2 +1,2 @@ from ._param_checker import _param_checker -from ._datachecker import _datachecker \ No newline at end of file +from ._datachecker import _datachecker diff --git a/pySEQTarget/error/_datachecker.py b/pySEQTarget/error/_datachecker.py index da8277f..054c581 100644 --- a/pySEQTarget/error/_datachecker.py +++ b/pySEQTarget/error/_datachecker.py @@ -1,11 +1,11 @@ import polars as pl + def _datachecker(self): - check = self.data.group_by(self.id_col).agg([ - pl.len().alias("row_count"), - pl.col(self.time_col).max().alias("max_time") - ]) - + check = self.data.group_by(self.id_col).agg( + [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")] + ) + invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1) if len(invalid) > 0: raise ValueError( @@ -13,17 +13,26 @@ def _datachecker(self): f"This suggests invalid times" f"Invalid IDs:\n{invalid}" ) - + for col in self.excused_colnames: - violations = self.data.sort([self.id_col, self.time_col]).group_by(self.id_col).agg([ - ((pl.col(col).cum_sum().shift(1, fill_value=0) > 0) & (pl.col(col) == 0)) - .any() - .alias("has_violation") - ]).filter(pl.col("has_violation")) - + violations = ( + self.data.sort([self.id_col, self.time_col]) + .group_by(self.id_col) + .agg( + [ + ( + (pl.col(col).cum_sum().shift(1, fill_value=0) > 0) + & (pl.col(col) == 0) + ) + .any() + .alias("has_violation") + ] + ) + .filter(pl.col("has_violation")) + ) + if len(violations) > 0: raise ValueError( f"Column '{col}' violates 'once one, always one' rule for excusing treatment " f"{len(violations)} ID(s) have zeros after ones." ) - \ No newline at end of file diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 2d90168..3a96448 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -1,32 +1,43 @@ from ..helpers import _pad + def _param_checker(self): - if self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols: + if ( + self.subgroup_colname is not None + and self.subgroup_colname not in self.fixed_cols + ): raise ValueError("subgroup_colname must be included in fixed_cols.") - + if self.followup_max is None: - self.followup_max = self.data.select(self.time_col).to_series().max() - + self.followup_max = self.data.select(self.time_col).to_series().max() + if len(self.excused_colnames) == 0 and self.excused: self.excused = False - raise Warning("Excused column names not provided but excused is set to True. Automatically set excused to False") - + raise Warning( + "Excused column names not provided but excused is set to True. Automatically set excused to False" + ) + if len(self.excused_colnames) > 0 and not self.excused: self.excused = True - raise Warning("Excused column names provided but excused is set to False. Automatically set excused to True") - + raise Warning( + "Excused column names provided but excused is set to False. Automatically set excused to True" + ) + if self.km_curves and self.hazard_estimate: raise ValueError("km_curves and hazard cannot both be set to True.") - + if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1: - raise ValueError("Only one of followup_class or followup_include can be set to True.") - + raise ValueError( + "Only one of followup_class or followup_include can be set to True." + ) + if self.weighted and self.method == "ITT" and self.cense_colname is None: raise ValueError("For weighted ITT analyses, cense_colname must be provided.") - + if self.excused: _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames) - _, self.weight_eligible_colnames = _pad(self.treatment_level, self.weight_eligible_colnames) - + _, self.weight_eligible_colnames = _pad( + self.treatment_level, self.weight_eligible_colnames + ) + return - \ No newline at end of file diff --git a/pySEQTarget/expansion/__init__.py b/pySEQTarget/expansion/__init__.py index 60947bb..32598ec 100644 --- a/pySEQTarget/expansion/__init__.py +++ b/pySEQTarget/expansion/__init__.py @@ -2,4 +2,4 @@ from ._dynamic import _dynamic from ._mapper import _mapper from ._selection import _random_selection -from ._diagnostics import _diagnostics \ No newline at end of file +from ._diagnostics import _diagnostics diff --git a/pySEQTarget/expansion/_binder.py b/pySEQTarget/expansion/_binder.py index 389ebac..c459002 100644 --- a/pySEQTarget/expansion/_binder.py +++ b/pySEQTarget/expansion/_binder.py @@ -1,71 +1,97 @@ import polars as pl from ._mapper import _mapper + def _binder(self, kept_cols): """ Internal function to bind data to the map created by __mapper """ - excluded = {"dose", - f"dose{self.indicator_squared}", - "followup", - f"followup{self.indicator_squared}", - "tx_lag", - "trial", - f"trial{self.indicator_squared}", - self.time_col, - f"{self.time_col}{self.indicator_squared}"} - + excluded = { + "dose", + f"dose{self.indicator_squared}", + "followup", + f"followup{self.indicator_squared}", + "tx_lag", + "trial", + f"trial{self.indicator_squared}", + self.time_col, + f"{self.time_col}{self.indicator_squared}", + } + cols = kept_cols.union({self.eligible_col, self.outcome_col, self.treatment_col}) cols = {col for col in cols if col is not None} - regular = {col for col in cols if not (self.indicator_baseline in col or self.indicator_squared in col) and col not in excluded} - - baseline = {col for col in cols if self.indicator_baseline in col and col not in excluded} + regular = { + col + for col in cols + if not (self.indicator_baseline in col or self.indicator_squared in col) + and col not in excluded + } + + baseline = { + col for col in cols if self.indicator_baseline in col and col not in excluded + } bas_kept = {col.replace(self.indicator_baseline, "") for col in baseline} - - squared = {col for col in cols if self.indicator_squared in col and col not in excluded} + + squared = { + col for col in cols if self.indicator_squared in col and col not in excluded + } sq_kept = {col.replace(self.indicator_squared, "") for col in squared} - + kept = list(regular.union(bas_kept).union(sq_kept)) - + if self.selection_first_trial: - DT = self.data.sort([self.id_col, self.time_col]) \ - .with_columns([ - pl.col(self.time_col).alias("period"), - pl.col(self.time_col).alias("followup"), - pl.lit(0).alias("trial") - ]).drop(self.time_col) + DT = ( + self.data.sort([self.id_col, self.time_col]) + .with_columns( + [ + pl.col(self.time_col).alias("period"), + pl.col(self.time_col).alias("followup"), + pl.lit(0).alias("trial"), + ] + ) + .drop(self.time_col) + ) else: - DT = _mapper(self.data, self.id_col, self.time_col, self.followup_min, self.followup_max) + DT = _mapper( + self.data, self.id_col, self.time_col, self.followup_min, self.followup_max + ) DT = DT.join( self.data.select([self.id_col, self.time_col] + kept), left_on=[self.id_col, "period"], right_on=[self.id_col, self.time_col], - how="left" - ) - DT = DT.sort([self.id_col, "trial", "followup"]) \ - .with_columns([ + how="left", + ) + DT = DT.sort([self.id_col, "trial", "followup"]).with_columns( + [ (pl.col("trial") ** 2).alias(f"trial{self.indicator_squared}"), - (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}") - ]) - + (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"), + ] + ) + if squared: squares = [] for sq in squared: col = sq.replace(self.indicator_squared, "") squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}")) DT = DT.with_columns(squares) - + baseline_cols = {bas.replace(self.indicator_baseline, "") for bas in baseline} needed = {self.eligible_col, self.treatment_col} baseline_cols.update({c for c in needed}) - + bas = [ - pl.col(c).first().over([self.id_col, "trial"]).alias(f"{c}{self.indicator_baseline}") + pl.col(c) + .first() + .over([self.id_col, "trial"]) + .alias(f"{c}{self.indicator_baseline}") for c in baseline_cols ] - - DT = DT.with_columns(bas).filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1) \ - .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col]) - - return DT \ No newline at end of file + + DT = ( + DT.with_columns(bas) + .filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1) + .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col]) + ) + + return DT diff --git a/pySEQTarget/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py index 726312f..178062a 100644 --- a/pySEQTarget/expansion/_diagnostics.py +++ b/pySEQTarget/expansion/_diagnostics.py @@ -1,56 +1,53 @@ import polars as pl + def _diagnostics(self): - unique_out = _outcome_diag(self, unique = True) - nonunique_out = _outcome_diag(self, unique = False) - out = {"unique_outcomes": unique_out, - "nonunique_outcomes": nonunique_out} - + unique_out = _outcome_diag(self, unique=True) + nonunique_out = _outcome_diag(self, unique=False) + out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out} + if self.method == "censoring": - unique_switch = _switch_diag(self, unique = True) - nonunique_switch = _switch_diag(self, unique = False) - out.update({"unique_switches": unique_switch, - "nonunique_switches": nonunique_switch}) - + unique_switch = _switch_diag(self, unique=True) + nonunique_switch = _switch_diag(self, unique=False) + out.update( + {"unique_switches": unique_switch, "nonunique_switches": nonunique_switch} + ) + self.diagnostics = out - + + def _outcome_diag(self, unique): if unique: - data = self.DT.select([ - self.id_col, - self.treatment_col, - self.outcome_col - ]).group_by(self.id_col).last() - else: + data = ( + self.DT.select([self.id_col, self.treatment_col, self.outcome_col]) + .group_by(self.id_col) + .last() + ) + else: data = self.DT - out = data.group_by([self.treatment_col, - self.outcome_col]).len() - + out = data.group_by([self.treatment_col, self.outcome_col]).len() + return out + def _switch_diag(self, unique): if not self.excused: - data = self.DT.with_columns(pl.lit(False) - .alias("isExcused")) + data = self.DT.with_columns(pl.lit(False).alias("isExcused")) else: data = self.DT - + if unique: - data = data.select([ - self.id_col, - self.treatment_col, - "switch", - "isExcused" - ]).with_columns( - pl.when((pl.col("switch") == 0) & - (pl.col("isExcused"))) - .then(1) - .otherwise(pl.col("switch")) - .alias("switch") - ).group_by(self.id_col).last() - - out = data.group_by([self.treatment_col, - "isExcused", - "switch"]).len() + data = ( + data.select([self.id_col, self.treatment_col, "switch", "isExcused"]) + .with_columns( + pl.when((pl.col("switch") == 0) & (pl.col("isExcused"))) + .then(1) + .otherwise(pl.col("switch")) + .alias("switch") + ) + .group_by(self.id_col) + .last() + ) + + out = data.group_by([self.treatment_col, "isExcused", "switch"]).len() return out - \ No newline at end of file diff --git a/pySEQTarget/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py index 8c8b53d..d2a99c5 100644 --- a/pySEQTarget/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -1,5 +1,6 @@ import polars as pl + def _dynamic(self): """ Handles special cases for the data from the __mapper -> __binder pipeline @@ -10,12 +11,9 @@ def _dynamic(self): .cum_sum() .over([self.id_col, "trial"]) .alias("dose") - ).with_columns([ - (pl.col("dose") ** 2) - .alias(f"dose{self.indicator_squared}") - ]) + ).with_columns([(pl.col("dose") ** 2).alias(f"dose{self.indicator_squared}")]) self.DT = DT - + elif self.method == "censoring": DT = self.DT.sort([self.id_col, "trial", "followup"]).with_columns( pl.col(self.treatment_col) @@ -23,13 +21,13 @@ def _dynamic(self): .over([self.id_col, "trial"]) .alias("tx_lag") ) - + switch = ( pl.when(pl.col("followup") == 0) .then(pl.lit(False)) .otherwise( - (pl.col("tx_lag").is_not_null()) & - (pl.col("tx_lag") != pl.col(self.treatment_col)) + (pl.col("tx_lag").is_not_null()) + & (pl.col("tx_lag") != pl.col(self.treatment_col)) ) ) is_excused = pl.lit(False) @@ -39,35 +37,25 @@ def _dynamic(self): colname = self.excused_colnames[i] if colname is not None: conditions.append( - (pl.col(colname) == 1) & - (pl.col(self.treatment_col) == self.treatment_level[i]) + (pl.col(colname) == 1) + & (pl.col(self.treatment_col) == self.treatment_level[i]) ) - + if conditions: excused = pl.any_horizontal(conditions) is_excused = switch & excused - switch = ( - pl.when(excused) - .then(pl.lit(False)) - .otherwise(switch) - ) - - DT = DT.with_columns([ - switch.alias("switch"), - is_excused.alias("isExcused") - ]).sort([self.id_col, "trial", "followup"]) \ + switch = pl.when(excused).then(pl.lit(False)).otherwise(switch) + + DT = ( + DT.with_columns([switch.alias("switch"), is_excused.alias("isExcused")]) + .sort([self.id_col, "trial", "followup"]) .filter( - ( - pl.col("switch") - .cum_max() - .shift(1, fill_value=False) + (pl.col("switch").cum_max().shift(1, fill_value=False)).over( + [self.id_col, "trial"] ) - .over([self.id_col, "trial"]) == 0 - ).with_columns( - pl.col("switch") - .cast(pl.Int8) - .alias("switch") + ) + .with_columns(pl.col("switch").cast(pl.Int8).alias("switch")) ) - - self.DT = DT.drop(["tx_lag"]) \ No newline at end of file + + self.DT = DT.drop(["tx_lag"]) diff --git a/pySEQTarget/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py index e2cfe80..b743a81 100644 --- a/pySEQTarget/expansion/_mapper.py +++ b/pySEQTarget/expansion/_mapper.py @@ -1,30 +1,43 @@ import polars as pl import math + def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf): """ Internal function to create the expanded map to bind data to. """ - + DT = ( - data.select([pl.col(id_col), pl.col(time_col)]) - .with_columns([ - pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial") - ]) - .with_columns([ - pl.struct([pl.col(time_col), pl.col(time_col).max().over(id_col).alias("max_time")]) - .map_elements(lambda x: list(range(x[time_col], x["max_time"] + 1)), - return_dtype=pl.List(pl.Int64)) - .alias("period") - ]) + data.select([pl.col(id_col), pl.col(time_col)]) + .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")]) + .with_columns( + [ + pl.struct( + [ + pl.col(time_col), + pl.col(time_col).max().over(id_col).alias("max_time"), + ] + ) + .map_elements( + lambda x: list(range(x[time_col], x["max_time"] + 1)), + return_dtype=pl.List(pl.Int64), + ) + .alias("period") + ] + ) .explode("period") .drop(pl.col(time_col)) - .with_columns([ - pl.col(id_col).cum_count().over([id_col, "trial"]).sub(1).alias("followup") - ]) + .with_columns( + [ + pl.col(id_col) + .cum_count() + .over([id_col, "trial"]) + .sub(1) + .alias("followup") + ] + ) .filter( - (pl.col("followup") >= min_followup) & - (pl.col("followup") <= max_followup) + (pl.col("followup") >= min_followup) & (pl.col("followup") <= max_followup) ) ) return DT diff --git a/pySEQTarget/expansion/_selection.py b/pySEQTarget/expansion/_selection.py index 52e39f5..9f58cd7 100644 --- a/pySEQTarget/expansion/_selection.py +++ b/pySEQTarget/expansion/_selection.py @@ -1,29 +1,31 @@ import polars as pl + + def _random_selection(self): """ - Handles the case where random selection is applied for data from + Handles the case where random selection is applied for data from the __mapper -> __binder -> optionally __dynamic pipeline """ - UIDs = self.DT.select([ - self.id_col, - "trial", - f"{self.treatment_col}{self.indicator_baseline}"]) \ - .with_columns( - (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")) \ - .filter( - pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0) \ - .unique("trialID").to_series().to_list() - + UIDs = ( + self.DT.select( + [self.id_col, "trial", f"{self.treatment_col}{self.indicator_baseline}"] + ) + .with_columns((pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")) + .filter(pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0) + .unique("trialID") + .to_series() + .to_list() + ) + NIDs = len(UIDs) sample = self._rng.choice( - UIDs, - size=int(self.selection_probability * NIDs), - replace=False + UIDs, size=int(self.selection_probability * NIDs), replace=False + ) + + self.DT = ( + self.DT.with_columns( + (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID") + ) + .filter(pl.col("trialID").is_in(sample)) + .drop("trialID") ) - - self.DT = self.DT.with_columns( - (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID") - ).filter( - pl.col("trialID").is_in(sample) - ).drop("trialID") - \ No newline at end of file diff --git a/pySEQTarget/helpers/__init__.py b/pySEQTarget/helpers/__init__.py index 8f9a23b..4fe0cae 100644 --- a/pySEQTarget/helpers/__init__.py +++ b/pySEQTarget/helpers/__init__.py @@ -3,4 +3,4 @@ from ._format_time import _format_time from ._predict_model import _predict_model from ._prepare_data import _prepare_data -from ._pad import _pad \ No newline at end of file +from ._pad import _pad diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 5e2987e..40447b9 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -7,78 +7,103 @@ import time from ._format_time import _format_time + def _prepare_boot_data(self, data, boot_id): id_counts = self._boot_samples[boot_id] - - counts = pl.DataFrame({ - self.id_col: list(id_counts.keys()), - "count": list(id_counts.values()) - }) - + + counts = pl.DataFrame( + {self.id_col: list(id_counts.keys()), "count": list(id_counts.values())} + ) + bootstrapped = data.join(counts, on=self.id_col, how="inner") - bootstrapped = bootstrapped.with_columns( - pl.int_ranges(0, pl.col("count")).alias("replicate") - ).explode("replicate").with_columns( - (pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("replicate").cast(pl.Utf8)).alias(self.id_col) - ).drop("count", "replicate") - + bootstrapped = ( + bootstrapped.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) + .explode("replicate") + .with_columns( + ( + pl.col(self.id_col).cast(pl.Utf8) + + "_" + + pl.col("replicate").cast(pl.Utf8) + ).alias(self.id_col) + ) + .drop("count", "replicate") + ) + return bootstrapped + def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs): obj = copy.deepcopy(obj) - obj._rng = np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() + obj._rng = ( + np.random.RandomState(seed + i) if seed is not None else np.random.RandomState() + ) obj.DT = _prepare_boot_data(obj, original_DT, i) - + # Disable bootstrapping to prevent recursion obj.bootstrap_nboot = 0 - + method = getattr(obj, method_name) result = method(*args, **kwargs) obj._rng = None return result + def bootstrap_loop(method): @wraps(method) def wrapper(self, *args, **kwargs): if not hasattr(self, "outcome_model"): self.outcome_model = [] start = time.perf_counter() - + results = [] full = method(self, *args, **kwargs) results.append(full) - - if getattr(self, "bootstrap_nboot") > 0 and getattr(self, "_boot_samples", None): + + if getattr(self, "bootstrap_nboot") > 0 and getattr( + self, "_boot_samples", None + ): original_DT = self.DT nboot = self.bootstrap_nboot ncores = self.ncores seed = getattr(self, "seed", None) method_name = method.__name__ - + if getattr(self, "parallel", False): original_rng = getattr(self, "_rng", None) self._rng = None - + with ProcessPoolExecutor(max_workers=ncores) as executor: futures = [ - executor.submit(_bootstrap_worker, self, method_name, original_DT, i, seed, args, kwargs) + executor.submit( + _bootstrap_worker, + self, + method_name, + original_DT, + i, + seed, + args, + kwargs, + ) for i in range(nboot) ] - for j in tqdm(as_completed(futures), total=nboot, desc="Bootstrapping..."): + for j in tqdm( + as_completed(futures), total=nboot, desc="Bootstrapping..." + ): results.append(j.result()) - + self._rng = original_rng else: for i in tqdm(range(nboot), desc="Bootstrapping..."): self.DT = _prepare_boot_data(self, original_DT, i) boot_fit = method(self, *args, **kwargs) results.append(boot_fit) - + self.DT = original_DT - + end = time.perf_counter() self._model_time = _format_time(start, end) - + self.outcome_model = results return results + return wrapper diff --git a/pySEQTarget/helpers/_col_string.py b/pySEQTarget/helpers/_col_string.py index 62b9f16..907ef80 100644 --- a/pySEQTarget/helpers/_col_string.py +++ b/pySEQTarget/helpers/_col_string.py @@ -3,4 +3,4 @@ def _col_string(expressions): for expression in expressions: if expression is not None: cols.update(expression.replace("+", " ").replace("*", " ").split()) - return cols \ No newline at end of file + return cols diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 08ecc46..5ddd731 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -1,8 +1,9 @@ import numpy as np + def _predict_model(self, model, newdata): newdata = newdata.to_pandas() for col in self.fixed_cols: if col in newdata.columns: - newdata[col] = newdata[col].astype("category") + newdata[col] = newdata[col].astype("category") return np.array(model.predict(newdata)) diff --git a/pySEQTarget/helpers/_prepare_data.py b/pySEQTarget/helpers/_prepare_data.py index efc0e3f..682e23d 100644 --- a/pySEQTarget/helpers/_prepare_data.py +++ b/pySEQTarget/helpers/_prepare_data.py @@ -1,14 +1,19 @@ import polars as pl + def _prepare_data(self, DT): - binaries = [self.eligible_col, self.outcome_col, self.cense_colname] # self.excused_colnames + self.weight_eligible_colnames + binaries = [ + self.eligible_col, + self.outcome_col, + self.cense_colname, + ] # self.excused_colnames + self.weight_eligible_colnames binary_colnames = [col for col in binaries if col in DT.columns and not None] DT = DT.with_columns( - [ - *[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols], - *[pl.col(col).cast(pl.Int8) for col in binary_colnames], - pl.col(self.id_col).cast(pl.Utf8), - ] - ) + [ + *[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols], + *[pl.col(col).cast(pl.Int8) for col in binary_colnames], + pl.col(self.id_col).cast(pl.Utf8), + ] + ) return DT diff --git a/pySEQTarget/initialization/__init__.py b/pySEQTarget/initialization/__init__.py index 004b0ce..32173e2 100644 --- a/pySEQTarget/initialization/__init__.py +++ b/pySEQTarget/initialization/__init__.py @@ -1,4 +1,4 @@ from ._outcome import _outcome from ._censoring import _cense_numerator, _cense_denominator from ._numerator import _numerator -from ._denominator import _denominator \ No newline at end of file +from ._denominator import _denominator diff --git a/pySEQTarget/initialization/_censoring.py b/pySEQTarget/initialization/_censoring.py index 94ce6a9..828d584 100644 --- a/pySEQTarget/initialization/_censoring.py +++ b/pySEQTarget/initialization/_censoring.py @@ -1,29 +1,53 @@ def _cense_numerator(self) -> str: - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: out = "+".join(filter(None, ["tx_lag", time, fixed])) else: out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv_bas])) - + return out + def _cense_denominator(self) -> str: - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: out = "+".join(filter(None, ["tx_lag", time, fixed, tv])) else: out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv, tv_bas])) - - return out + return out diff --git a/pySEQTarget/initialization/_denominator.py b/pySEQTarget/initialization/_denominator.py index 4d74bdc..0a081b3 100644 --- a/pySEQTarget/initialization/_denominator.py +++ b/pySEQTarget/initialization/_denominator.py @@ -1,14 +1,26 @@ def _denominator(self) -> str: if self.method == "ITT": return - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - + tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: if self.method == "dose-response": out = "+".join(filter(None, [fixed, tv, time])) @@ -24,4 +36,4 @@ def _denominator(self) -> str: elif self.method == "censoring" and self.excused: out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial])) - return out \ No newline at end of file + return out diff --git a/pySEQTarget/initialization/_numerator.py b/pySEQTarget/initialization/_numerator.py index 8f232d6..b47ba23 100644 --- a/pySEQTarget/initialization/_numerator.py +++ b/pySEQTarget/initialization/_numerator.py @@ -1,13 +1,25 @@ def _numerator(self) -> str: if self.method == "ITT": return - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - followup = "+".join(["followup", f"followup{self.indicator_squared}"]) if self.followup_include else None + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + followup = ( + "+".join(["followup", f"followup{self.indicator_squared}"]) + if self.followup_include + else None + ) time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"]) - - tv_bas = "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None + + tv_bas = ( + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) + if self.time_varying_cols + else None + ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - + if self.weight_preexpansion: if self.method == "dose-response": out = "+".join(filter(None, [fixed, time])) @@ -22,4 +34,4 @@ def _numerator(self) -> str: out = "+".join(filter(None, [fixed, tv_bas, followup, trial])) elif self.method == "censoring" and self.excused: out = "+".join(filter(None, [fixed, tv_bas, followup, trial])) - return out \ No newline at end of file + return out diff --git a/pySEQTarget/initialization/_outcome.py b/pySEQTarget/initialization/_outcome.py index 8e30e9f..4ec0fc9 100644 --- a/pySEQTarget/initialization/_outcome.py +++ b/pySEQTarget/initialization/_outcome.py @@ -2,28 +2,36 @@ def _outcome(self) -> str: tx_bas = f"{self.treatment_col}{self.indicator_baseline}" dose = "+".join(["dose", f"dose{self.indicator_squared}"]) interaction = f"{tx_bas}*followup" - interaction_dose = "+".join(["followup*dose", f"followup*dose{self.indicator_squared}"]) - + interaction_dose = "+".join( + ["followup*dose", f"followup*dose{self.indicator_squared}"] + ) + if self.hazard or not self.km_curves: interaction = interaction_dose = None - + tv_bas = ( - "+".join([f"{v}_bas" for v in self.time_varying_cols]) if self.time_varying_cols else None + "+".join([f"{v}_bas" for v in self.time_varying_cols]) + if self.time_varying_cols + else None ) fixed = "+".join(self.fixed_cols) if self.fixed_cols else None - trial = "+".join(["trial", f"trial{self.indicator_squared}"]) if self.trial_include else None - + trial = ( + "+".join(["trial", f"trial{self.indicator_squared}"]) + if self.trial_include + else None + ) + if self.followup_include: followup = "+".join(["followup", f"followup{self.indicator_squared}"]) elif (self.followup_spline or self.followup_class) and not self.followup_include: followup = "followup" else: followup = None - + if self.method == "ITT": parts = [tx_bas, followup, trial, fixed, tv_bas, interaction] return "+".join(filter(None, parts)) - + if self.weighted: if self.weight_preexpansion: if self.method == "dose-response": @@ -39,7 +47,7 @@ def _outcome(self) -> str: elif self.method == "censoring": parts = [tx_bas, followup, trial, fixed, tv_bas, interaction] return "+".join(filter(None, parts)) - + if self.method == "dose-response": parts = [dose, followup, trial, fixed, tv_bas, interaction_dose] elif self.method == "censoring": diff --git a/pySEQTarget/plot/__init__.py b/pySEQTarget/plot/__init__.py index 6d77c39..fba45b1 100644 --- a/pySEQTarget/plot/__init__.py +++ b/pySEQTarget/plot/__init__.py @@ -1 +1 @@ -from ._survival_plot import _survival_plot \ No newline at end of file +from ._survival_plot import _survival_plot diff --git a/pySEQTarget/plot/_survival_plot.py b/pySEQTarget/plot/_survival_plot.py index 9641da8..f2ed288 100644 --- a/pySEQTarget/plot/_survival_plot.py +++ b/pySEQTarget/plot/_survival_plot.py @@ -3,14 +3,15 @@ import polars as pl import numpy as np + def _survival_plot(self): if self.plot_type == "risk": plot_data = self.km_data.filter(pl.col("estimate") == "risk") - elif self.plot_type == "survival": + elif self.plot_type == "survival": plot_data = self.km_data.filter(pl.col("estimate") == "survival") else: plot_data = self.km_data.filter(pl.col("estimate") == "incidence") - + if self.subgroup_colname is None: _plot_single(self, plot_data) else: @@ -20,10 +21,10 @@ def _survival_plot(self): def _plot_single(self, plot_data): plt.figure(figsize=(10, 6)) _plot_data(self, plot_data, plt.gca()) - + if self.plot_title is None: self.plot_title = f"Cumulative {self.plot_type.title()}" - + plt.xlabel("Followup") plt.ylabel(self.plot_type.title()) plt.title(self.plot_title) @@ -35,60 +36,62 @@ def _plot_single(self, plot_data): def _plot_subgroups(self, plot_data): subgroups = sorted(plot_data[self.subgroup_colname].unique().to_list()) n_subgroups = len(subgroups) - + n_cols = min(3, n_subgroups) n_rows = (n_subgroups + n_cols - 1) // n_cols - - fig, axes = plt.subplots(n_rows, n_cols, figsize=(7*n_cols, 6*n_rows)) + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows)) axes = np.atleast_1d(axes).flatten() - + for idx, subgroup_val in enumerate(subgroups): ax = axes[idx] subgroup_data = plot_data.filter(pl.col(self.subgroup_colname) == subgroup_val) _plot_data(self, subgroup_data, ax) - - subgroup_label = str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val - + + subgroup_label = ( + str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val + ) + ax.set_xlabel("Followup") ax.set_ylabel(self.plot_type.title()) - ax.set_title(f"{self.subgroup_colname.title()}: {subgroup_label}", fontsize=10, style="italic") + ax.set_title( + f"{self.subgroup_colname.title()}: {subgroup_label}", + fontsize=10, + style="italic", + ) ax.legend() ax.grid() - + for idx in range(n_subgroups, len(axes)): axes[idx].set_visible(False) - + if self.plot_title: fig.suptitle(self.plot_title, fontsize=14) else: fig.suptitle(f"Cumulative {self.plot_type.title()}", fontsize=14) - + plt.tight_layout() plt.show() def _plot_data(self, plot_data, ax): color_cycle = itertools.cycle(self.plot_colors) if self.plot_colors else None - + for idx, i in enumerate(self.treatment_level): subset = plot_data.filter(pl.col(self.treatment_col) == i) if subset.is_empty(): continue - + label = f"treatment = {i}" if self.plot_labels and idx < len(self.plot_labels): label = self.plot_labels[idx] - + color = next(color_cycle) if color_cycle else None - - line, = ax.plot( - subset["followup"], - subset["pred"], - "-", - label=label, - color=color + + (line,) = ax.plot( + subset["followup"], subset["pred"], "-", label=label, color=color ) - + if "LCI" in subset.columns and "UCI" in subset.columns: ax.fill_between( subset["followup"], @@ -96,5 +99,5 @@ def _plot_data(self, plot_data, ax): subset["UCI"], color=line.get_color(), alpha=0.2, - label="_nolegend_" + label="_nolegend_", ) diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py index 10953df..62fb9f7 100644 --- a/pySEQTarget/weighting/__init__.py +++ b/pySEQTarget/weighting/__init__.py @@ -2,4 +2,4 @@ from ._weight_pred import _weight_predict from ._weight_bind import _weight_bind from ._weight_data import _weight_setup -from ._weight_stats import _weight_stats \ No newline at end of file +from ._weight_stats import _weight_stats diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index be9d188..8f8f4a8 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -1,5 +1,6 @@ import polars as pl + def _weight_bind(self, WDT): if self.weight_preexpansion: join = "inner" @@ -8,51 +9,60 @@ def _weight_bind(self, WDT): else: join = "left" on = [self.id_col, "trial", "followup"] - + WDT = self.DT.join(WDT, on=on, how=join) - + if self.weight_preexpansion and self.excused: trial = (pl.col("trial") == 0) & (pl.col("period") == 0) - excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0 + excused = ( + pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) + > 0 + ) override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-7) + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-7) ) elif not self.weight_preexpansion and self.excused: trial = pl.col("followup") == 0 - excused = pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) > 0 + excused = ( + pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"]) + > 0 + ) override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-7) | - (pl.col("numerator") < 1e-7) + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-7) + | (pl.col("numerator") < 1e-7) ) else: - trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (pl.col("followup") == 0) + trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & ( + pl.col("followup") == 0 + ) excused = pl.lit(False) override = ( - trial | - excused | - pl.col(self.outcome_col).is_null() | - (pl.col("denominator") < 1e-15) | - pl.col("numerator").is_null() - ) - - self.DT = WDT.with_columns( - pl.when(override) - .then(pl.lit(1.0)) - .otherwise(pl.col("numerator") / pl.col("denominator")) - .alias("wt") - ).sort( - [self.id_col, "trial", "followup"] - ).with_columns( - pl.col("wt") - .fill_null(1.0) - .cum_prod() - .over([self.id_col, "trial"]) - .alias("weight") + trial + | excused + | pl.col(self.outcome_col).is_null() + | (pl.col("denominator") < 1e-15) + | pl.col("numerator").is_null() + ) + + self.DT = ( + WDT.with_columns( + pl.when(override) + .then(pl.lit(1.0)) + .otherwise(pl.col("numerator") / pl.col("denominator")) + .alias("wt") + ) + .sort([self.id_col, "trial", "followup"]) + .with_columns( + pl.col("wt") + .fill_null(1.0) + .cum_prod() + .over([self.id_col, "trial"]) + .alias("weight") + ) ) - \ No newline at end of file diff --git a/pySEQTarget/weighting/_weight_data.py b/pySEQTarget/weighting/_weight_data.py index 4aa3780..72f8ae7 100644 --- a/pySEQTarget/weighting/_weight_data.py +++ b/pySEQTarget/weighting/_weight_data.py @@ -1,38 +1,47 @@ import polars as pl + def _weight_setup(self): DT = self.DT data = self.data if not self.weight_preexpansion: - baseline_lag = data.select([self.treatment_col, self.id_col, self.time_col]) \ - .sort([self.id_col, self.time_col]) \ - .with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over(self.id_col) - .alias("tx_lag")) \ - .drop(self.treatment_col) \ - .rename({self.time_col : "period"}) - - fup0 = DT.filter(pl.col("followup") == 0) \ - .join( - baseline_lag, - on = [self.id_col, "period"], - how = "inner" + baseline_lag = ( + data.select([self.treatment_col, self.id_col, self.time_col]) + .sort([self.id_col, self.time_col]) + .with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over(self.id_col) + .alias("tx_lag") ) - - fup = DT.sort([self.id_col, "trial", "followup"]) \ - .with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over([self.id_col, "trial"]) - .alias("tx_lag") - ).filter(pl.col("followup") != 0) - + .drop(self.treatment_col) + .rename({self.time_col: "period"}) + ) + + fup0 = DT.filter(pl.col("followup") == 0).join( + baseline_lag, on=[self.id_col, "period"], how="inner" + ) + + fup = ( + DT.sort([self.id_col, "trial", "followup"]) + .with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over([self.id_col, "trial"]) + .alias("tx_lag") + ) + .filter(pl.col("followup") != 0) + ) + WDT = pl.concat([fup0, fup]).sort([self.id_col, "trial", "followup"]) else: - WDT = data.with_columns(pl.col(self.treatment_col) - .shift(fill_value=self.treatment_level[0]) - .over(self.id_col) - .alias("tx_lag"), - (pl.col(self.time_col) ** 2).alias(f"{self.time_col}{self.indicator_squared}")) + WDT = data.with_columns( + pl.col(self.treatment_col) + .shift(fill_value=self.treatment_level[0]) + .over(self.id_col) + .alias("tx_lag"), + (pl.col(self.time_col) ** 2).alias( + f"{self.time_col}{self.indicator_squared}" + ), + ) return WDT - \ No newline at end of file diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index ea0ef71..70fc85d 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -1,6 +1,7 @@ import statsmodels.api as sm import statsmodels.formula.api as smf + def _fit_LTFU(self, WDT): if self.cense_colname is None: return @@ -8,20 +9,17 @@ def _fit_LTFU(self, WDT): fits = [] if self.cense_eligible_colname is not None: WDT = WDT[WDT[self.cense_eligible_colname] == 1] - + for i in [self.cense_numerator, self.cense_denominator]: formula = f"{self.cense_colname}~{i}" - model = smf.glm( - formula, - WDT, - family=sm.families.Binomial() - ) + model = smf.glm(formula, WDT, family=sm.families.Binomial()) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.cense_numerator = fits[0] self.cense_denominator = fits[1] + def _fit_numerator(self, WDT): if self.weight_preexpansion and self.excused: return @@ -29,7 +27,9 @@ def _fit_numerator(self, WDT): return predictor = "switch" if self.excused else self.treatment_col formula = f"{predictor}~{self.numerator}" - tx_bas = f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" + tx_bas = ( + f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" + ) fits = [] for i, level in enumerate(self.treatment_level): if self.excused and self.excused_colnames[i] is not None: @@ -40,20 +40,22 @@ def _fit_numerator(self, WDT): DT_subset = DT_subset[DT_subset[tx_bas] == level] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - - model = smf.mnlogit( - formula, - DT_subset - ) + + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.numerator_model = fits - + + def _fit_denominator(self, WDT): if self.method == "ITT": return - predictor = "switch" if self.excused and not self.weight_preexpansion else self.treatment_col + predictor = ( + "switch" + if self.excused and not self.weight_preexpansion + else self.treatment_col + ) formula = f"{predictor}~{self.denominator}" fits = [] for i, level in enumerate(self.treatment_level): @@ -62,18 +64,14 @@ def _fit_denominator(self, WDT): else: DT_subset = WDT if self.weight_lag_condition: - DT_subset = DT_subset[DT_subset["tx_lag"] == level] + DT_subset = DT_subset[DT_subset["tx_lag"] == level] if not self.weight_preexpansion and not self.excused: DT_subset = DT_subset[DT_subset["followup"] != 0] if self.weight_eligible_colnames[i] is not None: DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1] - - model = smf.mnlogit( - formula, - DT_subset - ) + + model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0) fits.append(model_fit) - + self.denominator_model = fits - \ No newline at end of file diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index aff0c6f..145865f 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -2,26 +2,25 @@ import polars as pl import numpy as np + def _weight_predict(self, WDT): grouping = [self.id_col] grouping += ["trial"] if not self.weight_preexpansion else [] time = self.time_col if self.weight_preexpansion else "followup" - + if self.method == "ITT": - WDT = WDT.with_columns([ - pl.lit(1.).alias("numerator"), - pl.lit(1.).alias("denominator") - ]) + WDT = WDT.with_columns( + [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] + ) else: - WDT = WDT.with_columns([ - pl.lit(1.).alias("numerator"), - pl.lit(1.).alias("denominator") - ]) - + WDT = WDT.with_columns( + [pl.lit(1.0).alias("numerator"), pl.lit(1.0).alias("denominator")] + ) + for i, level in enumerate(self.treatment_level): mask = pl.col("tx_lag") == level lag_mask = (WDT["tx_lag"] == level).to_numpy() - + if self.denominator_model[i] is not None: pred_denom = np.ones(WDT.height) if lag_mask.sum() > 0: @@ -30,11 +29,13 @@ def _weight_predict(self, WDT): if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] - switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy() - pred_denom[lag_mask] = np.where(switched_treatment, 1. - p, p) + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_denom[lag_mask] = np.where(switched_treatment, 1.0 - p, p) else: pred_denom = np.ones(WDT.height) - + if hasattr(self, "numerator_model") and self.numerator_model[i] is not None: pred_num = np.ones(WDT.height) if lag_mask.sum() > 0: @@ -43,34 +44,40 @@ def _weight_predict(self, WDT): if p.ndim == 1: p = p.reshape(-1, 1) p = p[:, i] - switched_treatment = (subset[self.treatment_col] != subset["tx_lag"]).to_numpy() - pred_num[lag_mask] = np.where(switched_treatment, 1. - p, p) + switched_treatment = ( + subset[self.treatment_col] != subset["tx_lag"] + ).to_numpy() + pred_num[lag_mask] = np.where(switched_treatment, 1.0 - p, p) else: pred_num = np.ones(WDT.height) - - WDT = WDT.with_columns([ - pl.when(mask) - .then(pl.Series(pred_num)) - .otherwise(pl.col("numerator")) - .alias("numerator"), - pl.when(mask) - .then(pl.Series(pred_denom)) - .otherwise(pl.col("denominator")) - .alias("denominator") - ]) - + + WDT = WDT.with_columns( + [ + pl.when(mask) + .then(pl.Series(pred_num)) + .otherwise(pl.col("numerator")) + .alias("numerator"), + pl.when(mask) + .then(pl.Series(pred_denom)) + .otherwise(pl.col("denominator")) + .alias("denominator"), + ] + ) + if self.cense_colname is not None: p_num = _predict_model(self, self.cense_numerator, WDT).flatten() p_denom = _predict_model(self, self.cense_denominator, WDT).flatten() - WDT = WDT.with_columns([ - pl.Series("cense_numerator", p_num), - pl.Series("cense_denominator", p_denom) - ]).with_columns( + WDT = WDT.with_columns( + [ + pl.Series("cense_numerator", p_num), + pl.Series("cense_denominator", p_denom), + ] + ).with_columns( (pl.col("cense_numerator") / pl.col("cense_denominator")).alias("cense") ) else: - WDT = WDT.with_columns(pl.lit(1.).alias("cense")) - + WDT = WDT.with_columns(pl.lit(1.0).alias("cense")) + kept = ["numerator", "denominator", "cense", self.id_col, "trial", time, "tx_lag"] exists = [col for col in kept if col in WDT.columns] return WDT.select(exists).sort(grouping + [time]) diff --git a/pySEQTarget/weighting/_weight_stats.py b/pySEQTarget/weighting/_weight_stats.py index ddbd190..b6da331 100644 --- a/pySEQTarget/weighting/_weight_stats.py +++ b/pySEQTarget/weighting/_weight_stats.py @@ -1,20 +1,23 @@ import polars as pl + def _weight_stats(self): - stats = self.DT.select([ - pl.col("weight").min().alias("weight_min"), - pl.col("weight").max().alias("weight_max"), - pl.col("weight").mean().alias("weight_mean"), - pl.col("weight").std().alias("weight_std"), - pl.col("weight").quantile(0.01).alias("weight_p01"), - pl.col("weight").quantile(0.25).alias("weight_p25"), - pl.col("weight").quantile(0.50).alias("weight_p50"), - pl.col("weight").quantile(0.75).alias("weight_p75"), - pl.col("weight").quantile(0.99).alias("weight_p99") - ]) - + stats = self.DT.select( + [ + pl.col("weight").min().alias("weight_min"), + pl.col("weight").max().alias("weight_max"), + pl.col("weight").mean().alias("weight_mean"), + pl.col("weight").std().alias("weight_std"), + pl.col("weight").quantile(0.01).alias("weight_p01"), + pl.col("weight").quantile(0.25).alias("weight_p25"), + pl.col("weight").quantile(0.50).alias("weight_p50"), + pl.col("weight").quantile(0.75).alias("weight_p75"), + pl.col("weight").quantile(0.99).alias("weight_p99"), + ] + ) + if self.weight_p99: self.weight_min = stats.select("weight_p01").item() self.weight_max = stats.select("weight_p99").item() - + return stats From 0533f2dca5574d2efd08cf2529aaaea36f62aa29 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:11:28 +0100 Subject: [PATCH 6/9] black on tests --- tests/test_accessor.py | 9 +- tests/test_coefficients.py | 302 +++++++++++++++++++++------------ tests/test_covariates.py | 105 +++++++----- tests/test_followup_options.py | 82 ++++++--- tests/test_hazard.py | 25 +-- tests/test_parallel.py | 30 ++-- tests/test_survival.py | 81 +++++---- 7 files changed, 398 insertions(+), 236 deletions(-) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 9039030..ef028e6 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -2,9 +2,10 @@ from pySEQTarget.data import load_data import pytest + def test_ITT_collector(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -14,12 +15,12 @@ def test_ITT_collector(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), ) s.expand() s.fit() collector = s.collect() outcomes = collector.retrieve_data("unique_outcomes") with pytest.raises(ValueError): - collector.retrieve_data("km_data") \ No newline at end of file + collector.retrieve_data("km_data") diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index ae76aa5..90494f8 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,9 +1,10 @@ from pySEQTarget import SEQuential, SEQopts from pySEQTarget.data import load_data + def test_ITT_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,21 +14,30 @@ def test_ITT_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.828506035553407, 0.18935003090041902, 0.12717241010542563, - 0.033715156987629266, -0.00014691202235029346, 0.044566165558944326, - 0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771, - 0.20072409918428052] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.828506035553407, + 0.18935003090041902, + 0.12717241010542563, + 0.033715156987629266, + -0.00014691202235029346, + 0.044566165558944326, + 0.0005787770439053261, + 0.0032906669395295026, + -0.01339242049205771, + 0.20072409918428052, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_dose_response_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -37,21 +47,28 @@ def test_PreE_dose_response_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-4.842735939069144, 0.14286755770151904, 0.055221018477671927, - -0.000581657931537684, -0.008484541900408258, 0.00021073328759912806, - 0.010537967151467553, 0.0007772316818101141] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -4.842735939069144, + 0.14286755770151904, + 0.055221018477671927, + -0.000581657931537684, + -0.008484541900408258, + 0.00021073328759912806, + 0.010537967151467553, + 0.0007772316818101141, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_dose_response_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -61,22 +78,32 @@ def test_PostE_dose_response_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True) + method="dose-response", + parameters=SEQopts(weighted=True), ) - + s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.265901713761531, 0.14065954021957594, 0.048626017624679704, - -0.0004688287307505834, -0.003975906839775267, 0.00016676441745740924, - 0.03866279977096397, 0.0005928449623613982, 0.0030001459817949844, - -0.02106338184559446, 0.14867250693140854] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.265901713761531, + 0.14065954021957594, + 0.048626017624679704, + -0.0004688287307505834, + -0.003975906839775267, + 0.00016676441745740924, + 0.03866279977096397, + 0.0005928449623613982, + 0.0030001459817949844, + -0.02106338184559446, + 0.14867250693140854, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_censoring_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -86,21 +113,27 @@ def test_PreE_censoring_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="censoring", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-4.818288687908951, 0.511665606890965, 0.062028316788368384, - 0.025489681857269905, 0.00018215948440046585, -0.014019017637918164, - 0.001110238926667272] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -4.818288687908951, + 0.511665606890965, + 0.062028316788368384, + 0.025489681857269905, + 0.00018215948440046585, + -0.014019017637918164, + 0.001110238926667272, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_censoring_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -110,21 +143,30 @@ def test_PostE_censoring_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True) + method="censoring", + parameters=SEQopts(weighted=True), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-7.9113179326280445, 0.49092190701455873, 0.08903087485402544, - 0.026160806382879903, 0.00019078148503570062, 0.04445697224987294, - 0.0007051968052005897, 0.004316239095295115, 0.013762799304812959, - 0.3196331024454665] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -7.9113179326280445, + 0.49092190701455873, + 0.08903087485402544, + 0.026160806382879903, + 0.00019078148503570062, + 0.04445697224987294, + 0.0007051968052005897, + 0.004316239095295115, + 0.013762799304812959, + 0.3196331024454665, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PreE_censoring_excused_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -134,22 +176,31 @@ def test_PreE_censoring_excused_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.175691049418418, 1.3493634846413598, 0.1072284696749134, - -0.003977965364113033, 0.06959432825811135, -0.00034297574787048573] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.175691049418418, + 1.3493634846413598, + 0.1072284696749134, + -0.003977965364113033, + 0.06959432825811135, + -0.00034297574787048573, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_censoring_excused_coefs(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -159,24 +210,35 @@ def test_PostE_censoring_excused_coefs(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"], - weight_max=1) + method="censoring", + parameters=SEQopts( + weighted=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + weight_max=1, + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-7.126398786875262, 0.2632047482928519, 0.13345454814736696, - 0.03967181206032395, -0.00033089446793392585, 0.03763545026332514, - 0.0007588725152627089, 0.0036793093608788923, -0.022372677571544992, - 0.2441842617520696] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -7.126398786875262, + 0.2632047482928519, + 0.13345454814736696, + 0.03967181206032395, + -0.00033089446793392585, + 0.03763545026332514, + 0.0007588725152627089, + 0.0036793093608788923, + -0.022372677571544992, + 0.2441842617520696, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] - + + def test_PreE_LTFU_ITT(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -186,23 +248,32 @@ def test_PreE_LTFU_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - cense_colname="LTFU") + method="ITT", + parameters=SEQopts( + weighted=True, weight_preexpansion=True, cense_colname="LTFU" + ), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-21.640523091572796, 0.0685235184372898, -0.19006360662228572, - 0.028750950193838918, -0.0005762057433736666, 0.28554312978583757, - -0.001373044229623057, 0.006589141394458155, -0.44898959259422394, - 1.3875089788036237] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -21.640523091572796, + 0.0685235184372898, + -0.19006360662228572, + 0.028750950193838918, + -0.0005762057433736666, + 0.28554312978583757, + -0.001373044229623057, + 0.006589141394458155, + -0.44898959259422394, + 1.3875089788036237, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_PostE_LTFU_ITT(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -212,22 +283,30 @@ def test_PostE_LTFU_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(weighted=True, - cense_colname="LTFU") + method="ITT", + parameters=SEQopts(weighted=True, cense_colname="LTFU"), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-21.640523091572796, 0.0685235184372898, -0.19006360662228572, - 0.028750950193838918, -0.0005762057433736666, 0.28554312978583757, - -0.001373044229623057, 0.006589141394458155, -0.44898959259422394, - 1.3875089788036237] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -21.640523091572796, + 0.0685235184372898, + -0.19006360662228572, + 0.028750950193838918, + -0.0005762057433736666, + 0.28554312978583757, + -0.001373044229623057, + 0.006589141394458155, + -0.44898959259422394, + 1.3875089788036237, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_ITT_multinomial(): data = load_data("SEQdata_multitreatment") - + s = SEQuential( data, id_col="ID", @@ -237,22 +316,31 @@ def test_ITT_multinomial(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(treatment_level=[1,2]) + method="ITT", + parameters=SEQopts(treatment_level=[1, 2]), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-47.505262164163625, 1.76628017234151, 22.79205044396338, - 0.14473536056627245, -0.003725499516376173, 0.2893070991930884, - -0.004266608123938117, 0.05574429164512122, 0.7847862691929901, - 1.4703411759229423] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -47.505262164163625, + 1.76628017234151, + 22.79205044396338, + 0.14473536056627245, + -0.003725499516376173, + 0.2893070991930884, + -0.004266608123938117, + 0.05574429164512122, + 0.7847862691929901, + 1.4703411759229423, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] - + + def test_weighted_multinomial(): - data = load_data("SEQdata_multitreatment") - - s = SEQuential( + data = load_data("SEQdata_multitreatment") + + s = SEQuential( data, id_col="ID", time_col="time", @@ -261,15 +349,21 @@ def test_weighted_multinomial(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted = True, - weight_preexpansion=True, - treatment_level=[1,2]) + method="censoring", + parameters=SEQopts( + weighted=True, weight_preexpansion=True, treatment_level=[1, 2] + ), ) - s.expand() - s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-111.35419661939163, -12.571187230338328, 9.234157699403015, - -0.6336774763031923, 0.016754692338530056, 5.8240772329087225, - -0.08598454090661659] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] \ No newline at end of file + s.expand() + s.fit() + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -111.35419661939163, + -12.571187230338328, + 9.234157699403015, + -0.6336774763031923, + 0.016754692338530056, + 5.8240772329087225, + -0.08598454090661659, + ] + assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] diff --git a/tests/test_covariates.py b/tests/test_covariates.py index 31cc727..aedf54d 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -1,9 +1,10 @@ from pySEQTarget.data import load_data from pySEQTarget import SEQuential, SEQopts + def test_ITT_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,18 +14,22 @@ def test_ITT_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts() + method="ITT", + parameters=SEQopts(), + ) + + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator is None assert s.denominator is None return + def test_PreE_dose_response_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -34,19 +39,19 @@ def test_PreE_dose_response_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="dose-response", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) - + assert s.covariates == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex" assert s.numerator == "sex+time+time_sq" assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_dose_response_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -56,17 +61,24 @@ def test_PostE_dose_response_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "dose-response", - parameters=SEQopts(weighted=True) + method="dose-response", + parameters=SEQopts(weighted=True), + ) + assert ( + s.covariates + == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "dose+dose_sq+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) return + def test_PreE_censoring_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -76,18 +88,18 @@ def test_PreE_censoring_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True) + method="censoring", + parameters=SEQopts(weighted=True, weight_preexpansion=True), ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex" assert s.numerator == "sex+time+time_sq" assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_censoring_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -97,18 +109,25 @@ def test_PostE_censoring_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True) + method="censoring", + parameters=SEQopts(weighted=True), + ) + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) + return + def test_PreE_censoring_excused_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -118,20 +137,23 @@ def test_PreE_censoring_excused_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - weight_preexpansion=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + excused=True, + excused_colnames=["excusedZero", "excusedOne"], + ), ) assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq" assert s.numerator is None assert s.denominator == "sex+N+L+P+time+time_sq" return + def test_PostE_censoring_excused_covariates(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -141,13 +163,18 @@ def test_PostE_censoring_excused_covariates(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "censoring", - parameters=SEQopts(weighted=True, - excused=True, - excused_colnames=["excusedZero", "excusedOne"]) + method="censoring", + parameters=SEQopts( + weighted=True, excused=True, excused_colnames=["excusedZero", "excusedOne"] + ), + ) + assert ( + s.covariates + == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" ) - assert s.covariates == "tx_init_bas+followup+followup_sq+trial+trial_sq+sex+N_bas+L_bas+P_bas" assert s.numerator == "sex+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" - assert s.denominator == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + assert ( + s.denominator + == "sex+N+L+P+N_bas+L_bas+P_bas+followup+followup_sq+trial+trial_sq" + ) return - \ No newline at end of file diff --git a/tests/test_followup_options.py b/tests/test_followup_options.py index 6bc0c08..3e2102d 100644 --- a/tests/test_followup_options.py +++ b/tests/test_followup_options.py @@ -1,9 +1,10 @@ from pySEQTarget import SEQuential, SEQopts from pySEQTarget.data import load_data + def test_followup_class(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,24 +14,33 @@ def test_followup_class(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_class=True, - followup_include=False, - followup_max=5) + method="ITT", + parameters=SEQopts(followup_class=True, followup_include=False, followup_max=5), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.6000834193414635, 0.36024705241286203, 0.04326409573404126, - 0.07627958175273072, 0.11375627612408938, 0.14496108664292745, - 0.1798424095611678, 0.09066206802273916, 0.015738693166264354, - 0.0009258560187318309, 0.011267393559242982, 0.022194521411244304, - 0.19115237121222872] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.6000834193414635, + 0.36024705241286203, + 0.04326409573404126, + 0.07627958175273072, + 0.11375627612408938, + 0.14496108664292745, + 0.1798424095611678, + 0.09066206802273916, + 0.015738693166264354, + 0.0009258560187318309, + 0.011267393559242982, + 0.022194521411244304, + 0.19115237121222872, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_followup_spline(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -40,22 +50,31 @@ def test_followup_spline(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_spline=True, - followup_include=False) + method="ITT", + parameters=SEQopts(followup_spline=True, followup_include=False), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.264817962084417, 0.20125056343026881, 0.12568743032952776, - 0.03823426390103046, 0.0006607691746414019, 0.003343365539743267, - -0.01319460158923785, 0.19601796921732118, -0.5186462478511427, - 0.37598656666756103, 1.6553848469346044] + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.264817962084417, + 0.20125056343026881, + 0.12568743032952776, + 0.03823426390103046, + 0.0006607691746414019, + 0.003343365539743267, + -0.01319460158923785, + 0.19601796921732118, + -0.5186462478511427, + 0.37598656666756103, + 1.6553848469346044, + ] assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] + def test_no_followup(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -65,13 +84,20 @@ def test_no_followup(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(followup_include=False) + method="ITT", + parameters=SEQopts(followup_include=False), ) s.expand() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - expected = [-6.062350570326165, 0.17748844870984498, 0.11209431124681817, - 0.03344595751001804, 0.0005457002039545119, 0.0032236473201563585, - -0.014463448024337773, 0.20398559747503964] - assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] \ No newline at end of file + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + expected = [ + -6.062350570326165, + 0.17748844870984498, + 0.11209431124681817, + 0.03344595751001804, + 0.0005457002039545119, + 0.0032236473201563585, + -0.014463448024337773, + 0.20398559747503964, + ] + assert [round(x, 3) for x in matrix] == [round(x, 3) for x in expected] diff --git a/tests/test_hazard.py b/tests/test_hazard.py index 2244bed..335ed5f 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,9 +1,10 @@ from pySEQTarget.data import load_data from pySEQTarget import SEQuential, SEQopts + def test_ITT_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,16 +14,17 @@ def test_ITT_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True) + method="ITT", + parameters=SEQopts(hazard_estimate=True), ) s.expand() s.fit() s.hazard() + def test_bootstrap_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -32,18 +34,18 @@ def test_bootstrap_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True, - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2), ) s.expand() s.bootstrap() s.fit() s.hazard() - + + def test_subgroup_hazard(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -53,9 +55,8 @@ def test_subgroup_hazard(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(hazard_estimate=True, - subgroup_colname="sex") + method="ITT", + parameters=SEQopts(hazard_estimate=True, subgroup_colname="sex"), ) s.expand() s.bootstrap() diff --git a/tests/test_parallel.py b/tests/test_parallel.py index b3e63db..9a5ef26 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -3,13 +3,13 @@ import os import pytest + @pytest.mark.skipif( - os.getenv('CI') == 'true', - reason="Parallelism test hangs in CI environment" + os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" ) def test_parallel_ITT(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -19,16 +19,22 @@ def test_parallel_ITT(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(parallel=True, - bootstrap_nboot=2, - ncores=1) + method="ITT", + parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1), ) s.expand() s.bootstrap() s.fit() - matrix = s.outcome_model[0]['outcome'].summary2().tables[1]["Coef."].to_list() - assert matrix == [-6.828506035553407, 0.18935003090041902, 0.12717241010542563, - 0.033715156987629266, -0.00014691202235029346, 0.044566165558944326, - 0.0005787770439053261, 0.0032906669395295026, -0.01339242049205771, - 0.20072409918428052] \ No newline at end of file + matrix = s.outcome_model[0]["outcome"].summary2().tables[1]["Coef."].to_list() + assert matrix == [ + -6.828506035553407, + 0.18935003090041902, + 0.12717241010542563, + 0.033715156987629266, + -0.00014691202235029346, + 0.044566165558944326, + 0.0005787770439053261, + 0.0032906669395295026, + -0.01339242049205771, + 0.20072409918428052, + ] diff --git a/tests/test_survival.py b/tests/test_survival.py index abb89a7..a85d17c 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,9 +1,10 @@ from pySEQTarget import SEQuential, SEQopts from pySEQTarget.data import load_data + def test_regular_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -13,17 +14,18 @@ def test_regular_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True) + method="ITT", + parameters=SEQopts(km_curves=True), ) s.expand() s.fit() s.survival() return - + + def test_bootstrapped_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -33,9 +35,8 @@ def test_bootstrapped_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(km_curves=True, bootstrap_nboot=2), ) s.expand() s.bootstrap() @@ -43,9 +44,10 @@ def test_bootstrapped_survival(): s.survival() return + def test_subgroup_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -55,18 +57,18 @@ def test_subgroup_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - subgroup_colname="sex") + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex"), ) s.expand() s.fit() s.survival() return - + + def test_subgroup_bootstrapped_survival(): data = load_data("SEQdata") - + s = SEQuential( data, id_col="ID", @@ -76,10 +78,8 @@ def test_subgroup_bootstrapped_survival(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - subgroup_colname="sex", - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2), ) s.expand() s.bootstrap() @@ -87,9 +87,10 @@ def test_subgroup_bootstrapped_survival(): s.survival() return + def test_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -99,19 +100,20 @@ def test_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence") + method="ITT", + parameters=SEQopts( + km_curves=True, compevent_colname="LTFU", plot_type="incidence" + ), ) s.expand() s.fit() s.survival() return - + + def test_bootstrapped_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -121,21 +123,24 @@ def test_bootstrapped_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence", - bootstrap_nboot=2) + method="ITT", + parameters=SEQopts( + km_curves=True, + compevent_colname="LTFU", + plot_type="incidence", + bootstrap_nboot=2, + ), ) s.expand() s.bootstrap() s.fit() s.survival() return - + + def test_subgroup_compevent(): data = load_data("SEQdata_LTFU") - + s = SEQuential( data, id_col="ID", @@ -145,11 +150,13 @@ def test_subgroup_compevent(): outcome_col="outcome", time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], - method = "ITT", - parameters=SEQopts(km_curves=True, - compevent_colname="LTFU", - plot_type = "incidence", - subgroup_colname = "sex") + method="ITT", + parameters=SEQopts( + km_curves=True, + compevent_colname="LTFU", + plot_type="incidence", + subgroup_colname="sex", + ), ) s.expand() s.bootstrap() From 15837fbd275c6c10bdd21e98cee92745597f6de9 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:29:00 +0100 Subject: [PATCH 7/9] lint --- .github/workflows/lint.yml | 4 +-- README.md | 10 +++--- pySEQTarget/SEQopts.py | 2 +- pySEQTarget/SEQoutput.py | 10 +++--- pySEQTarget/SEQuential.py | 48 +++++++++----------------- pySEQTarget/__init__.py | 4 +-- pySEQTarget/analysis/__init__.py | 13 ++++--- pySEQTarget/analysis/_hazard.py | 5 +-- pySEQTarget/analysis/_outcome_fit.py | 7 ++-- pySEQTarget/analysis/_subgroup_fit.py | 1 + pySEQTarget/data/__init__.py | 3 +- pySEQTarget/error/__init__.py | 4 +-- pySEQTarget/expansion/__init__.py | 10 +++--- pySEQTarget/expansion/_binder.py | 1 + pySEQTarget/expansion/_mapper.py | 3 +- pySEQTarget/helpers/__init__.py | 12 +++---- pySEQTarget/helpers/_bootstrap.py | 10 +++--- pySEQTarget/initialization/__init__.py | 9 ++--- pySEQTarget/plot/__init__.py | 2 +- pySEQTarget/plot/_survival_plot.py | 3 +- pySEQTarget/weighting/__init__.py | 12 ++++--- pySEQTarget/weighting/_weight_pred.py | 5 +-- pyproject.toml | 2 +- tests/test_accessor.py | 7 ++-- tests/test_coefficients.py | 2 +- tests/test_covariates.py | 2 +- tests/test_followup_options.py | 2 +- tests/test_hazard.py | 2 +- tests/test_parallel.py | 6 ++-- tests/test_survival.py | 2 +- 30 files changed, 103 insertions(+), 100 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d573f35..74b3dc9 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,6 +29,4 @@ jobs: - name: Run Ruff run: ruff check . - - - name: Run mypy - run: mypy . --ignore-missing-imports \ No newline at end of file + \ No newline at end of file diff --git a/README.md b/README.md index 0a48102..7e20d3d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# pySEQ - Sequentially Nested Target Trial Emulation +# pySEQTarget - Sequentially Nested Target Trial Emulation Implementation of sequential trial emulation for the analysis of observational databases. The ‘SEQTaRget’ software accommodates @@ -9,13 +9,13 @@ intention-to-treat and per-protocol effects, and can adjust for potential selection bias. ## Installation -You can install the development version of pySEQ from github with: +You can install the development version of pySEQTarget from github with: ```shell -pip install git+https://github.com/CausalInference/pySEQ +pip install git+https://github.com/CausalInference/pySEQTarget ``` Or from pypi iwth ```shell -pip install pySEQ +pip install pySEQTarget ``` ## Setting up your Analysis @@ -25,7 +25,7 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe ```python import polars as pl -from pySEQ import SEQuential, SEQopts +from pySEQTarget import SEQuential, SEQopts data = pl.from_pandas(SEQdata) options = SEQopts(km_curves = True) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 944c86b..4a59de6 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -1,6 +1,6 @@ import multiprocessing from dataclasses import dataclass, field -from typing import List, Optional, Literal +from typing import List, Literal, Optional @dataclass diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index ca22a20..ed1ea74 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from typing import List, Optional, Literal -from .SEQopts import SEQopts -from statsmodels.base.wrapper import ResultsWrapper -import polars as pl +from typing import List, Literal, Optional + import matplotlib.figure +import polars as pl +from statsmodels.base.wrapper import ResultsWrapper + +from .SEQopts import SEQopts @dataclass diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 05158e2..0079891 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -1,41 +1,25 @@ -from typing import Optional, List, Literal +import datetime import time -from dataclasses import asdict from collections import Counter -import polars as pl +from dataclasses import asdict +from typing import List, Literal, Optional + import numpy as np -import datetime +import polars as pl +from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit, + _pred_risk, _risk_estimates, _subgroup_fit) +from .error import _datachecker, _param_checker +from .expansion import _binder, _diagnostics, _dynamic, _random_selection +from .helpers import _col_string, _format_time, bootstrap_loop +from .initialization import (_cense_denominator, _cense_numerator, + _denominator, _numerator, _outcome) +from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .error import _param_checker, _datachecker -from .helpers import _col_string, bootstrap_loop, _format_time -from .initialization import ( - _outcome, - _numerator, - _denominator, - _cense_numerator, - _cense_denominator, -) -from .expansion import _binder, _dynamic, _random_selection, _diagnostics -from .weighting import ( - _weight_setup, - _fit_LTFU, - _fit_numerator, - _fit_denominator, - _weight_bind, - _weight_predict, - _weight_stats, -) -from .analysis import ( - _outcome_fit, - _pred_risk, - _calculate_survival, - _subgroup_fit, - _calculate_hazard, - _risk_estimates, -) -from .plot import _survival_plot +from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, + _weight_bind, _weight_predict, _weight_setup, + _weight_stats) class SEQuential: diff --git a/pySEQTarget/__init__.py b/pySEQTarget/__init__.py index 540b3eb..d8c687e 100644 --- a/pySEQTarget/__init__.py +++ b/pySEQTarget/__init__.py @@ -1,5 +1,5 @@ -from .SEQuential import SEQuential from .SEQopts import SEQopts from .SEQoutput import SEQoutput +from .SEQuential import SEQuential -__all__ = ["SEQuential", "SEQopts"] +__all__ = ["SEQuential", "SEQopts", "SEQoutput"] diff --git a/pySEQTarget/analysis/__init__.py b/pySEQTarget/analysis/__init__.py index ec9e032..6799dfd 100644 --- a/pySEQTarget/analysis/__init__.py +++ b/pySEQTarget/analysis/__init__.py @@ -1,5 +1,8 @@ -from ._outcome_fit import _outcome_fit -from ._survival_pred import _get_outcome_predictions, _pred_risk, _calculate_survival -from ._subgroup_fit import _subgroup_fit -from ._hazard import _calculate_hazard -from ._risk_estimates import _risk_estimates +from ._hazard import _calculate_hazard as _calculate_hazard +from ._outcome_fit import _outcome_fit as _outcome_fit +from ._risk_estimates import _risk_estimates as _risk_estimates +from ._subgroup_fit import _subgroup_fit as _subgroup_fit +from ._survival_pred import _calculate_survival as _calculate_survival +from ._survival_pred import \ + _get_outcome_predictions as _get_outcome_predictions +from ._survival_pred import _pred_risk as _pred_risk diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 195120e..4c667c9 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -1,7 +1,8 @@ -import polars as pl +import warnings + import numpy as np +import polars as pl from lifelines import CoxPHFitter -import warnings def _calculate_hazard(self): diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index 451dc43..7ed823f 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -1,7 +1,8 @@ +import re + +import polars as pl import statsmodels.api as sm import statsmodels.formula.api as smf -import polars as pl -import re def _outcome_fit( @@ -32,7 +33,7 @@ def _outcome_fit( df_pd[squared_col] = df_pd[squared_col].astype("category") if self.followup_spline: - spline = f"cr(followup, df=3)" + spline = "cr(followup, df=3)" formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula) formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula) diff --git a/pySEQTarget/analysis/_subgroup_fit.py b/pySEQTarget/analysis/_subgroup_fit.py index b6a6c04..fd481cf 100644 --- a/pySEQTarget/analysis/_subgroup_fit.py +++ b/pySEQTarget/analysis/_subgroup_fit.py @@ -1,4 +1,5 @@ import polars as pl + from ._outcome_fit import _outcome_fit diff --git a/pySEQTarget/data/__init__.py b/pySEQTarget/data/__init__.py index 37c47d3..e65d31d 100644 --- a/pySEQTarget/data/__init__.py +++ b/pySEQTarget/data/__init__.py @@ -1,9 +1,10 @@ from importlib.resources import files + import polars as pl def load_data(name: str = "SEQdata") -> pl.DataFrame: - loc = files("pySEQ.data") + loc = files("pySEQTarget.data") if name in ["SEQdata", "SEQdata_multitreatment", "SEQdata_LTFU"]: if name == "SEQdata": data_path = loc.joinpath("SEQdata.csv") diff --git a/pySEQTarget/error/__init__.py b/pySEQTarget/error/__init__.py index 3aa4a33..f51f084 100644 --- a/pySEQTarget/error/__init__.py +++ b/pySEQTarget/error/__init__.py @@ -1,2 +1,2 @@ -from ._param_checker import _param_checker -from ._datachecker import _datachecker +from ._datachecker import _datachecker as _datachecker +from ._param_checker import _param_checker as _param_checker diff --git a/pySEQTarget/expansion/__init__.py b/pySEQTarget/expansion/__init__.py index 32598ec..1262af8 100644 --- a/pySEQTarget/expansion/__init__.py +++ b/pySEQTarget/expansion/__init__.py @@ -1,5 +1,5 @@ -from ._binder import _binder -from ._dynamic import _dynamic -from ._mapper import _mapper -from ._selection import _random_selection -from ._diagnostics import _diagnostics +from ._binder import _binder as _binder +from ._diagnostics import _diagnostics as _diagnostics +from ._dynamic import _dynamic as _dynamic +from ._mapper import _mapper as _mapper +from ._selection import _random_selection as _random_selection diff --git a/pySEQTarget/expansion/_binder.py b/pySEQTarget/expansion/_binder.py index c459002..727e0e6 100644 --- a/pySEQTarget/expansion/_binder.py +++ b/pySEQTarget/expansion/_binder.py @@ -1,4 +1,5 @@ import polars as pl + from ._mapper import _mapper diff --git a/pySEQTarget/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py index b743a81..c169669 100644 --- a/pySEQTarget/expansion/_mapper.py +++ b/pySEQTarget/expansion/_mapper.py @@ -1,6 +1,7 @@ -import polars as pl import math +import polars as pl + def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf): """ diff --git a/pySEQTarget/helpers/__init__.py b/pySEQTarget/helpers/__init__.py index 4fe0cae..f621544 100644 --- a/pySEQTarget/helpers/__init__.py +++ b/pySEQTarget/helpers/__init__.py @@ -1,6 +1,6 @@ -from ._col_string import _col_string -from ._bootstrap import bootstrap_loop -from ._format_time import _format_time -from ._predict_model import _predict_model -from ._prepare_data import _prepare_data -from ._pad import _pad +from ._bootstrap import bootstrap_loop as bootstrap_loop +from ._col_string import _col_string as _col_string +from ._format_time import _format_time as _format_time +from ._pad import _pad as _pad +from ._predict_model import _predict_model as _predict_model +from ._prepare_data import _prepare_data as _prepare_data diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index 40447b9..7aefef6 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -1,10 +1,12 @@ -from functools import wraps +import copy +import time from concurrent.futures import ProcessPoolExecutor, as_completed -import polars as pl +from functools import wraps + import numpy as np +import polars as pl from tqdm import tqdm -import copy -import time + from ._format_time import _format_time diff --git a/pySEQTarget/initialization/__init__.py b/pySEQTarget/initialization/__init__.py index 32173e2..4f026ca 100644 --- a/pySEQTarget/initialization/__init__.py +++ b/pySEQTarget/initialization/__init__.py @@ -1,4 +1,5 @@ -from ._outcome import _outcome -from ._censoring import _cense_numerator, _cense_denominator -from ._numerator import _numerator -from ._denominator import _denominator +from ._censoring import _cense_denominator as _cense_denominator +from ._censoring import _cense_numerator as _cense_numerator +from ._denominator import _denominator as _denominator +from ._numerator import _numerator as _numerator +from ._outcome import _outcome as _outcome diff --git a/pySEQTarget/plot/__init__.py b/pySEQTarget/plot/__init__.py index fba45b1..f417f55 100644 --- a/pySEQTarget/plot/__init__.py +++ b/pySEQTarget/plot/__init__.py @@ -1 +1 @@ -from ._survival_plot import _survival_plot +from ._survival_plot import _survival_plot as _survival_plot diff --git a/pySEQTarget/plot/_survival_plot.py b/pySEQTarget/plot/_survival_plot.py index f2ed288..0592036 100644 --- a/pySEQTarget/plot/_survival_plot.py +++ b/pySEQTarget/plot/_survival_plot.py @@ -1,7 +1,8 @@ import itertools + import matplotlib.pyplot as plt -import polars as pl import numpy as np +import polars as pl def _survival_plot(self): diff --git a/pySEQTarget/weighting/__init__.py b/pySEQTarget/weighting/__init__.py index 62fb9f7..4874865 100644 --- a/pySEQTarget/weighting/__init__.py +++ b/pySEQTarget/weighting/__init__.py @@ -1,5 +1,7 @@ -from ._weight_fit import _fit_LTFU, _fit_numerator, _fit_denominator -from ._weight_pred import _weight_predict -from ._weight_bind import _weight_bind -from ._weight_data import _weight_setup -from ._weight_stats import _weight_stats +from ._weight_bind import _weight_bind as _weight_bind +from ._weight_data import _weight_setup as _weight_setup +from ._weight_fit import _fit_denominator as _fit_denominator +from ._weight_fit import _fit_LTFU as _fit_LTFU +from ._weight_fit import _fit_numerator as _fit_numerator +from ._weight_pred import _weight_predict as _weight_predict +from ._weight_stats import _weight_stats as _weight_stats diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 145865f..5a858a8 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -1,6 +1,7 @@ -from ..helpers import _predict_model -import polars as pl import numpy as np +import polars as pl + +from ..helpers import _predict_model def _weight_predict(self, WDT): diff --git a/pyproject.toml b/pyproject.toml index b33e142..9b83517 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=60", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "pySEQ" +name = "pySEQTarget" version = "0.9.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" diff --git a/tests/test_accessor.py b/tests/test_accessor.py index ef028e6..ab9b796 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -1,7 +1,8 @@ -from pySEQTarget import SEQuential, SEQopts -from pySEQTarget.data import load_data import pytest +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + def test_ITT_collector(): data = load_data("SEQdata") @@ -21,6 +22,6 @@ def test_ITT_collector(): s.expand() s.fit() collector = s.collect() - outcomes = collector.retrieve_data("unique_outcomes") + collector.retrieve_data("unique_outcomes") with pytest.raises(ValueError): collector.retrieve_data("km_data") diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 90494f8..07435c7 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -1,4 +1,4 @@ -from pySEQTarget import SEQuential, SEQopts +from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data diff --git a/tests/test_covariates.py b/tests/test_covariates.py index aedf54d..9863f8f 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -1,5 +1,5 @@ +from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data -from pySEQTarget import SEQuential, SEQopts def test_ITT_covariates(): diff --git a/tests/test_followup_options.py b/tests/test_followup_options.py index 3e2102d..74d5451 100644 --- a/tests/test_followup_options.py +++ b/tests/test_followup_options.py @@ -1,4 +1,4 @@ -from pySEQTarget import SEQuential, SEQopts +from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data diff --git a/tests/test_hazard.py b/tests/test_hazard.py index 335ed5f..f057dfd 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,5 +1,5 @@ +from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data -from pySEQTarget import SEQuential, SEQopts def test_ITT_hazard(): diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 9a5ef26..2ed0351 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -1,8 +1,10 @@ -from pySEQTarget import SEQuential, SEQopts -from pySEQTarget.data import load_data import os + import pytest +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + @pytest.mark.skipif( os.getenv("CI") == "true", reason="Parallelism test hangs in CI environment" diff --git a/tests/test_survival.py b/tests/test_survival.py index a85d17c..0e4ddfe 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -1,4 +1,4 @@ -from pySEQTarget import SEQuential, SEQopts +from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data From 47a62c9fd2850e3d2d2d94b6380bee34da7422f6 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:32:26 +0100 Subject: [PATCH 8/9] lint should get caught by autoformatter --- .github/workflows/lint.yml | 32 -------------------------------- pySEQTarget/SEQuential.py | 31 ++++++++++++++++++++++++------- pySEQTarget/analysis/__init__.py | 3 +-- 3 files changed, 25 insertions(+), 41 deletions(-) delete mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index 74b3dc9..0000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Lint and Format - -on: - push: - branches: [main, develop] - pull_request: - branches: [main, develop] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - - name: Install dependencies - run: | - pip install ruff black isort mypy - - - name: Run Black - run: black --check . - - - name: Run isort - run: isort --check-only . - - - name: Run Ruff - run: ruff check . - \ No newline at end of file diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 0079891..1485939 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -7,19 +7,36 @@ import numpy as np import polars as pl -from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit, - _pred_risk, _risk_estimates, _subgroup_fit) +from .analysis import ( + _calculate_hazard, + _calculate_survival, + _outcome_fit, + _pred_risk, + _risk_estimates, + _subgroup_fit, +) from .error import _datachecker, _param_checker from .expansion import _binder, _diagnostics, _dynamic, _random_selection from .helpers import _col_string, _format_time, bootstrap_loop -from .initialization import (_cense_denominator, _cense_numerator, - _denominator, _numerator, _outcome) +from .initialization import ( + _cense_denominator, + _cense_numerator, + _denominator, + _numerator, + _outcome, +) from .plot import _survival_plot from .SEQopts import SEQopts from .SEQoutput import SEQoutput -from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator, - _weight_bind, _weight_predict, _weight_setup, - _weight_stats) +from .weighting import ( + _fit_denominator, + _fit_LTFU, + _fit_numerator, + _weight_bind, + _weight_predict, + _weight_setup, + _weight_stats, +) class SEQuential: diff --git a/pySEQTarget/analysis/__init__.py b/pySEQTarget/analysis/__init__.py index 6799dfd..e35ceb7 100644 --- a/pySEQTarget/analysis/__init__.py +++ b/pySEQTarget/analysis/__init__.py @@ -3,6 +3,5 @@ from ._risk_estimates import _risk_estimates as _risk_estimates from ._subgroup_fit import _subgroup_fit as _subgroup_fit from ._survival_pred import _calculate_survival as _calculate_survival -from ._survival_pred import \ - _get_outcome_predictions as _get_outcome_predictions +from ._survival_pred import _get_outcome_predictions as _get_outcome_predictions from ._survival_pred import _pred_risk as _pred_risk From ba2a9ad84baeb054f1442adf45421882074da727 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:43:56 +0100 Subject: [PATCH 9/9] move docs --- {pySEQTarget/docs => docs}/Makefile | 0 {pySEQTarget/docs => docs}/make.bat | 0 {pySEQTarget/docs => docs}/source/conf.py | 15 ++++++++------- docs/source/index.rst | 17 +++++++++++++++++ pySEQTarget/docs/source/index.rst | 20 -------------------- 5 files changed, 25 insertions(+), 27 deletions(-) rename {pySEQTarget/docs => docs}/Makefile (100%) rename {pySEQTarget/docs => docs}/make.bat (100%) rename {pySEQTarget/docs => docs}/source/conf.py (72%) create mode 100644 docs/source/index.rst delete mode 100644 pySEQTarget/docs/source/index.rst diff --git a/pySEQTarget/docs/Makefile b/docs/Makefile similarity index 100% rename from pySEQTarget/docs/Makefile rename to docs/Makefile diff --git a/pySEQTarget/docs/make.bat b/docs/make.bat similarity index 100% rename from pySEQTarget/docs/make.bat rename to docs/make.bat diff --git a/pySEQTarget/docs/source/conf.py b/docs/source/conf.py similarity index 72% rename from pySEQTarget/docs/source/conf.py rename to docs/source/conf.py index 6b536b6..c0630db 100644 --- a/pySEQTarget/docs/source/conf.py +++ b/docs/source/conf.py @@ -6,22 +6,23 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = "SEQuential" -copyright = "2024, Ryan O'Dea, Alejandro Szmulewicz" -author = "Ryan O'Dea, Alejandro Szmulewicz" -release = "0.1.0" +project = 'pySEQTarget' +copyright = "2025, Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" +author = "Ryan O'Dea, Alejandro Szmulewicz, Tom Palmer, Miguel Hernan" +release = '0.9.0' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [] -templates_path = ["_templates"] +templates_path = ['_templates'] exclude_patterns = [] + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = "alabaster" -html_static_path = ["_static"] +html_theme = 'alabaster' +html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..cee549f --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,17 @@ +.. pySEQTarget documentation master file, created by + sphinx-quickstart on Mon Nov 24 20:43:34 2025. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +pySEQTarget documentation +========================= + +Add your content using ``reStructuredText`` syntax. See the +`reStructuredText `_ +documentation for details. + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + diff --git a/pySEQTarget/docs/source/index.rst b/pySEQTarget/docs/source/index.rst deleted file mode 100644 index 167b37b..0000000 --- a/pySEQTarget/docs/source/index.rst +++ /dev/null @@ -1,20 +0,0 @@ -.. SEQuential documentation master file, created by - sphinx-quickstart on Fri Nov 15 14:02:09 2024. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -Welcome to SEQuential's documentation! -====================================== - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search`