Interactive proof blueprint: brettkoonce.github.io/lean4-mlir/blueprint/
(or PDF)
— clickable dependency DAG for the full VJP proof suite (no sorrys, zero project axioms),
from pdiv primitives up to the whole-network VJPs (ViT, ResNet, MobileNetV2, ConvNeXt, EfficientNet).
Lean 4 as a specification language for neural networks. Declare architecture in Lean, generate StableHLO MLIR (forward + loss + backward + optimizer all in one fused function), compile to GPU via IREE, train end-to-end. No Python runtime, no autograd library — the gradients are computed at codegen time in Lean.
Companion code for the upcoming book Verified Deep Learning with Lean 4 (follow-up to Convolutional Neural Networks with Swift for TensorFlow, Apress).
Current version: v0.6.0 — Object detection joins the framework. A
YOLOv1 person detector on Pascal VOC, bootstrapped from Chapter 6's
ResNet-34 backbone with a 1×1 convolutional detection head — the FC head
can't learn localization on a small dataset (diagnosed end to end in
planning/yolo_v5.md); the focal-loss objectness term is the remaining
piece, flagged in the bestiary's new detection intro. The IREE/Lean
training path gains global-norm gradient clipping, env-var checkpoint
resume (LEAN_MLIR_INIT_LOAD / LEAN_MLIR_START_STEP), and per-step LR
warmup. Blueprint adds demo-anchored intros for object detection (§11.2.2)
and diffusion (§11.2.7), redraws the MNIST/CIFAR training curves as native
pgfplots (vector, regenerable from logs/), and adds a BatchNorm "why
normalizing helps" intuition.
The v0.5.7 headline still holds: two parallel-agent audits closed.
The "canonical correct := rfl" pattern at non-smooth operators
(ReLU, the composed MLP, MaxPool2) now has machine-checked
smooth-point bridges: relu_codegen_matches_canonical and
maxPool2_codegen_matches_canonical prove the canonical-witness
backward equals the codegen formula wherever every coordinate
avoids the kink. A HasVJPAt pointwise framework provides
smooth-input variants of the three kinked-operator instances
whose correct field is a real chain-rule proof rather than
rfl. The comparator suite extends from 38 → 41 theorems
independently kernel-rechecked against [propext, Quot.sound, Classical.choice]. Blueprint gets a half-dozen flow improvements
(GAP defined at first material use in Ch 6, Diffusion split into
its own Bestiary subsection, ResNet entry expanded to the full
standard family including R-18, Tomáš Skřivan's Scientific
Computing in Lean credited at the top of the acknowledgments).
Android bottom-cutoff bug fixed (issue #2); Umami cookieless
analytics replaces planned GA. First Zenodo deposit lands with
this release.
The v0.5.6 headline still holds: Chapter 9 lands its ConvNeXt-T
worked example (84.94% val on Imagenette, paper-faithful recipe);
Chapter 10 gets a Data Augmentation section with a 9-row ViT
recipe ablation table — CutMix is the load-bearing knob at 9.5K
images, and stacking RandAugment + Random Erasing on top of it
hurts val accuracy. Bestiary gets paper-exact entries for VGG,
ResNet-50/101/152, WRN, and DenseNet, plus the "N new primitives"
claim reframed around the Ch 2-10 reader's toolbox (what's free)
rather than the codebase (what's already in Types.lean). Found
and fixed a long-standing eval-pipeline bug along the way:
centerCrop was running on already-224 val data, reading past
per-image bounds and making heavy-aug runs appear to collapse.
New LEAN_MLIR_EVAL_ONLY=1 mode re-evals saved checkpoints in
~5 sec each.
The v0.5.5 headline still holds: Swish/SiLU as a first-class
activation (forward + backward + proved swish_has_vjp_correct)
plus the independent-kernel comparator re-check covering 38
theorems via public *_has_vjp_correct wrappers, and Ch 2's
"Why VJPs, not Jacobians?" bridge + canonical-pdiv witness
explainer + three-pillar TikZ spine diagram.
On top of that, a differential-test suite in
tests/vjp_oracle/ uses JAX's value_and_grad
as an oracle for the hand-derived VJPs in LeanMlir/Proofs/. Nine
test cases cover every axiom family — dense, conv, BN, maxPool,
residual (biPath), depthwise, SE (elementwise product), attention,
and the transformer block — each verified to 1–2 ULP of JAX autodiff.
The ULP-floor cross-backend agreement (Lean→IREE→GPU vs Lean→JAX→XLA
on both NVIDIA and AMD) established in v0.5.4 still holds; see
traces/CROSS_BACKEND_RESULTS.md
for the four-corner verification tables.
This project went through three implementations of the same idea — "Lean 4 as a specification language for deep learning" — each shedding more dependencies than the last.
Phase 1 — Pure Lean 4. mnist-lean4/: everything in Lean,
Float64 as the only datatype, hand-written gradients, C FFI to OpenBLAS /
hipBLAS for the matmuls. Worked end-to-end on MNIST through ResNet-34 but
performance was poor — every operation crossed the FFI boundary, no fusion,
no autodiff, no JIT.
Phase 2 — Lean → JAX. jax/: Lean as a metaprogramming layer
that emits idiomatic JAX Python (jax/Jax/Codegen.lean, ~2100 lines). The
generated script gets value_and_grad autodiff and XLA JIT for free, runs
on any JAX-supported device. Trades the pure-Lean story for a working stack
and real GPU performance. See jax/README.md for details.
Phase 3 — Lean → StableHLO → MLIR → device. (this README) No Python
runtime at all. Lean directly emits StableHLO MLIR, IREE compiles it to a
GPU flatbuffer, a thin C FFI loads and runs it. The pure-math version of
phase 2 — autodiff is done at codegen time in Lean (LeanMlir/MlirCodegen.lean,
~7500 lines), not at runtime by a framework. See RESULTS.md
for the per-architecture numbers.
The VJP correctness proofs live in LeanMlir/Proofs/ —
chapter-by-chapter, for tensor ops, MLP, CNN, residual, batch norm,
depthwise, SE, LayerNorm, and attention, up to whole-network backward passes
(ViT, ResNet, MobileNetV2, ConvNeXt, EfficientNet). What they establish: each
reference forward function, written in exact real arithmetic (ℝ), has a
backward pass equal to its Mathlib fderiv Jacobian-transpose — with zero
project axioms (#print axioms closes under the Lean-core triple alone).
The whole-network results come in two forms, set by the architecture's activations:
-
Unconditional — ViT, ConvNeXt, EfficientNet. These use only smooth ops (GELU / Swish / sigmoid, softmax, LayerNorm, convolution — no ReLU, no max-pool), so the VJP holds at every input, with the LayerNorm/BatchNorm
0 < εpositivity as the only side conditions (vit_full_has_vjp,convnext_has_vjp,efficientnet_has_vjp : HasVJP …). -
Conditional + concretely instantiated — MLP, MNIST-CNN, ResNet, MobileNetV2. ReLU, ReLU6 and max-pool are genuinely non-differentiable at their kinks, so the generic whole-network VJP is stated at a smooth point (
*_has_vjp_at, under per-site "off the kink" hypotheses). Each is then instantiated on a concrete (small, representative) network with every smoothness hypothesis discharged, giving a hypothesis-free correctness theorem (MlpConcrete,Spatial/Mini,CnnConcrete,MobileNetV2Concrete) — proof that the kink-avoidance conditions are jointly satisfiable on the real forward, not vacuous.
Axiom closure on every one of these is a CI invariant
(tests/AuditAxioms.lean); the generic headline
theorems are additionally re-checked by the independent
tests/comparator/ kernel pass.
These proofs are about the reference ℝ definitions in Proofs/, not the
Float32 StableHLO the codegen emits — the two are written separately and no
Lean theorem currently links them. The connection is instead twofold. (1)
Structural: codegen and proofs were developed independently and arrived at
the same decomposition — every backward pass factors through the standalone
gradient of one new primitive per architecture (softmax for attention, the
spatial reductions for BN, the rank-1 collapse for SE), and everything else is
composition via the chain rule on tools from earlier chapters — and the
codegen cites the matching proof inline in the MLIR it generates. (2)
Numerical: finite-difference checks (LeanMlir/Proofs/check_jacobians.py)
and JAX value_and_grad oracles (tests/vjp_oracle/)
exercise the emitted formulas — including at the ReLU/MaxPool kinks, where the
codegen substitutes the standard subgradient convention. See the "Codegen
trust boundary" section of LeanMlir/Proofs/README.md
for the precise gap. Closing it formally — a forward-extraction lemma tying a
proven *Forward to the codegen's emitted graph — is open future work.
All proofs are over exact reals (ℝ). The emitted MLIR and GPU execution are
Float32; iree-compile, the IREE runtime, and the FFI are trusted. Within
that boundary, the verification is tiered by dataset / backend:
Tier 1 — MNIST (linear, mlp, cnn): forward + backward bridged. The reference
forward and backward are proven faithful to the Mathlib fderiv math as rendered
StableHLO graphs (mlpFwdGraph_faithful, mlpBackGraph_faithful,
cnnFwdGraph_faithful, cnnBackGraph_faithful; for linear also the param-grad
Jacobians wGrad/bGrad_is*Jacobian and sgdW/sgdB_descends_certified_grad).
All audited to the 3-axiom closure. Caveat: the train-step .mlir is currently
assembled from these proven op-graphs with a hand-written grad/SGD tail (see
linearTrainStepModuleV); folding that tail into the rendered AST so the whole
train-step module is render(provenGraph) is in progress. Tier 1 also now
carries the ℝ→Float32 bridge (below): forward, gradient, and SGD-step
rounding budgets for the linear/MLP nets, and for linear a proven descent
guarantee.
Tier 2 — CIFAR (cifar, cifar-bn): forward bridged, backward WIP.
cifarFwdGraph_faithful / cifarBnFwdGraph_faithful (plus op-level
bnBack_faithful) hold; the whole-net backward graph and the train step are
not yet rendered from a proof.
Tier 3 — Imagenette (ResNet-34, MobileNetV2, ConvNeXt, EfficientNet, ViT):
ℝ whole-net VJP proven; codegen bridge WIP. The whole-network VJP is proven
over ℝ (resnet34_has_vjp_at, vit_full_has_vjp, convnext_has_vjp,
efficientnet_has_vjp, mobilenetv2_has_vjp_at). The rendered-MLIR bridge is
forward-graph-only (resnet/mnv2/convnext) or op-level-only (efficientnet/vit),
and the GPU trainers behind the Imagenette numbers below run the unverified
MlirCodegen.lean path. No theorem links those proofs to that codegen yet.
Tier 4 — ImageNet-1k (phase-2 Lean→JAX bridge): scale baseline, gradients
not Lean-verified. Full 1000-class ImageNet runs use the phase-2 path
(jax/Jax/Codegen.lean, ~1100 lines: NetSpec → idiomatic JAX Python), where
JAX's value_and_grad computes the gradients and XLA does the compilation —
the Lean VJP proofs are not in the loop. The only proof-adjacent Lean artifact is
the shared NetSpec ADT (the same architecture spec whose phase-3 backward is
proven over ℝ); the emitter itself is unverified. This tier exists to (a)
establish scale baselines the verified-IREE codegen can't yet reach — ConvNeXt-T
75.93% / EfficientNet-B0 72.31% / ResNet-34 72.02% / MobileNetV2 68.33% / ViT-Tiny
65.64% top-1, full 50k val (jax/runs/*/RESULTS.md) — and (b) serve
as the differential-test oracle: tests/vjp_oracle/ uses
JAX value_and_grad as ground truth to cross-check the Tier 1–3 Lean-derived VJPs
to 1–2 ULP. So Tier 4 is the least-verified tier by gradient provenance but the
one that empirically anchors the others. Whether phase-3 verified codegen can reach
ImageNet scale is open.
All tier proofs are over exact reals; LeanMlir/Proofs/FloatBridge.lean +
SgdDescent.lean/SgdDescentLinear.lean/SgdDescentMlp.lean close the
rounding gap for the
Tier-1 nets, hypothesis-style (zero project axioms — a FloatModel is any
rounding operator with relative error u; binary32 instantiates it with
u = 2⁻²⁴ on the normal range, subnormals open). The chain, every link in
the 3-axiom audit:
- Forward (
mlp_float_close_uniform): dot/dense budgets in the classical compounded form, valid for every summation association (IREE may reassociate reductions freely). ReLU is exact in float — the op that forces the off-the-kink hypotheses overℝis the free op here. - Backward (
mlp_{w2,w1,w0,b2,b1,b0}_step_float_close): every rounded SGD parameter entry within an explicit budget ofθ − lr·(aᵢ·cⱼ)— the sameemitWeightGrad/emitBiasGradentriesmlp_render_*_certifiedprove equal to thepdiv-Jacobian contractions. The ReLU masks need quantitative margins (ez < |zᵢ|: rounding must not flip a sign). - Loss head (
softmax_ce_cot_close): the rounded softmax−onehot cotangent vs the certified gradient, given anexpaccuracy hypothesis (|fexp t − exp t| ≤ eexp·exp t— GPUexphas no IEEE spec;eexpis the constanttests/vjp_oracle/validates at 1–2 ULP). - Descent (
sgd_descends,linear_sgd_descends,mlp_{output,hidden,input}_sgd_descends): an η-accurate gradient step still decreases the loss — with the smoothness hypothesis proven, not assumed: explicit constant2a²/(1−2aD)for the linear net, and through the MLP's ReLU kinks per weight layer under quantitative margins (the step'sℓ1radius cannot flip a mask sign, so the sign pattern freezes along the segment):2d₃w₂²a²/(1−2w₂aD)for the hidden layer,2d₃d₂²w₁²w₂²a²/(1−2w₂d₂w₁aD)for the input layer; the output layer is the linear theorem at the hidden activation, margin-free. No Hessian anywhere (the same softmax ratio sandwich as the float budgets).
Measured vs proven (scripts/margin_probe.py, an f32/f64 twin of the
97.8% GPU run; numeric capstones instantiated at the trained magnitudes
|W| ≤ 3/5):
| quantity | worst-case theorem | measured |
|---|---|---|
| logit drift | ≤ 5100 (mnist_mlp_float_budget) |
1.6·10⁻⁵ |
| cotangent | ≤ 21/1000 at δ=1/100 (mnist_cot_budget) |
2.2·10⁻⁶ |
| W₂ SGD step | ≤ 5/4 (mnist_w2_step_float_budget) |
7.5·10⁻⁹ |
| ReLU mask flips | 0 under margins | 0 / 29.5M |
The worst-case-vs-measured gap (up to ~10⁸) is the quantitative case for a-posteriori certificates past toy depth; the zero flip count says the margin hypotheses describe real training, not a technicality.
Not yet verified anywhere: the ~7500-line MlirCodegen.lean (zero
theorems); outside Tier 1, the train-step text that iree-compile actually
consumes; and, within the float bridge, subnormals (the model is
relative-error-only), the joint all-layers descent step and bias columns
(the per-weight-layer constants are proven for linear + MLP; for the CNN
the new ingredients are proven — quantitative max-pool selection margins
that freeze the argmax routing, pool ℓ1-contraction, conv-kernel drift
with the weight-sharing factor — but the conv-layer capstone assembly is
open, planning/sgd_descent_cnn.md; so is every-parameter-at-once, where
the logits are no longer affine in the moving parameters), and any link
from the Lean-side FloatModel to IREE's actual kernels beyond the
empirical probe.
Concrete-instance honesty. The conditional capstones (MLP, MNIST-CNN, CIFAR,
MobileNetV2, ResNet-34) are instantiated to discharge their off-the-kink
hypotheses. MlpConcrete, Micro/Mini/Spatial (MNIST) and Tiny (CIFAR) are
live witnesses (non-constant forward, nonzero Jacobian). MobileNetV2Concrete,
CnnConcrete, and ResNet34Concrete are degenerate constant-output nets (zero
Jacobian) — they prove the hypothesis bundle is satisfiable but say nothing about a
realistic gradient; live witnesses for the deep/BN/ReLU6 nets are follow-up.
Lean NetSpec (~15 lines)
│
│ MlirCodegen.generateTrainStep
▼
StableHLO MLIR (500 KB - 2 MB of text, forward+loss+backward+Adam fused)
│
│ iree-compile (~10-15 min for ROCm gfx1100)
▼
VMFB flatbuffer (1.8-3 MB)
│
│ IREE runtime via libiree_ffi.so
▼
GPU execution (HIP/ROCm or CUDA)
The same Lean → MLIR pipeline handles every architecture. Adding a new
architecture means extending LeanMlir/MlirCodegen.lean with:
- forward emission for the new layer types
- VJP / backward emission
FwdRecrecording for backward intermediates
The training executable, FFI, and IREE runtime are unchanged.
Phase 2 and Phase 3 share the same Lean NetSpec ADT but compile through
completely independent stacks (JAX/XLA vs IREE). Differential testing
confirms both stacks produce the same training dynamics on the same input,
for both MLP (670K params, 12 epochs) and CNN (1.7M params with conv+BN,
15 epochs):
| diff | MLP step 1 Δ | CNN step 1 Δ |
|---|---|---|
| phase 2 (JAX) vs phase 3 (IREE) | ~2e-7 | ~1e-5 to 1e-4 |
| phase 3 ROCm vs phase 3 CUDA | 0 | 0 |
| phase 2 CPU vs phase 2 CUDA | ~4e-6 | ~1e-4 |
MLP hits the float32 ULP floor because it's dense-only. CNN's noise
floor is looser by ~100× because each conv-BN layer does two
reductions over ~100k-element tensors and XLA's reduction trees differ
from IREE's — both pipelines do correct math, just with different
summation orders. Phase 3 ROCm ≡ Phase 3 CUDA is bit-identical at
step 1 on both networks. Reproducible in 5 minutes via
traces/CROSS_BACKEND_RESULTS.md.
A separate per-axiom differential test in
tests/vjp_oracle/ uses JAX's value_and_grad as
a correctness oracle for every hand-derived backward pass in
LeanMlir/Proofs/. Each test case is a minimal NetSpec exercising one
axiom in isolation; the oracle compares step-2 loss (the first step
whose value depends on the backward pass) against phase 2's
autodiff-derived gradients.
Nine cases, all green on mars (ROCm + CPU) and ares (CUDA):
| case | axiom | step 2 Δ |
|---|---|---|
dense |
dense_has_vjp + softmaxCE_grad |
2.7e-07 |
dense-relu |
relu_has_vjp + vjp_comp |
4.8e-07 |
conv |
conv2d_has_vjp + flatten_has_vjp |
2.2e-07 |
convbn |
convBn_has_vjp (BN-mode) |
2.2e-06 |
conv-pool |
maxPool_has_vjp (argmax tiebreaks) |
1.2e-04 |
residual |
biPath_has_vjp (additive fan-in) |
3.1e-07 |
depthwise |
depthwise-conv VJP via .invertedResidual |
1.1e-05 |
mbconv |
elemwiseProduct_has_vjp (SE gate) + Swish |
1.6e-06 |
attention |
patchEmbed + transformerBlock_has_vjp_mat + classifier |
1.8e-07 |
Run with tests/vjp_oracle/run.sh. Adding a new axiom means dropping
a minimal Lean spec under tests/vjp_oracle/phase{2,3}/ plus one
line in the lakefiles — see
tests/vjp_oracle/README.md.
The oracle also surfaced a real heInitParams bug (shape-peek
heuristic misfiring at patchEmbed + transformer-block boundaries) and
a JAX-ROCm crash on gfx1100 (filed as
ROCm/MIOpen#3955; repro
lives at upstream-issues/2026-04-rocm-miopen-conv-segv/).
Trained from scratch on a single AMD 7900 XTX (gfx1100), Adam, batch 32, cosine LR + 3-epoch warmup, label smoothing 0.1, weight decay 1e-4, random crop (256→224) + horizontal flip, running BN stats for eval.
| Model | Params | Val accuracy |
|---|---|---|
| ResNet-34 | 21.3M | 90.29% |
| ResNet-50 | 23.5M | 89.40% |
| EfficientNetV2-S | 38.2M | 88.50% |
| EfficientNet-B0 | 7.2M | 87.58% |
| MobileNetV2 | 2.2M | 87.09% |
| MobileNetV3-Large | 3.0M | 86.48% |
| MobileNetV4-Medium | 4.1M | 84.58% |
| ViT-Tiny | 5.5M | 71.70% |
Per-epoch eval histories and ablation tables in RESULTS.md.
curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | shYou need the IREE runtime built for your GPU (CUDA or ROCm). The FFI shim
in ffi/ links against libiree_runtime_unified.a from the IREE build tree.
See IREE_BUILD.md for build instructions.
./download_mnist.sh # MNIST (Ch 3-4 trainers)
./download_cifar.sh # CIFAR-10 (Ch 5 trainers)
./download_imagenette.sh # Imagenette 320px → preprocessed binary (Ch 6+)lake build resnet34-trainThis compiles the Lean trainer (which generates MLIR + drives IREE + runs
the training loop). Other targets, in roughly book order:
mnist-mlp-train, mnist-cnn-train, cifar-cnn-train, cifar-bn-train,
resnet50-train, mobilenet-v2-train, mobilenet-v3-train,
mobilenet-v4-train, efficientnet-train, efficientnet-v2-train,
vgg-train, vit-tiny-train.
The first invocation generates and compiles the vmfbs (slow — IREE
compilation takes 10-15 min for ResNet-sized models). Subsequent
runs reuse the cached vmfbs unless you clear .lake/build/.
HIP_VISIBLE_DEVICES=0 IREE_BACKEND=rocm .lake/build/bin/resnet34-train
# Or via the included shell wrapper that sets the env vars correctly
bash run.sh resnet34 # GPU 0, ROCm (defaults)
bash run.sh efficientnet-v2 1 cuda # GPU 1, CUDAFor CUDA, set IREE_BACKEND=cuda (the default) and use CUDA_VISIBLE_DEVICES.
The same NetSpec type is used by all three phases. A spec is a list of
Layer values:
def resnet34 : NetSpec where
name := "ResNet-34"
imageH := 224
imageW := 224
layers := [
.convBn 3 64 7 2 .same,
.maxPool 2 2,
.residualBlock 64 64 3 1,
.residualBlock 64 128 4 2,
.residualBlock 128 256 6 2,
.residualBlock 256 512 3 2,
.globalAvgPool,
.dense 512 10 .identity
]
def vitTiny : NetSpec where
name := "ViT-Tiny"
imageH := 224
imageW := 224
layers := [
.patchEmbed 3 192 16 196, -- (224/16)^2 = 196 patches
.transformerEncoder 192 3 768 12, -- 12 blocks, 3 heads, MLP dim 768
.dense 192 10 .identity
]
def mobilenetV4Medium : NetSpec where
name := "MobileNet V4-Medium"
imageH := 224
imageW := 224
layers := [
.convBn 3 32 3 2 .same,
.fusedMbConv 32 48 4 3 2 1 false,
.uib 48 80 4 2 3 5, -- ExtraDW
.uib 80 160 6 2 0 3, -- IB (= MBConv)
.uib 160 160 4 1 5 0, -- ConvNeXt
.uib 160 160 4 1 0 0, -- FFN
-- ... 11 more UIB blocks
.convBn 256 1280 1 1 .same,
.globalAvgPool,
.dense 1280 10 .identity
]lean4-mlir/
├── README.md -- this file (phase 3)
├── RESULTS.md -- per-architecture eval histories + ablations
├── IREE_BUILD.md -- how to build libiree_ffi.so from scratch
├── ROCM.md -- ROCm setup notes
├── BENCHMARK.md -- ROCm vs CUDA performance comparison
├── lakefile.lean -- Lake build config (libraries + ~30 execs)
│
├── LeanMlir.lean -- umbrella module
├── LeanMlir/
│ ├── MlirCodegen.lean -- ~7500 lines, NetSpec → StableHLO MLIR
│ ├── IreeRuntime.lean -- Lean ↔ libiree_ffi.so bindings
│ ├── F32Array.lean -- ByteArray-backed float32 helpers
│ ├── Spec.lean -- NetSpec / Layer / param-counting
│ ├── Types.lean -- core types (Layer, Activation, Padding, ...)
│ ├── MnistData.lean -- IDX file loader (older training paths)
│ └── Proofs/ -- VJP correctness proofs (~36,700 lines)
│ ├── Tensor.lean
│ ├── MLP.lean
│ ├── CNN.lean
│ ├── Residual.lean
│ ├── BatchNorm.lean
│ ├── Depthwise.lean
│ ├── SE.lean
│ ├── LayerNorm.lean
│ ├── Attention.lean
│ ├── FloatBridge.lean -- ℝ→Float32: rounding budgets (Tier 1)
│ ├── SgdDescent.lean -- inexact-gradient descent over ℝ
│ ├── SgdDescentLinear.lean -- Lipschitz constants, linear loss
│ ├── SgdDescentMlp.lean -- ...through the MLP's ReLU kinks (margins)
│ └── SgdDescentCnn.lean -- ...toward the CNN: pool selection margins
│
├── Main*Train.lean -- phase 3 trainers (one per architecture)
│ ├── MainResnetTrain.lean
│ ├── MainResnet50Train.lean
│ ├── MainMobilenetV2Train.lean
│ ├── MainMobilenetV3Train.lean
│ ├── MainMobilenetV4Train.lean
│ ├── MainEfficientNetTrain.lean
│ ├── MainEfficientNetV2Train.lean
│ ├── MainVitTrain.lean
│ ├── MainVggTrain.lean
│ ├── MainMnistMlpTrain.lean
│ ├── MainMnistCnnTrain.lean
│ ├── MainCifarCnnBnTrain.lean
│ ├── MainCifarCnnTrain.lean
│ └── MainAblation.lean
│
├── tests/ -- unit tests + smoke tests + differential tests
│ ├── Test*.lean -- runtime / FFI / codegen sanity tests
│ ├── BenchResnet.lean
│ ├── diff_traces.py -- JSONL trace diff helper
│ ├── cross_backend_mnist_mlp.sh
│ └── vjp_oracle/ -- JAX-autodiff oracle for hand-derived VJPs
│ ├── README.md
│ ├── run.sh
│ ├── diff_step.py
│ ├── phase3/ -- Lean→IREE test trainers
│ └── phase2/ -- (mirrored at jax/tests/vjp_oracle/phase2/)
│
├── upstream-issues/ -- isolated reproducers + backtraces for bugs
│ └── 2026-04-rocm-miopen-conv-segv/ -- ROCm/MIOpen#3955
│
├── ffi/
│ ├── iree_ffi.{c,h} -- IREE runtime wrapper
│ ├── iree_lean_ffi.c -- Lean FFI bindings
│ ├── f32_helpers.c -- data loading, He init, EMA, augmentation
│ └── libiree_ffi.so -- compiled shared library
│
├── jax/ -- phase 2 (Lean → JAX Python)
│ ├── README.md
│ ├── Jax.lean
│ ├── Jax/{Codegen,Runner}.lean
│ ├── Main*.lean -- 14 JAX-driven architecture specs
│ └── tests/vjp_oracle/phase2/ -- phase-2 mirror of oracle specs
│
├── mnist-lean4/ -- phase 1 (pure Lean 4 + C BLAS)
│
├── traces/ -- committed cross-backend training traces
│ ├── CROSS_BACKEND_RESULTS.md
│ ├── TRACE_FORMAT.md
│ └── mnist_{mlp,cnn}.*.jsonl
│
├── data/ -- downloaded + preprocessed datasets
└── run_*.sh -- shell wrappers for tmux env propagation
| Layer | Description |
|---|---|
dense |
Fully connected (with optional activation) |
conv2d |
Standard convolution |
convBn |
Conv + batch norm + ReLU/ReLU6/Swish/h-swish |
residualBlock |
BasicBlock (ResNet-18/34) |
bottleneckBlock |
Bottleneck (ResNet-50/101/152) |
invertedResidual |
Expand → depthwise → project + skip (MobileNetV2) |
mbConv |
+ Squeeze-Excitation, Swish (EfficientNet) |
mbConvV3 |
+ h-swish + h-sigmoid SE (MobileNetV3, exact math) |
fusedMbConv |
k×k regular conv replaces (1×1 expand + depthwise) (EfficientNetV2) |
uib |
Universal Inverted Bottleneck — pre-DW? + expand + post-DW? + project (MobileNetV4) |
patchEmbed |
Conv patch projection + CLS token + positional embedding (ViT) |
transformerEncoder |
LN → MHSA → + → LN → MLP → +, with exact tanh-form GELU |
maxPool, globalAvgPool, flatten |
Structural |
Activations supported with exact backward: ReLU, ReLU6, Swish, h-swish, h-sigmoid, GELU (tanh form). Layer-norm and batch-norm both have proper VJPs and (for BN) running statistics for eval.
Tested with Lean 4.30.0 / Lake 5.0.0, IREE built from source against ROCm 7.2.0 / gfx1100.
@software{koonce2026leanmlir,
author = {Brett Koonce},
title = {Verified Deep Learning with Lean 4: Formal Backpropagation from MLP to Attention, via MLIR},
url = {https://github.com/brettkoonce/lean4-mlir},
doi = {10.5281/zenodo.20402133},
version = {0.6.0},
year = {2026},
}