Skip to content

arrya5/Self-Pruning-Neural-Network-Dynamic-Sparsification-PyTorch-FastAPI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-Pruning Neural Network — Dynamic Sparsification in PyTorch

Python PyTorch FastAPI Docker Tests

A from-scratch PyTorch implementation of continuous network sparsification via learnable sigmoid gates — the network autonomously decides which of its own connections to prune during training, with no manual intervention. Served via a production FastAPI inference endpoint and containerized with Docker.

Input (3072)
    │
    ▼
PrunableLinear(3072 → 512)  ← each weight has a learnable gate ∈ [0, 1]
    │  gate × weight
    ▼
PrunableLinear(512 → 256)   ← gates near 0 are hard-masked to exactly 0 during eval
    │  gate × weight
    ▼
PrunableLinear(256 → 10)
    │
    ▼
CIFAR-10 class prediction

Result at λ=0.0001: 53.66% test accuracy with 89.36% of all connections permanently pruned.


How It Works

The Gate Mechanism

Every weight in every layer has a corresponding learnable gate score gᵢ. The forward pass becomes:

output = (σ(gᵢ) × weight) @ input + bias

The total training loss is:

loss = CrossEntropy(logits, labels) + λ × Σᵢ σ(gᵢ)

The L1 penalty on σ(gᵢ) creates a constant gradient pushing gate scores toward −∞. Once gᵢ is sufficiently negative, σ(gᵢ) ≈ 0 — that weight is effectively dead. The network learns which connections matter and eliminates the rest.

During model.eval(), a hard threshold enforces strict binary masking: any gate below 1e-2 is clamped to exactly 0.0.

Why L1 on Sigmoid Gates (not on weights directly)?

Standard L1 weight regularization pushes weights toward zero but can't fully eliminate them — the gradient vanishes as the weight approaches zero. Gating through sigmoid changes this: since σ(gᵢ) → 0 only as gᵢ → -∞, the optimizer receives a constant gradient signal regardless of how small the gate already is. This drives genuine structural sparsity rather than just weight shrinkage.


Experimental Results

Lambda (λ) Test Accuracy Sparsity
0.0001 (Low) 53.66% 89.36%
0.01 (Medium) 10.00% 99.97%
0.5 (High) 10.00% 100.00%

The sweet spot at λ=0.0001 demonstrates the core claim: the network sheds ~89% of its connections while retaining meaningful classification ability. Higher λ collapses the network entirely — the experiment honestly reports this boundary rather than cherry-picking.

Gate Distribution at λ=0.5

Gate value distribution at high λ — bimodal polarization toward 0 and 1 shows the network making hard binary decisions about each connection.


Key Engineering Discovery: The Soft-Pruning Illusion

During testing, λ=0.01 initially reported 51% accuracy at 99.98% sparsity — an implausible result for an MLP with only ~340 active connections out of 1.7M weights.

The bug: The sparsity metric flagged gates < 0.01 as "pruned," but the forward pass still computed 0.009 × weight. Those residual values weren't zero — tiny activations that collectively maintained enough signal for 51% accuracy, creating a phantom result.

The fix: A hard mask enforced strictly inside model.eval() — any gate below threshold is set to exactly 0.0 before matrix multiplication. This eliminated the leak and revealed the honest result: at true 99.97% sparsity the network collapses to random-guess accuracy (10%), validating λ=0.0001 as the genuine operating point.


Stack

Layer Tool
Model PyTorch 2.x, custom PrunableLinear layer
Training CIFAR-10, CrossEntropy + L1 sparsity loss, automated lambda sweep
Inference API FastAPI, async endpoints, Pydantic schema validation
Testing pytest — shape stability + backward gradient flow
Containerization Docker
Workflow make install / test / train / serve

Quickstart

git clone https://github.com/arrya5/Self-Pruning-Neural-Network-Dynamic-Sparsification-PyTorch-FastAPI
cd Self-Pruning-Neural-Network-Dynamic-Sparsification-PyTorch-FastAPI

make install   # install dependencies
make test      # run pytest suite
make train     # full lambda sweep → prints accuracy/sparsity table
make serve     # start FastAPI inference server at localhost:8000

Inference API

# Health check
GET /health
→ {"status": "up", "framework": "PyTorch + FastAPI"}

# Predict
POST /api/v1/predict
{"image_tensor": [<3072 floats representing a 32×32 RGB image>]}
→ {"prediction_class": 3, "confidence_scores": [0.02, 0.01, ...]}

Docker

docker build -t self-pruning-nn .
docker run -p 8000:8000 self-pruning-nn

Project Structure

├── src/
│   ├── model.py      # PrunableLinear layer + SelfPruningNet architecture
│   ├── train.py      # training loop, lambda sweep, sparsity metrics
│   ├── config.py     # hyperparameters (lambdas, hidden dims, threshold)
│   └── utils.py      # gate extraction, visualization
├── tests/
│   └── test_model.py # shape stability + gradient flow tests
├── main.py           # FastAPI inference server
├── Dockerfile
└── Makefile

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors