A from-scratch PyTorch implementation of continuous network sparsification via learnable sigmoid gates — the network autonomously decides which of its own connections to prune during training, with no manual intervention. Served via a production FastAPI inference endpoint and containerized with Docker.
Input (3072)
│
▼
PrunableLinear(3072 → 512) ← each weight has a learnable gate ∈ [0, 1]
│ gate × weight
▼
PrunableLinear(512 → 256) ← gates near 0 are hard-masked to exactly 0 during eval
│ gate × weight
▼
PrunableLinear(256 → 10)
│
▼
CIFAR-10 class prediction
Result at λ=0.0001: 53.66% test accuracy with 89.36% of all connections permanently pruned.
Every weight in every layer has a corresponding learnable gate score gᵢ. The forward pass becomes:
output = (σ(gᵢ) × weight) @ input + bias
The total training loss is:
loss = CrossEntropy(logits, labels) + λ × Σᵢ σ(gᵢ)
The L1 penalty on σ(gᵢ) creates a constant gradient pushing gate scores toward −∞. Once gᵢ is sufficiently negative, σ(gᵢ) ≈ 0 — that weight is effectively dead. The network learns which connections matter and eliminates the rest.
During model.eval(), a hard threshold enforces strict binary masking: any gate below 1e-2 is clamped to exactly 0.0.
Standard L1 weight regularization pushes weights toward zero but can't fully eliminate them — the gradient vanishes as the weight approaches zero. Gating through sigmoid changes this: since σ(gᵢ) → 0 only as gᵢ → -∞, the optimizer receives a constant gradient signal regardless of how small the gate already is. This drives genuine structural sparsity rather than just weight shrinkage.
| Lambda (λ) | Test Accuracy | Sparsity |
|---|---|---|
0.0001 (Low) |
53.66% | 89.36% |
0.01 (Medium) |
10.00% | 99.97% |
0.5 (High) |
10.00% | 100.00% |
The sweet spot at λ=0.0001 demonstrates the core claim: the network sheds ~89% of its connections while retaining meaningful classification ability. Higher λ collapses the network entirely — the experiment honestly reports this boundary rather than cherry-picking.
Gate value distribution at high λ — bimodal polarization toward 0 and 1 shows the network making hard binary decisions about each connection.
During testing, λ=0.01 initially reported 51% accuracy at 99.98% sparsity — an implausible result for an MLP with only ~340 active connections out of 1.7M weights.
The bug: The sparsity metric flagged gates < 0.01 as "pruned," but the forward pass still computed 0.009 × weight. Those residual values weren't zero — tiny activations that collectively maintained enough signal for 51% accuracy, creating a phantom result.
The fix: A hard mask enforced strictly inside model.eval() — any gate below threshold is set to exactly 0.0 before matrix multiplication. This eliminated the leak and revealed the honest result: at true 99.97% sparsity the network collapses to random-guess accuracy (10%), validating λ=0.0001 as the genuine operating point.
| Layer | Tool |
|---|---|
| Model | PyTorch 2.x, custom PrunableLinear layer |
| Training | CIFAR-10, CrossEntropy + L1 sparsity loss, automated lambda sweep |
| Inference API | FastAPI, async endpoints, Pydantic schema validation |
| Testing | pytest — shape stability + backward gradient flow |
| Containerization | Docker |
| Workflow | make install / test / train / serve |
git clone https://github.com/arrya5/Self-Pruning-Neural-Network-Dynamic-Sparsification-PyTorch-FastAPI
cd Self-Pruning-Neural-Network-Dynamic-Sparsification-PyTorch-FastAPI
make install # install dependencies
make test # run pytest suite
make train # full lambda sweep → prints accuracy/sparsity table
make serve # start FastAPI inference server at localhost:8000# Health check
GET /health
→ {"status": "up", "framework": "PyTorch + FastAPI"}
# Predict
POST /api/v1/predict
{"image_tensor": [<3072 floats representing a 32×32 RGB image>]}
→ {"prediction_class": 3, "confidence_scores": [0.02, 0.01, ...]}docker build -t self-pruning-nn .
docker run -p 8000:8000 self-pruning-nn├── src/
│ ├── model.py # PrunableLinear layer + SelfPruningNet architecture
│ ├── train.py # training loop, lambda sweep, sparsity metrics
│ ├── config.py # hyperparameters (lambdas, hidden dims, threshold)
│ └── utils.py # gate extraction, visualization
├── tests/
│ └── test_model.py # shape stability + gradient flow tests
├── main.py # FastAPI inference server
├── Dockerfile
└── Makefile
