Skip to content

rajeev-sr/Learning-from-Teaching-Regularization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

39 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Learning-from-Teaching (LoT) Regularization

LICENSE arXiv Python 3.8+ PyTorch

A comprehensive framework for Learning-from-Teaching (LoT) regularization research, combining teacher-student learning with emergent communication. This project explores how bidirectional knowledge transfer between teacher and student models can improve generalization across multiple domains: Natural Language Processing (NLP), Computer Vision (CV), and Reinforcement Learning (RL).

Paper: Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate
Authors: Can Jin, Tong Che, Hongwu Peng, Yiyuan Li, Marco Pavone

πŸ“š Table of Contents

🎯 Overview

Generalization remains a central challenge in machine learning. In this work, we propose Learning from Teaching (LoT), a novel regularization technique for deep neural networks to enhance generalization. Inspired by the human ability to capture concise and abstract patterns, we hypothesize that generalizable correlations are expected to be easier to teach. LoT operationalizes this concept to improve the generalization of the main model with auxiliary student learners.

Key Innovation

Traditional knowledge distillation is unidirectional: teacher β†’ student. LoT introduces bidirectional learning:

  • Forward KL: Student learns to imitate teacher's predictions
  • Reverse KL: Teacher learns to be more "teachable"
  • Emergent Language : Teacher encodes knowledge into symbolic messages that student decodes

Advantages

βœ… Better Generalization: Both teacher and student achieve lower perplexity/error rates
βœ… Flexible Architecture: Works across domains (NLP, CV, RL)
βœ… Interpretable Communication: Emergent messages reveal what the teacher "teaches"
βœ… Curriculum Learning: Gradual integration of messages for stable training


πŸ“ Project Structure

LoT/
β”œβ”€β”€ model/                          # Model architectures
β”‚   β”œβ”€β”€ rnn.py                     # LSTM/GRU language models
β”‚   β”œβ”€β”€ mem_transformer.py         # Transformer-XL implementation
β”‚   β”œβ”€β”€ preresnet.py               # ResNet for image classification
β”‚   β”œβ”€β”€ message_channel.py         # Emergent language communication (Phase 4)
β”‚   └── emergent_language_models.py # Teacher/Student with message encoding/decoding
β”‚
β”œβ”€β”€ trainer/                        # Training scripts for different experiments
β”‚   β”œβ”€β”€ Lot_Emergent.py            # LoT + Emergent Language
β”‚   β”œβ”€β”€ feedback_variants.py       # Positive/Negative/Mixed feedback
β”‚   β”œβ”€β”€ size_ratios.py             # Teacher-Student size ratios
β”‚   β”œβ”€β”€ alternative_metrics.py     # KL/JS/L2/Cosine teaching metrics
β”‚   └── image_classification.py    # CV: Image classification with LoT
β”‚
β”œβ”€β”€ analysis/                       # Analysis and visualization scripts
β”‚   β”œβ”€β”€ analyze_phase4_full.py     # Emergent language analysis
β”‚   β”œβ”€β”€ analyze_phase1.py          # Feedback variants analysis
β”‚   β”œβ”€β”€ analyze_phase2.py          # Size ratio analysis
β”‚   └── visualize_results.py       # General result visualization
β”‚
β”œβ”€β”€ utils/                          # Utility functions
β”‚   β”œβ”€β”€ data_utils.py              # Data loading utilities
β”‚   β”œβ”€β”€ corpus.py                  # NLP corpus handling (PTB, WikiText)
β”‚   β”œβ”€β”€ vocabulary.py              # Vocabulary management
β”‚   └── exp_utils.py               # Experiment utilities
β”‚
β”œβ”€β”€ run/                            # Shell scripts for experiments
β”‚   β”œβ”€β”€ Lot_Emergent.sh            # Run emergent language experiments
β”‚   β”œβ”€β”€ feedback_variants.sh       # Run feedback variant experiments
β”‚   β”œβ”€β”€ size_ratios.sh             # Run size ratio experiments
β”‚   └── getdata.sh                 # Download datasets
β”‚
β”œβ”€β”€ data/                           # Datasets (auto-downloaded)
β”‚   β”œβ”€β”€ ptb/                       # Penn TreeBank
β”‚   β”œβ”€β”€ wikitext-103/              # WikiText-103
β”‚   └── cifar-10-batches-py/       # CIFAR-10
β”‚
β”œβ”€β”€ ckpt/                           # Model checkpoints
β”‚   β”œβ”€β”€ Phase1/                    # Feedback variant checkpoints
β”‚   β”œβ”€β”€ Phase2/                    # Size ratio checkpoints
β”‚   β”œβ”€β”€ Phase3/                    # Alternative metric checkpoints
β”‚   └── Phase4/                    # Emergent language checkpoints
β”‚
β”œβ”€β”€ logs/                           # Training logs
└── result/                         # Analysis results and visualizations

