A text-guided 3D brain tumor segmentation framework built on a unified Mamba architecture. Achieves O(n) sequence complexity with ~60M trainable parameters (embed_dim=48, PubMedBERT last 2 layers unfrozen). Leverages clinical diagnostic text to guide volumetric MRI segmentation, validated on BraTS2020 TextBraTS dataset.
The core data flow is kept clean in the architecture diagram. The following modules are implemented and annotated in code:
Modality Grouping:T1 + T1ce/T2 + FLAIR- Per-stage main block:
3D DWConv → CrossScan BiMamba3D(multi-scale) - Dual-path fusion:
FiLM + Pixel-Text Cross-Attention Uncertainty Gatingwithin CrossScan blocks to suppress high-uncertainty featuresDeep Supervisionfor multi-scale supervision in the decoder
ASCII Architecture Diagram (fallback)
3D MRI Volume Clinical Diagnostic Text
(4ch: T1, T1ce, T2, FLAIR) "MRI shows left frontal lobe mass..."
| |
[Modality Grouping] [PubMedBERT (partially frozen)]
(T1+T1ce / T2+FLAIR) ~109.5M params
| |
[Patch Embed 3D] [Projection + Mamba]
| |
+----------------------+ text_global + text_seq
| Encoder Stage 1..4 |<----- FiLM -------------+ |
| 3D DWConv → | | |
| CrossScan BiMamba3D |<-- Pixel-Text Cross-Attn + |
| (multi-scale) | |
+----------+-----------+ |
| |
[Causal Mamba Fusion + Uncertainty Gating] <---------+
|
Decoder (symmetric + Skip + Deep Supervision)
|
[Final Expand + Conv3D]
v
Segmentation Output [B, 4, D, H, W]
(background / necrotic / edema / enhancing)
-
Unified Mamba Architecture — Encoder, text encoder, fusion, and decoder all built on State Space Models with lightweight Cross-Attention for cross-modal alignment. O(n) overall sequence complexity.
-
CrossScan BiMamba3D — Bidirectional scanning along 3 spatial axes (6 directions total), providing complete volumetric context and resolving the information propagation blind spots of unidirectional SSMs.
-
PubMedBERT Text Encoder — Partially frozen biomedical language model (~109.5M params), with last 2 layers unfrozen for task adaptation and a lightweight Mamba adapter layer (~18K extra params).
-
Multi-scale FiLM Fusion — Text modulates image features via Feature-wise Linear Modulation across all 4 encoder stages, not just the bottleneck.
-
Causal Mamba Bottleneck Fusion — Text tokens prepended to image tokens, leveraging Mamba's causal scan to naturally inject text context into image representations.
-
Robust Training (v2) — Manual warmup + cosine LR schedule, gradient clipping (max_norm=1.0), NaN batch skip, bfloat16 AMP, class-weighted Dice + 3D Sobel edge loss + contrastive loss.
-
Pixel-Text Cross-Attention — Fine-grained alignment between pixel tokens and text tokens for semantic guidance.
-
Uncertainty Gating — Suppresses noisy region features within CrossScan BiMamba3D blocks.
-
Deep Supervision — Multi-scale auxiliary losses at intermediate decoder stages.
# Clone the repository
git clone https://github.com/PlutoLei/TextMamba3D.git
cd TextMamba3D
# Install dependencies
pip install -r requirements.txt
# Download PubMedBERT locally (~440MB)
python scripts/download_pubmedbert.py
# Quick test run (50 samples)
python train.py --config configs/textbrats.yaml --max-samples 50
# Evaluate
python evaluate.py --config configs/default.yaml --checkpoint checkpoints/best.pth- Environment Setup
- Data Preparation
- Model Weights
- Training
- Evaluation
- Inference
- Architecture Details
- Preliminary Results
- Project Structure
- Known Limitations
- FAQ
- Citation
- Acknowledgments
- License
| Item | Requirement |
|---|---|
| Python | >= 3.8 |
| CUDA | >= 11.8 (mamba-ssm requires CUDA compilation) |
| GPU VRAM | >= 8GB (12GB+ recommended) |
| OS | Linux / WSL2 (Windows native does not support mamba-ssm) |
pip install -r requirements.txtFull Setup (from scratch)
# Create virtual environment
python -m venv venv
source venv/bin/activate # Linux / WSL2
# Install PyTorch (choose by CUDA version)
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Install Mamba SSM (requires CUDA toolchain)
pip install mamba-ssm
# Install remaining dependencies
pip install monai nibabel numpy scipy transformers pyyaml tensorboard tqdm einops pytest
# Verify installation
python scripts/verify_installation.pyTextBraTS provides expert-written clinical descriptions for each case, avoiding the information leakage problem of auto-generated text from segmentation masks.
- Source: Kaggle BraTS2020 + HuggingFace TextBraTS
- 369 total cases, split into 220 train / 55 val / 94 test (following the original TextBraTS paper)
data/BraTS2020/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/
├── BraTS20_Training_001/
│ ├── BraTS20_Training_001_t1.nii(.gz)
│ ├── BraTS20_Training_001_t1ce.nii(.gz)
│ ├── BraTS20_Training_001_t2.nii(.gz)
│ ├── BraTS20_Training_001_flair.nii(.gz)
│ ├── BraTS20_Training_001_seg.nii(.gz)
│ └── BraTS20_Training_001_flair_text.txt # Expert text
└── ...
Preprocessing is automatic: 4-modality stacking, Z-score normalization (non-zero voxels only), random/center cropping.
Alternative: BraTS 2021
Download from Synapse or Kaggle.
python train.py --config configs/default.yamlThe text encoder uses PubMedBERT (~440MB). Download before training:
python scripts/download_pubmedbert.pySaved to ./pretrained/pubmedbert/. Set text_model_path: null in config to download from HuggingFace automatically if your environment has internet access.
| Parameter | Value | Notes |
|---|---|---|
| embed_dim | 48 | Reduced from 96 for better param/sample ratio |
| Learning rate | 5e-5 | Manual warmup (10 epochs) + cosine decay |
| Weight decay | 0.01 | Standard AdamW regularization |
| Epochs | 300 | With patience=50 early stopping |
| Gradient clipping | max_norm=1.0 | Prevents gradient explosion |
| AMP | bfloat16 | No GradScaler needed |
| Gradient accumulation | 4 | Effective batch size = 4 |
| Split | 220/55/94 | train/val/test (original paper split) |
python train.py --config configs/textbrats.yamlpython train.py --config configs/textbrats.yaml --max-samples 50python train.py --config configs/textbrats.yaml --resume checkpoints/last.pthtensorboard --logdir logsCommand-line Arguments
| Argument | Default | Description |
|---|---|---|
--config |
configs/default.yaml |
Config file path |
--resume |
None | Resume from checkpoint |
--no-amp |
False | Disable mixed precision |
--no-text-ratio |
0.1 | Fraction of text-free training samples |
--grad-accum |
4 | Gradient accumulation steps |
--max-samples |
None | Limit training samples (for debugging) |
GPU VRAM Guide
| GPU VRAM | Recommended Config |
|---|---|
| 8 GB | batch_size=1, patch_size=[64,64,64], grad_accum=4, AMP on |
| 12 GB | batch_size=1, patch_size=[96,96,96], grad_accum=4, AMP on |
| 24 GB | batch_size=2, patch_size=[96,96,96], grad_accum=2, AMP on |
Gradient checkpointing is on by default (gradient_checkpointing: true), saving ~30-50% VRAM with ~20% speed overhead.
python evaluate.py --config configs/default.yaml --checkpoint checkpoints/best.pthReports standard BraTS three-region metrics:
| Region | Definition |
|---|---|
| ET (Enhancing Tumor) | Enhancing tumor (class 3) |
| TC (Tumor Core) | Necrotic + enhancing (class 1 + 3) |
| WT (Whole Tumor) | All tumor classes (class 1 + 2 + 3) |
Metrics: Dice Score (higher is better) and HD95 Hausdorff Distance (lower is better).
python inference.py \
--checkpoint checkpoints/best.pth \
--input /path/to/BraTS20_Training_001 \
--ttapython inference.py \
--checkpoint checkpoints/best.pth \
--input /path/to/cases_dir \
--batch \
--output ./predictionspython inference.py \
--checkpoint checkpoints/best.pth \
--input /path/to/case \
--text "MRI shows left frontal lobe mass with enhancing component"Outputs NIfTI-format segmentation results with preserved affine matrices.
| Component | Module | Params | Key Design |
|---|---|---|---|
| Modality Grouping + Patch Embed | PatchEmbed3D | ~0.5M | T1+T1ce / T2+FLAIR physical grouping |
| Image Backbone | (3D DWConv → CrossScanBiMamba3D) ×4 | ~12M | Per-stage local conv + multi-scale 6-dir scan |
| Text Encoder | PubMedBERT + Projection + MambaLayer | ~110M (109.5M frozen) | Global/sequential dual text representations |
| Cross-modal Fusion | MultiScaleFiLM + PixelTextCrossAttention | ~2M | Stage-level modulation + pixel-text alignment |
| Uncertainty Gating | UncertaintyGating | ~0.3M | Suppress noisy region features |
| Decoder + Supervision | Decoder ×4 + Deep Supervision Heads | ~5M | Symmetric decoding + multi-scale auxiliary outputs |
| Total | ~130M | With PubMedBERT | |
| Trainable | ~60M | PubMedBERT frozen + last 2 layers unfrozen |
Each encoder stage uses 3D DWConv → CrossScan BiMamba3D: local texture extraction followed by global sequence modeling. CrossScan provides complete 3D context via tri-axial bidirectional scanning:
- D-H-W (depth-first): forward + reverse
- H-W-D (height-first): forward + reverse
- W-D-H (width-first): forward + reverse
Outputs from 6 directions are aggregated through multi-scale branches.
- Multi-scale FiLM: Text global features modulate image features across all 4 encoder stages via
output = gamma * features + beta. - Pixel-Text Cross-Attention: Fine-grained pixel-text alignment complementing FiLM's global modulation.
- Uncertainty Gating: Uncertainty-based gating within CrossScan blocks to reduce noisy feature aggregation.
- Causal Mamba Fusion: At the bottleneck, text tokens are prepended to image tokens, leveraging causal scanning to inject text context.
L_total = L_main + λ_ds * L_deep_supervision + L_edge + λ_c * L_contrastive
L_main = L_dice + L_ce
- L_dice: Class-weighted Dice (weights: [0.25, 3.0, 1.0, 4.0])
- L_ce: Class-weighted cross-entropy
- L_deep_supervision: Multi-scale auxiliary output supervision
- L_edge: 3D Sobel edge-weighted penalty for boundary clarity
- L_contrastive: Bidirectional image-text alignment (active when batch_size > 1)
| Item | Setting |
|---|---|
| Dataset | BraTS2020 TextBraTS |
| Samples & Split | 369 cases: 220 train / 55 val / 94 test |
| Random Seed | 42 |
| Hardware | RTX 4060 Laptop 8GB (WSL2) |
| Input Patch | 64³ |
| Batch Size | 1 |
| Gradient Accumulation | 4 |
| Training Acceleration | bfloat16 AMP + Gradient Checkpointing |
| Status | Training in progress (v2, 300 epochs) |
Note: The model is currently training. Results below are preliminary and do not represent final performance.
Results will be updated upon training completion.
TextMamba3D/
├── configs/
│ ├── default.yaml # BraTS2021 config (96³ patch)
│ └── textbrats.yaml # TextBraTS config (64³ patch, v2 optimized)
├── data/
│ ├── brats_dataset.py # BraTS2021 dataloader
│ ├── brats_textbrats_dataset.py # TextBraTS dataloader (3-way split)
│ ├── text_generator.py # Auto-generate diagnostic text from masks
│ └── transforms.py # 3D augmentation (crop/flip/elastic/noise)
├── models/
│ ├── mamba_block.py # MambaBlock, BiMamba, CrossScanBiMamba3D
│ ├── encoder_3d.py # PatchEmbed3D, PatchMerging3D, MambaEncoder3D
│ ├── text_encoder.py # PubMedBERT + Mamba text encoder
│ ├── fusion.py # FiLM, MultiScaleFiLM, MambaFusion
│ ├── decoder_3d.py # PatchExpanding3D, MambaDecoder3D
│ └── textmamba3d.py # TextMamba3D main model
├── losses/
│ ├── dice_loss.py # Class-weighted Dice loss
│ ├── edge_loss.py # 3D Sobel edge loss
│ ├── contrastive_loss.py # Bidirectional contrastive loss
│ └── __init__.py # CombinedLoss
├── utils/
│ ├── metrics.py # Dice, HD95 (BraTS region metrics)
│ ├── tta.py # Test-time augmentation
│ └── sliding_window.py # Gaussian-weighted sliding window
├── tests/ # 214 tests across 7 test files
├── scripts/
│ ├── download_pubmedbert.py # PubMedBERT download script
│ ├── prepare_brats.py # Data preprocessing script
│ └── verify_installation.py # Installation verification
├── train.py # Training (bfloat16 AMP + grad clipping + manual LR)
├── evaluate.py # Evaluation (BraTS region metrics + TTA)
├── inference.py # Inference (sliding window + TTA + custom text)
├── smoke_test.py # End-to-end smoke test
├── requirements.txt
├── LICENSE
└── README.md
- Currently validated only on BraTS2020 TextBraTS; cross-dataset generalization tests pending.
- Training set limited to 220 cases.
- Text input depends on manually written clinical descriptions; text quality and writing style affect guidance effectiveness.
- mamba-ssm requires Linux/WSL2; Windows native is not supported.
WSL2 cannot connect to HuggingFace / network issues
WSL2 networking may block HuggingFace downloads. Workaround:
- Run the download script on Windows (where networking works):
cd TextMamba3D python scripts/download_pubmedbert.py - WSL2 reads local files via
/mnt/e/...— no internet needed.
To fix WSL2 networking directly:
cat /etc/resolv.conf
sudo sh -c 'echo "nameserver 8.8.8.8" > /etc/resolv.conf'mamba-ssm installation fails
mamba-ssm requires CUDA toolchain. Verify:
nvcc --version # Need CUDA 11.8+Try building from source:
pip install mamba-ssm --no-build-isolationNote: Windows native does not support mamba-ssm. Use Linux or WSL2.
Out of memory (OOM)
- Reduce
patch_size: 96 → 64 (in config YAML) - Ensure
gradient_checkpointing: true(on by default) - Increase accumulation:
--grad-accum 8 - Ensure AMP is enabled (on by default;
--no-ampdisables it)
ImportError: GreedySearchDecoderOnlyOutput
Caused by too-new transformers version. Downgrade:
pip install transformers==4.38.0Training hangs at start
Mamba CUDA kernels are JIT-compiled on first run, typically taking 1-2 minutes. Training resumes automatically.
If this project helps your research, please cite:
@article{textmamba3d2026,
title={TextMamba3D: Text-Guided 3D Medical Image Segmentation with Unified State Space Models},
author={Lei, Yuxuan},
journal={arXiv preprint},
year={2026}
}This project builds on the following works:
- Mamba — State Space Model architecture
- U-Mamba — Mamba for medical image segmentation
- PubMedBERT — Biomedical language model
- BraTS 2020 — Brain Tumor Segmentation Challenge
- TextBraTS — Text-guided brain tumor segmentation (MICCAI 2024)
- MONAI — Medical Open Network for AI
This project is licensed under the Apache License 2.0.
