A deep learning image captioning system that generates natural language descriptions from images using a custom Transformer decoder, trained on the MS COCO 2017 dataset. The project explores and compares two visual encoder architectures — ResNet-50 (CNN) and ViT-B/16 (Vision Transformer) — to study how different image representations affect caption quality.
Try it now: huggingface.co/spaces/mostafahagali/vit-image-captioning
- Encoder comparison study — trained and evaluated with both ResNet-50 (CNN) and ViT-B/16 to compare how convolutional vs. attention-based visual features impact captioning performance
- Custom Transformer blocks — hand-built multi-head attention encoder and decoder layers (not
nn.Transformerwrapper) - Beam search decoding with length normalization (greedy decoding also available)
- Backbone fine-tuning with learning rate scheduling and warm-up frozen epochs
- Standard NLP metrics evaluation via pycocoevalcap (BLEU, METEOR, ROUGE-L, CIDEr, SPICE)
- Gradio web app for interactive captioning
A core goal of this project was to experimentally compare how different visual backbones affect image captioning:
| CNN (ResNet-50) | ViT (ViT-B/16) | |
|---|---|---|
| How it sees | Local features via sliding convolutions → 7×7 = 49 patches | Global context via self-attention over 14×14 = 196 patches |
| Strengths | Strong local feature extraction, fewer patches = faster | Captures long-range spatial relationships, richer representations |
| Caption quality | Good baseline — solid BLEU/CIDEr scores | Better at describing complex scenes and spatial relationships |
| CIDEr Score | 0.924 | 1.073 (+14.9%) |
Both encoders share the exact same Transformer encoder + decoder pipeline, making this a controlled comparison of visual representations.
┌─────────────────────────────────────────────────┐
Input Image │ Image Captain Pipeline │
(3 × 224 × 224) │ │
│ │ ┌───────────────┐ ┌───────────────────────┐ │
├─── CNN ──────┼─►│ ResNet-50 │───►│ │ │
│ Encoder │ │ (49 patches) │ │ Transformer Encoder │ │
│ │ └───────────────┘ │ (6 layers, 8 heads) │ │
│ │ OR │ │ │
│ │ ┌───────────────┐ │ Self-Attention │ │
└─── ViT ──────┼─►│ ViT-B/16 │───►│ + Feed-Forward │ │
Encoder │ │ (196 patches) │ │ + LayerNorm │ │
│ └───────────────┘ └───────────┬───────────┘ │
│ │ │
│ Encoded Image Features │
│ │ │
│ ┌───────────▼───────────┐ │
<bos> token ──┼──────────────────────►│ Transformer Decoder │ │
│ │ (6 layers, 8 heads) │ │
│ │ │ │
│ │ Masked Self-Attention │ │
│ │ Cross-Attention │ │
│ │ Feed-Forward │ │
│ └───────────┬───────────┘ │
│ │ │
│ Beam Search / │
│ Greedy Decoding │
│ │ │
└───────────────────────────────────┼─────────────┘
▼
"a cat sitting on
a wooden bench"
| Metric | CNN (ResNet-50) | ViT (ViT-B/16) | Delta |
|---|---|---|---|
| BLEU-1 | 0.705 | 0.738 | +3.3% |
| BLEU-2 | 0.532 | 0.575 | +4.3% |
| BLEU-3 | 0.388 | 0.433 | +4.5% |
| BLEU-4 | 0.282 | 0.324 | +4.2% |
| METEOR | 0.243 | 0.266 | +2.3% |
| ROUGE-L | 0.523 | 0.552 | +2.9% |
| CIDEr | 0.924 | 1.073 | +14.9% |
| SPICE | 0.178 | 0.201 | +2.3% |
Key findings:
- ViT-B/16 outperforms ResNet-50 across all 8 metrics
- The largest gap is in CIDEr (+14.9%), indicating ViT generates captions that better match human descriptions
- BLEU improvements (+3–5%) confirm ViT produces more accurate n-gram overlaps
- The consistent improvement across all metrics validates that attention-based visual features provide richer representations for captioning than convolutional features
Image-Captain/
├── app.py # Gradio web interface
├── inference.py # CaptionGenerator class (greedy + beam search)
├── requirements.txt
├── configs/
│ ├── config.yaml # Paths, hyperparameters, encoder selection
│ └── logger.py # Rotating file + console logger
├── models/
│ ├── cnn_encoder.py # ResNet-50 backbone (frozen/unfrozen)
│ ├── vit_encoder.py # ViT-B/16 backbone (frozen/unfrozen)
│ ├── transformer_encoder.py # Custom multi-head self-attention encoder
│ ├── transformer_decoder.py # Custom cross-attention decoder + positional encoding
│ └── image_captioning_model.py # Full model: encoder → transformer → decoder
├── src/
│ ├── data_processing_vocabulary.py # Vocabulary builder, tokenizer, COCO dataset
│ ├── evaluate_metrics.py # Validation loss + BLEU/CIDEr/METEOR evaluation
│ └── train.py # Training loop with fine-tuning & checkpointing
└── images/
└── HuggingFace_live_demo.png
- Python 3.10+
- CUDA-capable GPU (recommended for training)
git clone https://github.com/yourusername/Image-Captain.git
cd Image-Captain
pip install -r requirements.txtThe pre-trained ViT model weights are hosted on Kaggle:
Option 1 — KaggleHub (Python):
pip install kagglehubimport kagglehub
path = kagglehub.model_download(
"mustafamohamed22/vit-image-captioning-coco/pyTorch/default/1"
)
print("Path to model files:", path)Option 2 — Manual download:
- Visit mustafamohamed22/vit-image-captioning-coco
- Click Download and extract the model file
- Place
best_model.pthin the project root
Edit configs/config.yaml to set your dataset paths and training parameters:
paths:
train_img_dir: /path/to/coco2017/train2017
train_ann_file: /path/to/coco2017/annotations/captions_train2017.json
val_img_dir: /path/to/coco2017/val2017
val_ann_file: /path/to/coco2017/annotations/captions_val2017.json
training:
batch_size: 32
d_model: 512
n_head: 8
num_layers: 6
learning_rate: 0.0001
num_epochs: 55
use_vit: true # true = ViT-B/16, false = ResNet-50python src/train.pyTraining features:
- Automatic backbone unfreezing at epoch 6 with reduced learning rate
ReduceLROnPlateauscheduler (factor=0.5, patience=2)- Gradient clipping (max_norm=1.0)
- Checkpoint saving every 5 epochs (+ every epoch after epoch 40)
- BLEU/CIDEr evaluation during training
from inference import CaptionGenerator
generator = CaptionGenerator(
model_path="best_model.pth",
vocab_path="vocab.pkl",
use_vit=True
)
# Beam search (default, higher quality)
caption = generator.generate("photo.jpg")
# Greedy decoding (faster)
caption = generator.generate("photo.jpg", decoding="greedy")
# Custom beam width
caption = generator.generate("photo.jpg", beam_width=10)python app.pyOpens an interactive web UI where you can upload images and receive AI-generated captions.
| Parameter | Value |
|---|---|
| Dataset | MS COCO 2017 (~118K train / ~5K val images) |
| Embedding Dim | 512 |
| Attention Heads | 8 |
| Encoder Layers | 6 |
| Decoder Layers | 6 |
| Feed-Forward Dim | 2048 |
| Dropout | 0.1 |
| Optimizer | AdamW |
| Learning Rate | 1e-4 (1e-5 after unfreeze) |
| Label Smoothing | 0.01 |
| Vocab Threshold | 5 (min word frequency) |
| Max Caption Length | 50 tokens |
| Beam Width | 5 |
-
Custom Transformer layers instead of PyTorch's
nn.Transformer— demonstrates hands-on understanding of multi-head attention, causal masking, and cross-attention mechanisms. -
Two-stage training — backbone is frozen for the first 5 epochs to train the Transformer head, then unfrozen with 10x lower learning rate for end-to-end fine-tuning.
-
Beam search with length normalization — uses log-probability scoring normalized by
length^α(α=0.7) to avoid favoring short captions. -
One sample per image strategy — each
__getitem__picks a random caption from the 5 available, providing natural data augmentation across epochs.
- PyTorch — model, training, inference
- torchvision — ResNet-50, ViT-B/16, image transforms
- pycocoevalcap — BLEU, METEOR, ROUGE-L, CIDEr, SPICE evaluation
- Gradio — interactive web demo
- PyYAML — configuration management
- Attention visualization — overlay heatmaps showing which image regions the model attends to for each word
- Nucleus (top-p) sampling — add stochastic decoding for more diverse and creative captions
- Multi-language captioning — extend the decoder to generate captions in Arabic, French, and other languages
- Larger ViT backbones — experiment with ViT-L/16 and ViT-H/14 for improved visual representations
- CLIP-guided re-ranking — use CLIP scores to re-rank beam search candidates for better image-text alignment
- ONNX / TorchScript export — optimize inference for production deployment
- Comprehensive test suite — add unit and integration tests for data pipeline, model, and evaluation
This project is licensed under the MIT License.
