A comprehensive framework for Learning-from-Teaching (LoT) regularization research, combining teacher-student learning with emergent communication. This project explores how bidirectional knowledge transfer between teacher and student models can improve generalization across multiple domains: Natural Language Processing (NLP), Computer Vision (CV), and Reinforcement Learning (RL).
Paper: Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate
Authors: Can Jin, Tong Che, Hongwu Peng, Yiyuan Li, Marco Pavone
- Overview
- Project Structure
- Architecture & Flow Diagrams
- Research Directions
- Installation
- Quick Start
- Reproducibility Guide
- Datasets
- Results
- Citation
- License
Generalization remains a central challenge in machine learning. In this work, we propose Learning from Teaching (LoT), a novel regularization technique for deep neural networks to enhance generalization. Inspired by the human ability to capture concise and abstract patterns, we hypothesize that generalizable correlations are expected to be easier to teach. LoT operationalizes this concept to improve the generalization of the main model with auxiliary student learners.
Traditional knowledge distillation is unidirectional: teacher β student. LoT introduces bidirectional learning:
- Forward KL: Student learns to imitate teacher's predictions
- Reverse KL: Teacher learns to be more "teachable"
- Emergent Language : Teacher encodes knowledge into symbolic messages that student decodes
β
Better Generalization: Both teacher and student achieve lower perplexity/error rates
β
Flexible Architecture: Works across domains (NLP, CV, RL)
β
Interpretable Communication: Emergent messages reveal what the teacher "teaches"
β
Curriculum Learning: Gradual integration of messages for stable training
LoT/
βββ model/ # Model architectures
β βββ rnn.py # LSTM/GRU language models
β βββ mem_transformer.py # Transformer-XL implementation
β βββ preresnet.py # ResNet for image classification
β βββ message_channel.py # Emergent language communication (Phase 4)
β βββ emergent_language_models.py # Teacher/Student with message encoding/decoding
β
βββ trainer/ # Training scripts for different experiments
β βββ Lot_Emergent.py # LoT + Emergent Language
β βββ feedback_variants.py # Positive/Negative/Mixed feedback
β βββ size_ratios.py # Teacher-Student size ratios
β βββ alternative_metrics.py # KL/JS/L2/Cosine teaching metrics
β βββ image_classification.py # CV: Image classification with LoT
β
βββ analysis/ # Analysis and visualization scripts
β βββ analyze_phase4_full.py # Emergent language analysis
β βββ analyze_phase1.py # Feedback variants analysis
β βββ analyze_phase2.py # Size ratio analysis
β βββ visualize_results.py # General result visualization
β
βββ utils/ # Utility functions
β βββ data_utils.py # Data loading utilities
β βββ corpus.py # NLP corpus handling (PTB, WikiText)
β βββ vocabulary.py # Vocabulary management
β βββ exp_utils.py # Experiment utilities
β
βββ run/ # Shell scripts for experiments
β βββ Lot_Emergent.sh # Run emergent language experiments
β βββ feedback_variants.sh # Run feedback variant experiments
β βββ size_ratios.sh # Run size ratio experiments
β βββ getdata.sh # Download datasets
β
βββ data/ # Datasets (auto-downloaded)
β βββ ptb/ # Penn TreeBank
β βββ wikitext-103/ # WikiText-103
β βββ cifar-10-batches-py/ # CIFAR-10
β
βββ ckpt/ # Model checkpoints
β βββ Phase1/ # Feedback variant checkpoints
β βββ Phase2/ # Size ratio checkpoints
β βββ Phase3/ # Alternative metric checkpoints
β βββ Phase4/ # Emergent language checkpoints
β
βββ logs/ # Training logs
βββ result/ # Analysis results and visualizations
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Learning-from-Teaching (LoT) Framework β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Input (x) ββββ¬ββββΊ Teacher Model ββββΊ Teacher Prediction (Ε·_t)
β β
β β
β βΌ
β βββββββββββββββββββββββ
β β LoT Loss Engine β
β β βββββββββββββββββ β
βββββΊ Student Model ββ€ β’ CE Loss (both) β
β β β’ KL(Ε·_s || Ε·_t) β
β β β’ KL(Ε·_t || Ε·_s) β
βΌ βββββββββββββββββββββββ
Student Prediction β
(Ε·_s) β
β β
βββββββββββββββββββββββββ
β
βΌ
Ground Truth (y)
Key Components:
- Cross-Entropy (CE) Loss: Both models learn from ground truth
- Forward KL
KL(Ε·_s || Ε·_t): Student imitates teacher - Reverse KL
KL(Ε·_t || Ε·_s): Teacher learns to be teachable - Regularization Weight Ξ±: Controls LoT strength
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β LoT with Emergent Communication β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Input (x) ββββ¬ββββΊ Teacher Encoder ββββΊ Message (m) ββββ
β β β
β β β
β βΌ βΌ
β Teacher Predictor Student Decoder ββββΊ Ε·_s
β β β
β βΌ β
β Teacher Prediction (Ε·_t) β
β β β
β ββββββββββββ¬ββββββββββββββββββββ
β β
ββββββββββββββββββββββββ€
βΌ
βββββββββββββββββββββββ
β LoT + Language β
β Loss Computation β
β βββββββββββββββββ β
β β’ CE Loss (both) β
β β’ KL(Ε·_s || Ε·_t) β
β β’ KL(Ε·_t || Ε·_s) β
β β’ Entropy Reg β
β β’ Compositionality β
βββββββββββββββββββββββ
Message Channel Details:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Message Channel (Discrete) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Teacher Hidden State (h_t)
β
βΌ
Feature Projection
β
βΌ
Message Encoder
β
βΌ
Logits: [batch, msg_length, vocab_size]
β
βΌ
Gumbel-Softmax Sampling
β
βΌ
Message: [batch, msg_length, vocab_size]
β
βββββΊ (Hard symbols for analysis)
β
βββββΊ Student Message Decoder
β
βΌ
Fusion with Student Hidden
β
βΌ
Enhanced Student Prediction
Training Flow with Curriculum Learning:
Epoch 0 βββββββββββββββββββββββββββββββββββΊ Epoch 60
Message 0% βββββββββββββΊ 10% ββββββΊ 50% ββββΊ 100%
Weight (Pure LoT) (Warmup) (Ramp) (Full)
Phase: [ββββββ LoT Foundation ββββββ][β Message Integration β]
Epoch 0-30 Epoch 30-50 Epoch 50-60
Goal: Compare positive, negative, and mixed teaching feedback
| Variant | Student β Teacher | Teacher β Student | Use Case |
|---|---|---|---|
| Positive | β KL(Ε·_t || Ε·_s) | β KL(Ε·_s || Ε·_t) | Bidirectional learning |
| Negative | β None | β KL(Ε·_s || Ε·_t) | Traditional distillation |
| Mixed | β Half strength | β Full strength | Asymmetric feedback |
Key Finding: Positive (bidirectional) feedback achieves best generalization
Goal: Study impact of model capacity ratios
| Configuration | Teacher Size | Student Size | Capacity Ratio |
|---|---|---|---|
| Small-Small | 650 hidden | 650 hidden | 1:1 |
| Medium-Small | 1300 hidden | 650 hidden | 2:1 |
| Large-Small | 2600 hidden | 650 hidden | 4:1 |
Key Finding: Larger teachers provide richer teaching signals, but require careful tuning
Goal: Explore different divergence measures for teaching
| Metric | Formula | Properties |
|---|---|---|
| KL Divergence | KL(P || Q) |
Asymmetric, mode-seeking |
| JS Divergence | Β½KL(P||M) + Β½KL(Q||M) |
Symmetric, bounded |
| L2 Distance | ||P - Q||Β² |
Euclidean, simple |
| Cosine | 1 - cos(P,Q) |
Angle-based, scale-invariant |
Key Finding: KL divergence remains most effective for probability distributions
Goal: Enable structured knowledge transfer through emergent communication
Components:
-
Message Channel: Discrete symbolic communication (Gumbel-Softmax)
- Vocabulary size: 32-100 symbols
- Message length: 8-20 tokens
- Temperature annealing: 1.0 β 0.5
-
Language Regularization:
- Entropy: Encourage vocabulary diversity
- Compositionality: Promote structured messages
- Curriculum Learning: Gradual message integration
-
Evaluation Metrics:
- Perplexity (PPL): Lower is better
- Vocabulary coverage: % of symbols used
- Positional entropy: Measures compositionality
- Topographic similarity: Input-message correlation
Expected Results:
- Teacher PPL: ~110-130 (similar to baseline)
- Student PPL: ~120-140 (5-15 points better than baseline ~145)
- Message benefit: Structured knowledge transfer
- Python 3.8 or higher
- CUDA-capable GPU (recommended for training)
- 16GB+ RAM (for larger datasets)
git clone https://github.com/rajeev-sr/Learning-from-Teaching-Regularization.git
cd Learning-from-Teaching-Regularization# Using conda (recommended)
conda create -n lot python=3.9
conda activate lot
# Or using venv
python3 -m venv env
source env/bin/activate # Linux/Mac
# env\Scripts\activate # Windowspip install -r requirements.txtCore Dependencies:
torch>=1.8.0: PyTorch frameworknumpy>=1.20.0: Numerical computingmatplotlib>=3.5.0: Visualizationwandb: Experiment tracking (optional)datasets: Hugging Face datasets
bash run/getdata.shThis downloads:
- Penn TreeBank (PTB): ~1M tokens for language modeling
- WikiText-103: ~100M tokens for language modeling
- CIFAR-10: 60K images for image classification
Configure WANDB USER_NAME and API_KEY in the key.config file for experiment tracking.
# Run with default configuration
bash run/Lot_Emergent.shWhat this does:
- Trains teacher model with message encoder
- Trains student model with message decoder
- Applies LoT regularization + language regularization
- Uses curriculum learning (gradual message integration)
- Saves checkpoint to
ckpt/Phase4/Lot_Emergent_ptb.pt - Logs results to
logs/Phase4/Lot_Emergent_ptb.log
python analysis/analyze_phase4_full.py \
--checkpoint ckpt/Phase4/Lot_Emergent_ptb.pt \
--output_dir result/Phase4Generated visualizations:
positional_entropy.png: Compositionality analysissymbol_frequency.png: Vocabulary usage distributionvocabulary_usage.png: Coverage statisticsanalysis_summary.txt: Detailed metrics
python trainer/feedback_variants.py \
--data ptb \
--feedback positive \
--alpha 1.0 \
--epochs 60 \
--save ckpt/Phase1/baseline.ptObjective: Compare positive, negative, and mixed feedback strategies
# 1. Positive feedback (bidirectional)
bash run/feedback_variants.sh positive
# 2. Negative feedback (unidirectional)
bash run/feedback_variants.sh negative
# 3. Mixed feedback (asymmetric)
bash run/feedback_variants.sh mixed
# 4. Analyze results
python analysis/analyze_phase1.py --result_dir result/feedback_variantsObjective: Achieve 5-15 PPL improvement over baseline via emergent communication
python trainer/feedback_variants.py \
--data ptb \
--feedback positive \
--alpha 1.0 \
--epochs 60 \
--save ckpt/baseline_lot.pt \
2>&1 | tee logs/baseline_lot.logbash run/Lot_Emergent.shConfiguration Details:
Dataset: PTB
Message vocab: 32 symbols
Message length: 8 tokens
LoT alpha: 1.0
Language reg (Ξ²): 0.0001
Epochs: 60
Batch size: 20
Learning rate: 30python analysis/analyze_phase4_full.py \
--checkpoint ckpt/Phase4/Lot_Emergent_ptb.pt \
--output_dir result/Phase4 \
--data ptb \
--batch_size 10Analysis Outputs:
- Compositionality Score: Measures if different positions encode different meanings
- Symbol-Meaning Correlation: Maps symbols to semantic clusters
- Vocabulary Coverage: % of vocabulary actively used
- Positional Entropy: How much information each position carries
# Extract final PPL from logs
echo "=== Baseline LoT ==="
grep "Teacher PPL" logs/baseline_lot.log | tail -1
grep "Student PPL" logs/baseline_lot.log | tail -1
echo "=== LoT + Emergent Language ==="
grep "Teacher PPL" logs/Phase4/Lot_Emergent_ptb.log | tail -1
grep "Student PPL" logs/Phase4/Lot_Emergent_ptb.log | tail -1- Size: ~1M tokens
- Vocabulary: ~10K words
- Splits: Train (42K sentences), Valid (3.3K), Test (3.7K)
- Use Case: Standard language modeling benchmark
- Size: ~103M tokens
- Vocabulary: ~268K words
- Splits: Train (1.8M sentences), Valid (3.7K), Test (4.3K)
- Use Case: Large-scale language modeling
- Size: 60K images (32Γ32 RGB)
- Classes: 10 (airplane, car, bird, cat, deer, dog, frog, horse, ship, truck)
- Splits: Train (50K), Test (10K)
- Use Case: Image classification with LoT
We welcome contributions! Please:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit changes (
git commit -m 'Add amazing feature') - Push to branch (
git push origin feature/amazing-feature) - Open a Pull Request
Areas for contribution:
- New message channel architectures (e.g., Transformer-based)
- Additional datasets (e.g., GLUE, ImageNet)
- Improved language metrics
- Multi-modal emergent communication
If you use this code in your research, please cite:
@article{jin2024learning,
title={Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate},
author={Jin, Can and Che, Tong and Peng, Hongwu and Li, Yiyuan and Pavone, Marco},
journal={arXiv preprint arXiv:2402.02769},
year={2024}
}This project is licensed under the MIT License - see the LICENSE file for details.
- PyTorch Team: For the deep learning framework
- Penn TreeBank & WikiText: For language modeling datasets
- OpenAI Gym: For reinforcement learning environments
- Community Contributors: For valuable feedback and improvements
For questions, issues, or collaborations:
- GitHub Issues: Create an issue
- Discussions: GitHub Discussions
β Star this repository if you find it useful!
π Found a bug? Report it here
π‘ Have a suggestion? Open a discussion
We encourage citing our paper if our findings are used in your research.
@misc{jin2024learning,
title={Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate},
author={Can Jin and Tong Che and Hongwu Peng and Yiyuan Li and Marco Pavone},
year={2024},
eprint={2402.02769},
archivePrefix={arXiv},
primaryClass={cs.LG}
}