From 7a32e3256f4fb3fe9f748ccf3b70db891f71e3ee Mon Sep 17 00:00:00 2001 From: Mike Jensen Date: Wed, 25 Mar 2026 16:14:59 -0600 Subject: [PATCH] fix: Accurate session count to avoid constant upward drift The register handler incremented the session counter before `SetIDPublicKey` could fail, leaking +1 on duplicate IDs or bad keys. Now only incremented after successful registration. The deregister handler decremented the counter before validating the request, leaking -1 on malformed or unauthorized requests. Removed the explicit decrement entirely (handled in logic below). Cache eviction and TTL expiry silently removed sessions without decrementing the counter, causing monotonic growth. Unified all session decrements into a single cache removal callback that fires on deregistration, eviction, and cache close, filtering to only count entries with a SecretKey (true client sessions). session_total was added as a metric so that even short lived sessions can be viewed in the metrics. --- cmd/interactsh-server/main.go | 8 +++- pkg/server/http_server.go | 6 +-- pkg/server/http_server_test.go | 81 ++++++++++++++++++++++++++++++++++ pkg/server/metrics.go | 23 +++++----- pkg/storage/option.go | 3 ++ pkg/storage/roundtrip_test.go | 71 ++++++++++++++++++++++++++++- pkg/storage/storagedb.go | 19 +++++--- 7 files changed, 187 insertions(+), 24 deletions(-) diff --git a/cmd/interactsh-server/main.go b/cmd/interactsh-server/main.go index e7344f55..eff168df 100644 --- a/cmd/interactsh-server/main.go +++ b/cmd/interactsh-server/main.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" _ "net/http/pprof" @@ -261,6 +262,11 @@ func main() { } } + serverOptions.Stats = &server.Metrics{} + storeOptions.OnRemoval = func() { + atomic.AddInt64(&serverOptions.Stats.Sessions, -1) + } + var err error store, err = storage.New(&storeOptions) if err != nil { @@ -273,8 +279,6 @@ func main() { _ = serverOptions.Storage.SetID(serverOptions.Token) } - serverOptions.Stats = &server.Metrics{} - // If root-tld is enabled create a singleton unencrypted record in the store if serverOptions.RootTLD { for _, domain := range serverOptions.Domains { diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index 28757b62..1acc5b23 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -388,13 +388,13 @@ func (h *HTTPServer) registerHandler(w http.ResponseWriter, req *http.Request) { return } - atomic.AddInt64(&h.options.Stats.Sessions, 1) - if err := h.options.Storage.SetIDPublicKey(r.CorrelationID, r.SecretKey, r.PublicKey); err != nil { gologger.Warning().Msgf("Could not set id and public key for %s: %s\n", r.CorrelationID, err) jsonError(w, fmt.Sprintf("could not set id and public key: %s", err), http.StatusBadRequest) return } + atomic.AddInt64(&h.options.Stats.Sessions, 1) + atomic.AddInt64(&h.options.Stats.SessionsTotal, 1) jsonMsg(w, "registration successful", http.StatusOK) gologger.Debug().Msgf("Registered correlationID %s for key\n", r.CorrelationID) } @@ -409,8 +409,6 @@ type DeregisterRequest struct { // deregisterHandler is a handler for client deregister requests func (h *HTTPServer) deregisterHandler(w http.ResponseWriter, req *http.Request) { - atomic.AddInt64(&h.options.Stats.Sessions, -1) - r := &DeregisterRequest{} if err := jsoniter.NewDecoder(req.Body).Decode(r); err != nil { gologger.Warning().Msgf("Could not decode json body: %s\n", err) diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go index d01749f1..f8311f98 100644 --- a/pkg/server/http_server_test.go +++ b/pkg/server/http_server_test.go @@ -1,13 +1,25 @@ package server import ( + "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/pem" "io" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" "time" + "github.com/google/uuid" + jsoniter "github.com/json-iterator/go" + "github.com/projectdiscovery/interactsh/pkg/storage" + "github.com/rs/xid" "github.com/stretchr/testify/require" ) @@ -70,3 +82,72 @@ func TestWriteResponseFromDynamicRequest(t *testing.T) { require.Equal(t, resp.Header.Get("Test"), "Another", "could not get correct result") }) } + +func TestSessionTotalMetric(t *testing.T) { + stats := &Metrics{} + removed := make(chan struct{}) + closeOnce := sync.Once{} + + store, err := storage.New(&storage.Options{ + EvictionTTL: 5 * time.Minute, + OnRemoval: func() { + atomic.AddInt64(&stats.Sessions, -1) + closeOnce.Do(func() { close(removed) }) + }, + }) + require.NoError(t, err) + defer store.Close() + + h := &HTTPServer{ + options: &Options{ + Storage: store, + Stats: stats, + }, + } + + // Generate a client key pair and registration payload. + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + pubPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes}) + pubB64 := base64.StdEncoding.EncodeToString(pubPem) + + correlationID := xid.New().String() + secretKey := uuid.New().String() + + // --- Register --- + regBody, err := jsoniter.Marshal(&RegisterRequest{ + PublicKey: pubB64, + SecretKey: secretKey, + CorrelationID: correlationID, + }) + require.NoError(t, err) + req := httptest.NewRequest("POST", "/register", bytes.NewReader(regBody)) + w := httptest.NewRecorder() + h.registerHandler(w, req) + require.Equal(t, http.StatusOK, w.Code) + + require.Equal(t, int64(1), atomic.LoadInt64(&stats.Sessions), "sessions should be 1 after register") + require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should be 1 after register") + + // --- Deregister --- + deregBody, err := jsoniter.Marshal(&DeregisterRequest{ + SecretKey: secretKey, + CorrelationID: correlationID, + }) + require.NoError(t, err) + req = httptest.NewRequest("POST", "/deregister", bytes.NewReader(deregBody)) + w = httptest.NewRecorder() + h.deregisterHandler(w, req) + require.Equal(t, http.StatusOK, w.Code) + + select { + case <-removed: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for OnRemoval callback") + } + + require.Equal(t, int64(0), atomic.LoadInt64(&stats.Sessions), "sessions should be 0 after deregister") + require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should remain 1 after deregister") +} diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index b51d13a1..64e20f8d 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -9,17 +9,18 @@ import ( ) type Metrics struct { - Dns uint64 `json:"dns"` - Ftp uint64 `json:"ftp"` - Http uint64 `json:"http"` - Ldap uint64 `json:"ldap"` - Smb uint64 `json:"smb"` - Smtp uint64 `json:"smtp"` - Sessions int64 `json:"sessions"` - Cache *storage.CacheMetrics `json:"cache"` - Memory *MemoryMetrics `json:"memory"` - Cpu *CpuStats `json:"cpu"` - Network *NetworkStats `json:"network"` + Dns uint64 `json:"dns"` + Ftp uint64 `json:"ftp"` + Http uint64 `json:"http"` + Ldap uint64 `json:"ldap"` + Smb uint64 `json:"smb"` + Smtp uint64 `json:"smtp"` + Sessions int64 `json:"sessions"` + SessionsTotal int64 `json:"sessions_total"` + Cache *storage.CacheMetrics `json:"cache"` + Memory *MemoryMetrics `json:"memory"` + Cpu *CpuStats `json:"cpu"` + Network *NetworkStats `json:"network"` } func GetCacheMetrics(options *Options) *storage.CacheMetrics { diff --git a/pkg/storage/option.go b/pkg/storage/option.go index 5d32a5ff..1d91fb49 100644 --- a/pkg/storage/option.go +++ b/pkg/storage/option.go @@ -15,6 +15,9 @@ type Options struct { MaxSize int MaxSharedInteractions int EvictionStrategy EvictionStrategy + // OnRemoval is called for each client session removed from cache + // (deregistration, TTL expiry, size eviction, or cache close). + OnRemoval func() } func (options *Options) UseDisk() bool { diff --git a/pkg/storage/roundtrip_test.go b/pkg/storage/roundtrip_test.go index 7324c4cb..70554847 100644 --- a/pkg/storage/roundtrip_test.go +++ b/pkg/storage/roundtrip_test.go @@ -403,7 +403,7 @@ func TestStaleDataCleanupOnReRegistration(t *testing.T) { _ = priv1 } -// TestCacheEvictionCleansLevelDB verifies the OnCacheRemovalCallback properly +// TestCacheEvictionCleansLevelDB verifies the onCacheRemoval callback properly // deletes LevelDB entries when cache entries are evicted. func TestCacheEvictionCleansLevelDB(t *testing.T) { tmpDir, err := os.MkdirTemp("", "interactsh-eviction-*") @@ -451,7 +451,74 @@ func TestCacheEvictionCleansLevelDB(t *testing.T) { // Small delay for async eviction callback time.Sleep(50 * time.Millisecond) - // LevelDB entry should be cleaned up by OnCacheRemovalCallback + // LevelDB entry should be cleaned up by onCacheRemoval _, err = db.db.Get([]byte(correlationID), nil) require.Error(t, err, "LevelDB entry should be deleted after cache eviction") } + +// TestOnRemovalSessionTracking verifies that the OnRemoval callback fires +// exactly once per client session on deregister and TTL eviction, and does +// not fire for non-session entries created via SetID. +func TestOnRemovalSessionTracking(t *testing.T) { + removed := make(chan struct{}, 10) + onRemoval := func() { removed <- struct{}{} } + + db, err := New(&Options{ + EvictionTTL: 50 * time.Millisecond, + EvictionStrategy: EvictionStrategyFixed, + OnRemoval: onRemoval, + }) + require.NoError(t, err) + defer db.Close() + + waitRemoval := func(msg string) { + t.Helper() + select { + case <-removed: + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for OnRemoval: %s", msg) + } + } + + // --- Non-session entries (SetID) must not trigger OnRemoval --- + // Invalidate a SetID entry, then register+deregister a real session as a + // FIFO barrier: the cache event channel is ordered, so receiving the + // session's callback proves the SetID invalidation was already processed. + _ = db.SetID("token-entry") + db.cache.Invalidate("token-entry") + + secret := uuid.New().String() + cid := xid.New().String() + _, pubKey := generateRSAKeyPair(t) + require.NoError(t, db.SetIDPublicKey(cid, secret, pubKey)) + require.NoError(t, db.RemoveID(cid, secret)) + waitRemoval("deregister barrier") + select { + case <-removed: + t.Fatal("SetID entry should not trigger OnRemoval") + default: + } + + // --- TTL eviction must trigger OnRemoval --- + secret2 := uuid.New().String() + cid2 := xid.New().String() + _, pubKey2 := generateRSAKeyPair(t) + require.NoError(t, db.SetIDPublicKey(cid2, secret2, pubKey2)) + + // Periodically access the cache to trigger lazy eviction. + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + db.cache.GetIfPresent(cid2) + } + } + }() + waitRemoval("TTL eviction") + close(stop) +} diff --git a/pkg/storage/storagedb.go b/pkg/storage/storagedb.go index 0aa9c8fd..0bc65185 100644 --- a/pkg/storage/storagedb.go +++ b/pkg/storage/storagedb.go @@ -50,9 +50,7 @@ func New(options *Options) (*StorageDB, error) { cacheOptions = append(cacheOptions, cache.WithExpireAfterAccess(options.EvictionTTL)) } } - if options.UseDisk() { - cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.OnCacheRemovalCallback)) - } + cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.onCacheRemoval)) cacheDb := cache.New(cacheOptions...) storageDB.cache = cacheDb @@ -77,10 +75,21 @@ func New(options *Options) (*StorageDB, error) { return storageDB, nil } -func (s *StorageDB) OnCacheRemovalCallback(key cache.Key, value cache.Value) { - if k, ok := key.(string); ok { +func (s *StorageDB) onCacheRemoval(key cache.Key, value cache.Value) { + k, ok := key.(string) + if !ok { + return + } + if s.Options.UseDisk() && s.db != nil { _ = s.db.Delete([]byte(k), &opt.WriteOptions{}) } + // Only fire for client sessions (entries with a SecretKey), + // not for token/domain entries created via SetID. + if s.Options.OnRemoval != nil { + if cd, ok := value.(*CorrelationData); ok && cd.SecretKey != "" { + s.Options.OnRemoval() + } + } } func (s *StorageDB) GetCacheMetrics() (*CacheMetrics, error) {