πŸ—οΈ Architecture & Flow Diagrams

Core LoT Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                 Learning-from-Teaching (LoT) Framework               β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Input (x) ───┬───► Teacher Model ───► Teacher Prediction (Ε·_t)
             β”‚                              β”‚
             β”‚                              β”‚
             β”‚                              β–Ό
             β”‚                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
             β”‚                    β”‚   LoT Loss Engine   β”‚
             β”‚                    β”‚  ─────────────────  β”‚
             └───► Student Model ──  β€’ CE Loss (both)   β”‚
                        β”‚         β”‚  β€’ KL(Ε·_s || Ε·_t)   β”‚
                        β”‚         β”‚  β€’ KL(Ε·_t || Ε·_s)   β”‚
                        β–Ό         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
              Student Prediction                β”‚
                    (Ε·_s)                       β”‚
                        β”‚                       β”‚
                        β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                    β”‚
                                    β–Ό
                            Ground Truth (y)

Key Components:

  • Cross-Entropy (CE) Loss: Both models learn from ground truth
  • Forward KL KL(Ε·_s || Ε·_t): Student imitates teacher
  • Reverse KL KL(Ε·_t || Ε·_s): Teacher learns to be teachable
  • Regularization Weight Ξ±: Controls LoT strength

LoT + Emergent Language Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                LoT with Emergent Communication                        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Input (x) ───┬───► Teacher Encoder ───► Message (m) ───┐
             β”‚           β”‚                              β”‚
             β”‚           β”‚                              β”‚
             β”‚           β–Ό                              β–Ό
             β”‚    Teacher Predictor          Student Decoder ───► Ε·_s
             β”‚           β”‚                              β”‚
             β”‚           β–Ό                              β”‚
             β”‚    Teacher Prediction (Ε·_t)             β”‚
             β”‚           β”‚                              β”‚
             β”‚           β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β”‚                      β”‚
             └───────────────────────
                                    β–Ό
                          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                          β”‚  LoT + Language     β”‚
                          β”‚  Loss Computation   β”‚
                          β”‚  ─────────────────  β”‚
                          β”‚  β€’ CE Loss (both)   β”‚
                          β”‚  β€’ KL(Ε·_s || Ε·_t)   β”‚
                          β”‚  β€’ KL(Ε·_t || Ε·_s)   β”‚
                          β”‚  β€’ Entropy Reg      β”‚
                          β”‚  β€’ Compositionality β”‚
                          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Message Channel Details:

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Message Channel (Discrete)                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Teacher Hidden State (h_t)
         β”‚
         β–Ό
    Feature Projection
         β”‚
         β–Ό
    Message Encoder
         β”‚
         β–Ό
    Logits: [batch, msg_length, vocab_size]
         β”‚
         β–Ό
    Gumbel-Softmax Sampling
         β”‚
         β–Ό
    Message: [batch, msg_length, vocab_size]
         β”‚
         β”œβ”€β”€β”€β–Ί (Hard symbols for analysis)
         β”‚
         └───► Student Message Decoder
                      β”‚
                      β–Ό
              Fusion with Student Hidden
                      β”‚
                      β–Ό
              Enhanced Student Prediction

Training Flow with Curriculum Learning:

Epoch 0  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━►  Epoch 60
         
Message   0% ────────────► 10% ─────► 50% ───► 100%
Weight    (Pure LoT)     (Warmup)   (Ramp)  (Full)
          
Phase:    [────── LoT Foundation ──────][─ Message Integration ─]
          Epoch 0-30              Epoch 30-50        Epoch 50-60

πŸ”¬ Research Directions

Feedback Balance Variants

Goal: Compare positive, negative, and mixed teaching feedback

Variant Student β†’ Teacher Teacher β†’ Student Use Case
Positive βœ… KL(Ε·_t || Ε·_s) βœ… KL(Ε·_s || Ε·_t) Bidirectional learning
Negative ❌ None βœ… KL(Ε·_s || Ε·_t) Traditional distillation
Mixed βœ… Half strength βœ… Full strength Asymmetric feedback

