This repository contains a PyTorch implementation of SimCLR, a self-supervised learning framework for visual representation learning using contrastive loss.
- Features
- Project Structure
- Installation
- Configuration
- Usage
- Training Details
- Checkpoints & Logging
- Notes
- Modular SimCLR model with configurable ResNet encoder and MLP projection head
- Custom LARS optimizer implementation
- Data augmentation pipeline for CIFAR-10
- PyTorch Lightning DataModule for easy data handling
- Configurable via YAML file
├── main.py # CLI entry point
├── train.py # Training and validation logic
├── preprocessing.py # Data loading and augmentation
├── model/
│ ├── model.py # SimCLR model definition
│ ├── encoder.py # ResNet encoder
├── utils/
│ └── lars.py # LARS optimizer
├── config/
│ └── config.yml # Main configuration file
├── data/ # CIFAR-10 dataset (auto-downloaded)
├── checkpoints/ # Model checkpoints
├── README.md
├── .gitignore
- Clone the repository:
git clone https://github.com/tuvv3ct0r/simclr_pytorch.git cd simclr_pytorch - Install dependencies:
Create a virtual environment (optional but recommended):
Install required packages:
python3 -m venv venv source venv/bin/activatepip install torch torchvision pytorch-lightning
All settings are managed via config/config.yml. Key sections:
- model: Encoder and projection head architecture
- training: Optimizer, learning rate, batch size, epochs, etc.
- augmentation: Data augmentation parameters
- dataset: Dataset name and path
- checkpoint: Checkpoint directory and frequency
- logging: Logging directory and interval
Run the main CLI script:
python main.py --config config/config.yml --train # Train the model
python main.py --config config/config.yml --eval # Evaluate on validation set
python main.py --config config/config.yml --test # (WIP) Test mode- The CIFAR-10 dataset will be downloaded automatically to
data/. - Checkpoints will be saved in
checkpoints/simclr/.
- Loss: NT-Xent (Normalized Temperature-scaled Cross Entropy Loss)
- Optimizers: Adam (default), LARS (for large-batch training)
- Augmentations: Random crop, color jitter, grayscale, Gaussian blur, normalization
- Encoder: Custom ResNet (configurable channels)
- Checkpoints are saved every
save_freqepochs and the best model is tracked by validation loss. - Logging is printed to stdout; TensorBoard support is planned (see config).
- Only CIFAR-10 is supported out-of-the-box, but you can adapt the DataModule for other datasets.
- The
.gitignoreexcludesdata/and__pycache__/by default. - For custom experiments, modify
config/config.yml.
- Python 3.8+
- torch
- torchvision
- pytorch-lightning
For questions or contributions, please open an issue or pull request.