Skip to content

feat(sae-steering): SAE full reconstruction runtime (phase 4 merge)#165

Open
RhizoNymph wants to merge 9 commits into
feat/sae-steeringfrom
feat/sae-steering-phase-4-runtime
Open

feat(sae-steering): SAE full reconstruction runtime (phase 4 merge)#165
RhizoNymph wants to merge 9 commits into
feat/sae-steeringfrom
feat/sae-steering-phase-4-runtime

Conversation

@RhizoNymph

Copy link
Copy Markdown
Owner

Summary

Brings the previously-shipped feat/sae-steering-phase-4-stage-{2,3,3b,4,5} branches into the delta-path mainline so the full-reconstruction (Anthropic Scaling Monosemanticity / Golden Gate Claude) variant runs end-to-end inside the worker. Stages were authored independently and never landed because the delta-path plumbing on feat/sae-steering had moved underneath them; this PR resolves the conflicts against the current delta runtime.

  • Stage 2 — per-(layer, hook) full encoder / decoder buffers, per-row clamp tables, shared sae_recon_index, and apply_layer_sae_full_reconstruction shim. Registered as torch.ops.vllm.apply_sae_full_reconstruction so torch.compile treats it as an opaque splitting point.
  • Stage 3 / 3bSAEFullReconstructionSpec request type with hash plumbing folded into hash_steering_config (distinct domain separator from the delta block), SAEFullReconstructionManager parallel to SAEClampManager, worker-mixin lifecycle (admission, prefill→decode transition, release, atomic register-and-attach with rollback).
  • Stage 4 — compaction-based CUDA path that runs the encoder / clamp / decoder math only on tokens whose recon_mask is True. cuBLAS-backed matmul on the dense subset; warmup wired into buffer attach so the first-call JIT happens outside any captured forward.
  • Stage 5 — Gemma Scope real-weights e2e harness exercising both pure-reconstruction-error-only and clamped-feature paths against a single SAE site of Gemma 2-2B (CUDA + HF token gated).

apply_layer_steering now composes additive → SAE delta → SAE full reconstruction in that order, each behind an independent static hasattr check so the disabled-mode forward still emits zero steering kernels.

Sibling PR

A follow-up adds a global SAE clamp tier as an overlay on the delta path's row-0 sentinel. That PR is independent of this one but expects this PR to land first for the docs cross-references to make sense.

Files

20 files changed, 5666 insertions, 149 deletions. New runtime modules:

  • vllm/model_executor/layers/sae_full_reconstruction.py — public op, custom-op registration, layer-hook shim, per-(layer, hook) buffer attach/detach.
  • vllm/model_executor/layers/sae_full_reconstruction_kernel.py — compaction-based CUDA path + warmup helper.
  • vllm/v1/worker/sae_full_reconstruction_manager.py — per-request row allocator.

Test plan

  • pytest tests/v1/test_sae_full_reconstruction_types.py
  • pytest tests/v1/worker/test_sae_full_reconstruction_manager.py
  • pytest tests/v1/worker/test_sae_full_recon_mixin_integration.py
  • pytest tests/model_executor/layers/test_sae_full_reconstruction_*
  • pytest tests/models/language/generation/test_sae_full_reconstruction_real_weights.py (CUDA + HF gated)
  • Full SAE suite (delta + full-recon + populator + manager + mixin) — 458 passing.

…ntime

Brings the previously-shipped phase-4-stage-{2,3,3b,4,5} branches into
the delta-path mainline so the full-reconstruction (Anthropic Scaling
Monosemanticity / Golden Gate Claude) variant runs end-to-end inside the
worker.  Stages were authored independently and never landed because
the delta-path plumbing on feat/sae-steering had moved underneath them;
this commit resolves the conflicts against the current delta runtime.

Phase 4 stage 2: per-(layer, hook) full encoder/decoder buffers + per-row
clamp tables + shared sae_recon_index + apply_layer_sae_full_reconstruction
shim; registered as torch.ops.vllm.apply_sae_full_reconstruction so
torch.compile treats it as an opaque splitting point.

Phase 4 stage 3 / 3b: SAEFullReconstructionSpec request type with hash
plumbing folded into hash_steering_config (distinct domain separator
from the delta block), SAEFullReconstructionManager parallel to
SAEClampManager, worker-mixin lifecycle (admission, prefill->decode
transition, release, atomic register-and-attach with rollback).

Phase 4 stage 4: compaction-based CUDA path that runs the encoder /
clamp / decoder math only on tokens whose recon_mask is True; cuBLAS-
backed matmul on the dense subset, warmup wired into buffer attach.

Phase 4 stage 5: Gemma Scope real-weights e2e harness exercising both
pure-reconstruction-error-only and clamped-feature paths against a
single SAE site of Gemma 2-2B (CUDA + HF token gated).

apply_layer_steering now composes additive -> SAE delta -> SAE full
reconstruction in that order, each behind an independent static hasattr
check so the disabled-mode forward still emits zero steering kernels.

458 SAE tests pass.

Co-authored-by: Claude
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant