diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6679ba2 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: "brats-tumor-segmentation-pipeline" +on: + push: + branches: [main] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: "Checkout code" + uses: actions/checkout@v4 + - name: "Setup python 3.11 environment" + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: "Install python libraries" + run: | + pip install -r requirements.txt + pip install -e . + - name: "Run ruff linting" + run: ruff check src/ + - name: "Run pytest" + run: pytest tests/ -v + diff --git a/src/preprocessing/create_splits.py b/src/preprocessing/create_splits.py index 685979f..96d1b2d 100644 --- a/src/preprocessing/create_splits.py +++ b/src/preprocessing/create_splits.py @@ -26,7 +26,7 @@ def create_splits(data_dir: Path, output_dir: Path): for case in case_names: seg = nib.load(data_dir / case / f"{case}-seg.nii.gz").get_fdata() labels = np.unique(seg) - tumor_labels = [str(l) for l in [1, 2, 3] if l in labels] + tumor_labels = [str(label_id) for label_id in [1, 2, 3] if label_id in labels] strat_labels.append("_".join(tumor_labels)) label_counts = Counter(strat_labels) diff --git a/src/training/train.py b/src/training/train.py index 75bb2e4..e918d2a 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -212,7 +212,7 @@ def main(): "epoch": epoch + 1, "best_val_loss": best_val_loss, "scheduler": scheduler.state_dict() - }, CHECKPOINT_DIR / f"best_model.pth") + }, CHECKPOINT_DIR / "best_model.pth") else: epochs_no_improvement += 1 diff --git a/tests/test_dataset.py b/tests/test_dataset.py index eed179b..94e07ef 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ from pathlib import Path from src.training.dataset import BraTSDataset +@pytest.mark.skipif(not Path("data/processed").exists(), reason="No data available") def test_dataset_output(): PROJECT_ROOT = Path(__file__).resolve().parents[1] DATA_DIR = PROJECT_ROOT / "data" / "processed"