Key Finding: Positive (bidirectional) feedback achieves best generalization


Teacher-Student Size Ratios

Goal: Study impact of model capacity ratios

Configuration Teacher Size Student Size Capacity Ratio
Small-Small 650 hidden 650 hidden 1:1
Medium-Small 1300 hidden 650 hidden 2:1
Large-Small 2600 hidden 650 hidden 4:1

Key Finding: Larger teachers provide richer teaching signals, but require careful tuning


Alternative Teaching Metrics

Goal: Explore different divergence measures for teaching

Metric Formula Properties
KL Divergence KL(P || Q) Asymmetric, mode-seeking
JS Divergence Β½KL(P||M) + Β½KL(Q||M) Symmetric, bounded
L2 Distance ||P - Q||Β² Euclidean, simple
Cosine 1 - cos(P,Q) Angle-based, scale-invariant

Key Finding: KL divergence remains most effective for probability distributions


LoT + Emergent Language ⭐

Goal: Enable structured knowledge transfer through emergent communication

Components:

  1. Message Channel: Discrete symbolic communication (Gumbel-Softmax)

    • Vocabulary size: 32-100 symbols
    • Message length: 8-20 tokens
    • Temperature annealing: 1.0 β†’ 0.5
  2. Language Regularization:

    • Entropy: Encourage vocabulary diversity
    • Compositionality: Promote structured messages
    • Curriculum Learning: Gradual message integration
  3. Evaluation Metrics:

    • Perplexity (PPL): Lower is better
    • Vocabulary coverage: % of symbols used
    • Positional entropy: Measures compositionality
    • Topographic similarity: Input-message correlation

Expected Results:

  • Teacher PPL: ~110-130 (similar to baseline)
  • Student PPL: ~120-140 (5-15 points better than baseline ~145)
  • Message benefit: Structured knowledge transfer

πŸš€ Installation

Prerequisites

  • Python 3.8 or higher
  • CUDA-capable GPU (recommended for training)
  • 16GB+ RAM (for larger datasets)

Step 1: Clone Repository

git clone https://github.com/rajeev-sr/Learning-from-Teaching-Regularization.git
cd Learning-from-Teaching-Regularization

Step 2: Create Virtual Environment

# Using conda (recommended)
conda create -n lot python=3.9
conda activate lot

# Or using venv
python3 -m venv env
source env/bin/activate  # Linux/Mac
# env\Scripts\activate   # Windows

Step 3: Install Dependencies

pip install -r requirements.txt

Core Dependencies:

  • torch>=1.8.0: PyTorch framework
  • numpy>=1.20.0: Numerical computing
  • matplotlib>=3.5.0: Visualization
  • wandb: Experiment tracking (optional)
  • datasets: Hugging Face datasets

Step 4: Download Datasets

bash run/getdata.sh

This downloads:

  • Penn TreeBank (PTB): ~1M tokens for language modeling
  • WikiText-103: ~100M tokens for language modeling
  • CIFAR-10: 60K images for image classification

Step 5: Configure WANDB (Optional)

Configure WANDB USER_NAME and API_KEY in the key.config file for experiment tracking.


⚑ Quick Start

Example 1: Train LoT with Emergent Language

# Run with default configuration
bash run/Lot_Emergent.sh

What this does:

  1. Trains teacher model with message encoder
  2. Trains student model with message decoder
  3. Applies LoT regularization + language regularization
  4. Uses curriculum learning (gradual message integration)
  5. Saves checkpoint to ckpt/Phase4/Lot_Emergent_ptb.pt
  6. Logs results to logs/Phase4/Lot_Emergent_ptb.log

Example 2: Analyze Emergent Language

python analysis/analyze_phase4_full.py \
    --checkpoint ckpt/Phase4/Lot_Emergent_ptb.pt \
    --output_dir result/Phase4

Generated visualizations:

  • positional_entropy.png: Compositionality analysis
  • symbol_frequency.png: Vocabulary usage distribution
  • vocabulary_usage.png: Coverage statistics
  • analysis_summary.txt: Detailed metrics

Example 3: Train Baseline LoT (No Messages)

python trainer/feedback_variants.py \
    --data ptb \
    --feedback positive \
    --alpha 1.0 \
    --epochs 60 \
    --save ckpt/Phase1/baseline.pt

πŸ“– Reproducibility Guide

Experiment 1: Feedback Variants

Objective: Compare positive, negative, and mixed feedback strategies

# 1. Positive feedback (bidirectional)
bash run/feedback_variants.sh positive

# 2. Negative feedback (unidirectional)
bash run/feedback_variants.sh negative

# 3. Mixed feedback (asymmetric)
bash run/feedback_variants.sh mixed

# 4. Analyze results
python analysis/analyze_phase1.py --result_dir result/feedback_variants

Experiment 2: LoT with Emergent Language ⭐

Objective: Achieve 5-15 PPL improvement over baseline via emergent communication

Step 1: Train Baseline LoT (for comparison)

python trainer/feedback_variants.py \
    --data ptb \
    --feedback positive \
    --alpha 1.0 \
    --epochs 60 \
    --save ckpt/baseline_lot.pt \
    2>&1 | tee logs/baseline_lot.log

Step 2: Train LoT with Emergent Messages

bash run/Lot_Emergent.sh

Configuration Details:

Dataset:              PTB
Message vocab:        32 symbols
Message length:       8 tokens
LoT alpha:            1.0
Language reg (Ξ²):     0.0001
Epochs:               60
Batch size:           20
Learning rate:        30

Step 3: Analyze Emergent Language

python analysis/analyze_phase4_full.py \
    --checkpoint ckpt/Phase4/Lot_Emergent_ptb.pt \
    --output_dir result/Phase4 \
    --data ptb \
    --batch_size 10

Analysis Outputs:

  1. Compositionality Score: Measures if different positions encode different meanings
  2. Symbol-Meaning Correlation: Maps symbols to semantic clusters
  3. Vocabulary Coverage: % of vocabulary actively used
  4. Positional Entropy: How much information each position carries

Step 4: Compare Results

# Extract final PPL from logs
echo "=== Baseline LoT ==="
grep "Teacher PPL" logs/baseline_lot.log | tail -1
grep "Student PPL" logs/baseline_lot.log | tail -1

echo "=== LoT + Emergent Language ==="
grep "Teacher PPL" logs/Phase4/Lot_Emergent_ptb.log | tail -1
grep "Student PPL" logs/Phase4/Lot_Emergent_ptb.log | tail -1

πŸ“Š Datasets

Penn TreeBank (PTB)

  • Size: ~1M tokens
  • Vocabulary: ~10K words
  • Splits: Train (42K sentences), Valid (3.3K), Test (3.7K)
  • Use Case: Standard language modeling benchmark

WikiText-103

  • Size: ~103M tokens
  • Vocabulary: ~268K words
  • Splits: Train (1.8M sentences), Valid (3.7K), Test (4.3K)
  • Use Case: Large-scale language modeling

CIFAR-10

  • Size: 60K images (32Γ—32 RGB)
  • Classes: 10 (airplane, car, bird, cat, deer, dog, frog, horse, ship, truck)
  • Splits: Train (50K), Test (10K)
  • Use Case: Image classification with LoT

🀝 Contributing

We welcome contributions! Please:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit changes (git commit -m 'Add amazing feature')
  4. Push to branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Areas for contribution:

  • New message channel architectures (e.g., Transformer-based)
  • Additional datasets (e.g., GLUE, ImageNet)
  • Improved language metrics
  • Multi-modal emergent communication

πŸ“ Citation

If you use this code in your research, please cite:

@article{jin2024learning,
  title={Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate},
  author={Jin, Can and Che, Tong and Peng, Hongwu and Li, Yiyuan and Pavone, Marco},
  journal={arXiv preprint arXiv:2402.02769},
  year={2024}
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.


πŸ™ Acknowledgments

  • PyTorch Team: For the deep learning framework
  • Penn TreeBank & WikiText: For language modeling datasets
  • OpenAI Gym: For reinforcement learning environments
  • Community Contributors: For valuable feedback and improvements

πŸ“ž Contact

For questions, issues, or collaborations:


⭐ Star this repository if you find it useful!

πŸ› Found a bug? Report it here

πŸ’‘ Have a suggestion? Open a discussion


Citation

We encourage citing our paper if our findings are used in your research.

@misc{jin2024learning,
      title={Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate}, 
      author={Can Jin and Tong Che and Hongwu Peng and Yiyuan Li and Marco Pavone},
      year={2024},
      eprint={2402.02769},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •