BERT-base fine-tuned on GLUE/MRPC with the attention softmax replaced by a division-free "CGF" softmax. Ordinary softmax needs an input-dependent reciprocal of the exp-sum, which is the operation hardest to evaluate under fully-homomorphic encryption (FHE). The CGF variant estimates the log-sum-exp normalizer from the mean and variance of the attention scores instead, so there is no input-dependent division (and no comparison/max-subtraction).
Headline result: the CGF approximation costs ~5 pp under plain fine-tuning,
but knowledge distillation from an exact-softmax teacher closes the gap. At
the principled coefficient c=0.5 with per-layer hidden-state matching, the CGF
student matches exact softmax (10-seed mean 0.8595 vs exact 0.8630). See
summary.md for the full investigation log.
All code lives under finetuning/:
finetuning/
cgf_bert/ # the reusable library
config.py # BERT repo id + reference-recipe defaults
softmax.py # cgf_softmax + exact_softmax
model.py # BertMRPC, load_weights, make_exact
data.py # MRPC tokenization + DataLoaders
engine.py # set_seed/freeze/optimizer + the shared fit() loop
finetune.py # step: CE fine-tune (exact teacher OR CGF baseline)
distill.py # step: KD exact-teacher -> CGF student (headline)
validate.py # step: score a checkpoint
experiments/
exact_vs_cgf.py # exact-vs-CGF ceiling probe
coeff_sweep.py # sweep the CGF variance coefficient
The fit() loop in engine.py drives every script; the per-batch update is a
pluggable train_step, so CE fine-tuning and KD share the same epoch /
best-checkpoint machinery (the KD step lives in distill.py).
Compute runs on GPU — set CUDA_VISIBLE_DEVICES to a free device. Run the step
scripts from inside finetuning/; run experiments as modules (so the
cgf_bert package resolves):
cd finetuning
# 1. Train the exact-softmax teacher (cached for distillation)
uv run python finetune.py --softmax exact -o ../checkpoints_teacher/teacher.pt
# 2. Distill it into a CGF student (the recovery method) — c=0.5, beta=10
uv run python distill.py --coeff 0.5 --beta 10 --seeds 42 7 123
# 3. Score a checkpoint (defaults to the teacher)
uv run python validate.py
uv run python validate.py --checkpoint ../checkpoints/mrpc_best.pt --softmax cgf
# CGF baseline without KD, and the experiments
uv run python finetune.py --softmax cgf --coeff 0.65
uv run python -m experiments.exact_vs_cgf --softmax cgf --coeff 0.65
uv run python -m experiments.coeff_sweep --coeffs 0.6 0.65 0.7 0.75distill.py auto-loads ../checkpoints_teacher/teacher.pt if present, else
trains the teacher once (step 1 is optional).
5 epochs · frozen embeddings · linear LR decay (no warmup) · lr 5e-5 · batch 16 ·
best-checkpoint selection (defaults in cgf_bert/config.py). This is the recipe
that takes an exact-softmax BERT to ~86% on MRPC.
The CGF variance-damping coefficient is VARIANCE_COEFF in
cgf_bert/softmax.py. Under CE fine-tuning the robust optimum is 0.65 (the
basin is narrow — below it the model collapses, above ~0.85 it over-subtracts).
Under distillation the principled Gaussian value 0.5 is best.
| model | recovery | mean acc |
|---|---|---|
| Exact softmax | — | 0.8630 |
| CGF softmax | KD, c=0.5, beta=10 | 0.8595 (10 seeds) |
| CGF softmax | CE fine-tune, c=0.65 | 0.8194 |
| CGF softmax | CE fine-tune, c=0.5 | 0.7273 |
Plain fine-tuning leaves a ~5 pp gap (and the recipe gains that lift exact
softmax do not transfer through CGF). KD with per-layer hidden-state
matching (beta>0) is what closes it: on 2-class MRPC the logit channel
carries too little signal, so matching the teacher's intermediate
representations is the effective lever — beta=0 only reaches 0.799.
CUDA matmul is nondeterministic by default, so even same-seed runs drift (~2 pp on the 408-example val set). Report the mean across a fixed seed set, and compare configs on the same seeds (paired comparison cancels seed luck). Do not cherry-pick the strongest single seed as "the result."