diff --git a/src/inference/__init__.py b/src/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/inference/predict.py b/src/inference/predict.py new file mode 100644 index 0000000..8b5bd74 --- /dev/null +++ b/src/inference/predict.py @@ -0,0 +1,111 @@ +import numpy as np +import tempfile +import torch +import os +import nibabel as nib +from pathlib import Path +from fastapi import FastAPI, UploadFile, File, Request, HTTPException, BackgroundTasks +from fastapi.responses import FileResponse +from contextlib import asynccontextmanager +from monai.inferers import sliding_window_inference +from src.training.model import get_model +from src.preprocessing.preprocess import preprocess_array +from src.training.transforms import get_inference_transforms + +CHECKPOINT_DIR = Path(os.getenv("CHECKPOINT_DIR", "checkpoints")) +@asynccontextmanager +async def lifespan(app): + # Run on startup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + app.state.device = device + model = get_model().to(device) + app.state.model = model + checkpoint = torch.load(CHECKPOINT_DIR / "best_model.pth", map_location=device) + model.load_state_dict(checkpoint["model"]) + model.eval() + + yield + # Run on shutdown + +app = FastAPI(lifespan=lifespan) +@app.post("/segment") +async def segment(request: Request, + background_tasks: BackgroundTasks, + t1c: UploadFile = File(...), + t1n: UploadFile = File(...), + t2f: UploadFile = File(...), + t2w: UploadFile = File(...)): + + try: + model = request.app.state.model + device = request.app.state.device + + with tempfile.NamedTemporaryFile(suffix=".nii", delete=False) as tmp: + tmp.write(await t1c.read()) + t1c_path = tmp.name + + with tempfile.NamedTemporaryFile(suffix=".nii", delete=False) as tmp: + tmp.write(await t1n.read()) + t1n_path = tmp.name + + with tempfile.NamedTemporaryFile(suffix=".nii", delete=False) as tmp: + tmp.write(await t2f.read()) + t2f_path = tmp.name + + with tempfile.NamedTemporaryFile(suffix=".nii", delete=False) as tmp: + tmp.write(await t2w.read()) + t2w_path = tmp.name + + affine = nib.load(t1c_path) + modality_arrays = {"t1c": affine.get_fdata(), + "t1n": nib.load(t1n_path).get_fdata(), + "t2f": nib.load(t2f_path).get_fdata(), + "t2w": nib.load(t2w_path).get_fdata() + } + + modality_arrays = preprocess_array(modality_arrays) + + image = np.stack((modality_arrays["t1c"], + modality_arrays["t1n"], + modality_arrays["t2f"], + modality_arrays["t2w"]), + axis=0 + ) + + tensors = {"image": image} + tensors = get_inference_transforms()(tensors) + image = tensors["image"] + + image = image.float().unsqueeze(0).to(device) + + output = sliding_window_inference(image, + roi_size=(128, 128, 128), + sw_batch_size=1, + predictor=model + ) + + # Convert to segmentation mask + seg_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy().astype(np.uint8) + + # Save as NIfTI using the affine from the input + seg_nifti = nib.Nifti1Image(seg_mask, affine.affine) + + # Save to a temp file + with tempfile.NamedTemporaryFile(suffix=".nii", delete=False) as tmp: + nib.save(seg_nifti, tmp.name) + output_path = tmp.name + + # Temp file cleanup + for path in [t1c_path, t1n_path, t2f_path, t2w_path]: + os.unlink(path) + + background_tasks.add_task(os.unlink, output_path) + + # Return the file + return FileResponse(output_path, media_type="application/octet-stream", filename="segmentation.nii") + except Exception as e: + raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") + +@app.get("/health") +async def health(): + return {"status": "ok"} \ No newline at end of file diff --git a/src/preprocessing/preprocess.py b/src/preprocessing/preprocess.py index b9984e5..d219d07 100644 --- a/src/preprocessing/preprocess.py +++ b/src/preprocessing/preprocess.py @@ -5,7 +5,7 @@ def preprocess_case(case_dir: Path, output_dir: Path): """ - Preprocesses a case to prepare it for training. + Preprocesses a case file to prepare it for training. Args: case_dir (Path): Directory of the target case @@ -42,6 +42,28 @@ def preprocess_case(case_dir: Path, output_dir: Path): # Delete case_files dictionary to save memory del case_files + # Process arrays + modality_arrays = preprocess_array(modality_arrays) + + # Save the preprocessed modalities + case_output_dir = output_dir / case_name + case_output_dir.mkdir(parents=True, exist_ok=True) + + for modality in modality_arrays: + nib.save(nib.Nifti1Image(modality_arrays[modality], affine), case_output_dir / f"{case_name}-{modality}.nii.gz") + + return None + +def preprocess_array(modality_arrays): + """ + Preprocesses the dictionary containing target arrays. + + Args: + modality_arrays: Dictionary of each modality to be processed + + Returns: + modality_arrays: Processed modality_arrays dictionary + """ # Instantiate bias field corrector corrector = sitk.N4BiasFieldCorrectionImageFilter() @@ -83,11 +105,4 @@ def preprocess_case(case_dir: Path, output_dir: Path): for modality in modality_arrays: modality_arrays[modality] = modality_arrays[modality][x_min:x_max, y_min:y_max, z_min:z_max] - # Save the preprocessed modalities - case_output_dir = output_dir / case_name - case_output_dir.mkdir(parents=True, exist_ok=True) - - for modality in modality_arrays: - nib.save(nib.Nifti1Image(modality_arrays[modality], affine), case_output_dir / f"{case_name}-{modality}.nii.gz") - - return None \ No newline at end of file + return modality_arrays \ No newline at end of file diff --git a/src/training/transforms.py b/src/training/transforms.py index 2fe4ff0..15207aa 100644 --- a/src/training/transforms.py +++ b/src/training/transforms.py @@ -47,3 +47,17 @@ def get_val_transforms(): Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), SpatialPadd(keys=["image", "label"], spatial_size=(128, 128, 128)) ]) + + +def get_inference_transforms(): + """ + Apply deterministic transforms for inference — no augmentation. + + Returns: + Compose: Transform pipeline for inference + """ + return Compose([ + EnsureTyped(keys=["image"]), + Orientationd(keys=["image"], axcodes="RAS"), + Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear") + ]) \ No newline at end of file