feat(sae-steering): SAE full reconstruction runtime (phase 4 merge)#165
Open
RhizoNymph wants to merge 9 commits into
Open
feat(sae-steering): SAE full reconstruction runtime (phase 4 merge)#165RhizoNymph wants to merge 9 commits into
RhizoNymph wants to merge 9 commits into
Conversation
…hared index validation
…cuit, buffer registration
…ion, named-resolve cache
…oc to runtime contract
…ntime
Brings the previously-shipped phase-4-stage-{2,3,3b,4,5} branches into
the delta-path mainline so the full-reconstruction (Anthropic Scaling
Monosemanticity / Golden Gate Claude) variant runs end-to-end inside the
worker. Stages were authored independently and never landed because
the delta-path plumbing on feat/sae-steering had moved underneath them;
this commit resolves the conflicts against the current delta runtime.
Phase 4 stage 2: per-(layer, hook) full encoder/decoder buffers + per-row
clamp tables + shared sae_recon_index + apply_layer_sae_full_reconstruction
shim; registered as torch.ops.vllm.apply_sae_full_reconstruction so
torch.compile treats it as an opaque splitting point.
Phase 4 stage 3 / 3b: SAEFullReconstructionSpec request type with hash
plumbing folded into hash_steering_config (distinct domain separator
from the delta block), SAEFullReconstructionManager parallel to
SAEClampManager, worker-mixin lifecycle (admission, prefill->decode
transition, release, atomic register-and-attach with rollback).
Phase 4 stage 4: compaction-based CUDA path that runs the encoder /
clamp / decoder math only on tokens whose recon_mask is True; cuBLAS-
backed matmul on the dense subset, warmup wired into buffer attach.
Phase 4 stage 5: Gemma Scope real-weights e2e harness exercising both
pure-reconstruction-error-only and clamped-feature paths against a
single SAE site of Gemma 2-2B (CUDA + HF token gated).
apply_layer_steering now composes additive -> SAE delta -> SAE full
reconstruction in that order, each behind an independent static hasattr
check so the disabled-mode forward still emits zero steering kernels.
458 SAE tests pass.
Co-authored-by: Claude
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Brings the previously-shipped
feat/sae-steering-phase-4-stage-{2,3,3b,4,5}branches into the delta-path mainline so the full-reconstruction (Anthropic Scaling Monosemanticity / Golden Gate Claude) variant runs end-to-end inside the worker. Stages were authored independently and never landed because the delta-path plumbing onfeat/sae-steeringhad moved underneath them; this PR resolves the conflicts against the current delta runtime.sae_recon_index, andapply_layer_sae_full_reconstructionshim. Registered astorch.ops.vllm.apply_sae_full_reconstructionsotorch.compiletreats it as an opaque splitting point.SAEFullReconstructionSpecrequest type with hash plumbing folded intohash_steering_config(distinct domain separator from the delta block),SAEFullReconstructionManagerparallel toSAEClampManager, worker-mixin lifecycle (admission, prefill→decode transition, release, atomic register-and-attach with rollback).recon_maskisTrue. cuBLAS-backed matmul on the dense subset; warmup wired into buffer attach so the first-call JIT happens outside any captured forward.apply_layer_steeringnow composes additive → SAE delta → SAE full reconstruction in that order, each behind an independent statichasattrcheck so the disabled-mode forward still emits zero steering kernels.Sibling PR
A follow-up adds a global SAE clamp tier as an overlay on the delta path's row-0 sentinel. That PR is independent of this one but expects this PR to land first for the docs cross-references to make sense.
Files
20 files changed, 5666 insertions, 149 deletions. New runtime modules:
vllm/model_executor/layers/sae_full_reconstruction.py— public op, custom-op registration, layer-hook shim, per-(layer, hook) buffer attach/detach.vllm/model_executor/layers/sae_full_reconstruction_kernel.py— compaction-based CUDA path + warmup helper.vllm/v1/worker/sae_full_reconstruction_manager.py— per-request row allocator.Test plan
pytest tests/v1/test_sae_full_reconstruction_types.pypytest tests/v1/worker/test_sae_full_reconstruction_manager.pypytest tests/v1/worker/test_sae_full_recon_mixin_integration.pypytest tests/model_executor/layers/test_sae_full_reconstruction_*pytest tests/models/language/generation/test_sae_full_reconstruction_real_weights.py(CUDA + HF gated)