Skip to content

desilo-doyoung/modified_bert

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FHE-friendly BERT (CGF softmax) on MRPC

BERT-base fine-tuned on GLUE/MRPC with the attention softmax replaced by a division-free "CGF" softmax. Ordinary softmax needs an input-dependent reciprocal of the exp-sum, which is the operation hardest to evaluate under fully-homomorphic encryption (FHE). The CGF variant estimates the log-sum-exp normalizer from the mean and variance of the attention scores instead, so there is no input-dependent division (and no comparison/max-subtraction).

Headline result: the CGF approximation costs ~5 pp under plain fine-tuning, but knowledge distillation from an exact-softmax teacher closes the gap. At the principled coefficient c=0.5 with per-layer hidden-state matching, the CGF student matches exact softmax (10-seed mean 0.8595 vs exact 0.8630). See summary.md for the full investigation log.

Layout

All code lives under finetuning/:

finetuning/
  cgf_bert/            # the reusable library
    config.py          # BERT repo id + reference-recipe defaults
    softmax.py         # cgf_softmax + exact_softmax
    model.py           # BertMRPC, load_weights, make_exact
    data.py            # MRPC tokenization + DataLoaders
    engine.py          # set_seed/freeze/optimizer + the shared fit() loop
  finetune.py          # step: CE fine-tune (exact teacher OR CGF baseline)
  distill.py           # step: KD exact-teacher -> CGF student  (headline)
  validate.py          # step: score a checkpoint
  experiments/
    exact_vs_cgf.py    # exact-vs-CGF ceiling probe
    coeff_sweep.py     # sweep the CGF variance coefficient

The fit() loop in engine.py drives every script; the per-batch update is a pluggable train_step, so CE fine-tuning and KD share the same epoch / best-checkpoint machinery (the KD step lives in distill.py).

Usage

Compute runs on GPU — set CUDA_VISIBLE_DEVICES to a free device. Run the step scripts from inside finetuning/; run experiments as modules (so the cgf_bert package resolves):

cd finetuning

# 1. Train the exact-softmax teacher (cached for distillation)
uv run python finetune.py --softmax exact -o ../checkpoints_teacher/teacher.pt

# 2. Distill it into a CGF student (the recovery method) — c=0.5, beta=10
uv run python distill.py --coeff 0.5 --beta 10 --seeds 42 7 123

# 3. Score a checkpoint (defaults to the teacher)
uv run python validate.py
uv run python validate.py --checkpoint ../checkpoints/mrpc_best.pt --softmax cgf

# CGF baseline without KD, and the experiments
uv run python finetune.py --softmax cgf --coeff 0.65
uv run python -m experiments.exact_vs_cgf --softmax cgf --coeff 0.65
uv run python -m experiments.coeff_sweep --coeffs 0.6 0.65 0.7 0.75

distill.py auto-loads ../checkpoints_teacher/teacher.pt if present, else trains the teacher once (step 1 is optional).

Reference recipe

5 epochs · frozen embeddings · linear LR decay (no warmup) · lr 5e-5 · batch 16 · best-checkpoint selection (defaults in cgf_bert/config.py). This is the recipe that takes an exact-softmax BERT to ~86% on MRPC.

The CGF variance-damping coefficient is VARIANCE_COEFF in cgf_bert/softmax.py. Under CE fine-tuning the robust optimum is 0.65 (the basin is narrow — below it the model collapses, above ~0.85 it over-subtracts). Under distillation the principled Gaussian value 0.5 is best.

Results (MRPC validation accuracy)

model recovery mean acc
Exact softmax 0.8630
CGF softmax KD, c=0.5, beta=10 0.8595 (10 seeds)
CGF softmax CE fine-tune, c=0.65 0.8194
CGF softmax CE fine-tune, c=0.5 0.7273

Plain fine-tuning leaves a ~5 pp gap (and the recipe gains that lift exact softmax do not transfer through CGF). KD with per-layer hidden-state matching (beta>0) is what closes it: on 2-class MRPC the logit channel carries too little signal, so matching the teacher's intermediate representations is the effective lever — beta=0 only reaches 0.799.

Reproducibility / fairness

CUDA matmul is nondeterministic by default, so even same-seed runs drift (~2 pp on the 408-example val set). Report the mean across a fixed seed set, and compare configs on the same seeds (paired comparison cancels seed luck). Do not cherry-pick the strongest single seed as "the result."

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages