diff --git a/cmd/interactsh-server/main.go b/cmd/interactsh-server/main.go index e7344f55..eff168df 100644 --- a/cmd/interactsh-server/main.go +++ b/cmd/interactsh-server/main.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" _ "net/http/pprof" @@ -261,6 +262,11 @@ func main() { } } + serverOptions.Stats = &server.Metrics{} + storeOptions.OnRemoval = func() { + atomic.AddInt64(&serverOptions.Stats.Sessions, -1) + } + var err error store, err = storage.New(&storeOptions) if err != nil { @@ -273,8 +279,6 @@ func main() { _ = serverOptions.Storage.SetID(serverOptions.Token) } - serverOptions.Stats = &server.Metrics{} - // If root-tld is enabled create a singleton unencrypted record in the store if serverOptions.RootTLD { for _, domain := range serverOptions.Domains { diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index 28757b62..1acc5b23 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -388,13 +388,13 @@ func (h *HTTPServer) registerHandler(w http.ResponseWriter, req *http.Request) { return } - atomic.AddInt64(&h.options.Stats.Sessions, 1) - if err := h.options.Storage.SetIDPublicKey(r.CorrelationID, r.SecretKey, r.PublicKey); err != nil { gologger.Warning().Msgf("Could not set id and public key for %s: %s\n", r.CorrelationID, err) jsonError(w, fmt.Sprintf("could not set id and public key: %s", err), http.StatusBadRequest) return } + atomic.AddInt64(&h.options.Stats.Sessions, 1) + atomic.AddInt64(&h.options.Stats.SessionsTotal, 1) jsonMsg(w, "registration successful", http.StatusOK) gologger.Debug().Msgf("Registered correlationID %s for key\n", r.CorrelationID) } @@ -409,8 +409,6 @@ type DeregisterRequest struct { // deregisterHandler is a handler for client deregister requests func (h *HTTPServer) deregisterHandler(w http.ResponseWriter, req *http.Request) { - atomic.AddInt64(&h.options.Stats.Sessions, -1) - r := &DeregisterRequest{} if err := jsoniter.NewDecoder(req.Body).Decode(r); err != nil { gologger.Warning().Msgf("Could not decode json body: %s\n", err) diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go index d01749f1..f8311f98 100644 --- a/pkg/server/http_server_test.go +++ b/pkg/server/http_server_test.go @@ -1,13 +1,25 @@ package server import ( + "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/pem" "io" "net/http" "net/http/httptest" + "sync" + "sync/atomic" "testing" "time" + "github.com/google/uuid" + jsoniter "github.com/json-iterator/go" + "github.com/projectdiscovery/interactsh/pkg/storage" + "github.com/rs/xid" "github.com/stretchr/testify/require" ) @@ -70,3 +82,72 @@ func TestWriteResponseFromDynamicRequest(t *testing.T) { require.Equal(t, resp.Header.Get("Test"), "Another", "could not get correct result") }) } + +func TestSessionTotalMetric(t *testing.T) { + stats := &Metrics{} + removed := make(chan struct{}) + closeOnce := sync.Once{} + + store, err := storage.New(&storage.Options{ + EvictionTTL: 5 * time.Minute, + OnRemoval: func() { + atomic.AddInt64(&stats.Sessions, -1) + closeOnce.Do(func() { close(removed) }) + }, + }) + require.NoError(t, err) + defer store.Close() + + h := &HTTPServer{ + options: &Options{ + Storage: store, + Stats: stats, + }, + } + + // Generate a client key pair and registration payload. + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) + require.NoError(t, err) + pubPem := pem.EncodeToMemory(&pem.Block{Type: "RSA PUBLIC KEY", Bytes: pubBytes}) + pubB64 := base64.StdEncoding.EncodeToString(pubPem) + + correlationID := xid.New().String() + secretKey := uuid.New().String() + + // --- Register --- + regBody, err := jsoniter.Marshal(&RegisterRequest{ + PublicKey: pubB64, + SecretKey: secretKey, + CorrelationID: correlationID, + }) + require.NoError(t, err) + req := httptest.NewRequest("POST", "/register", bytes.NewReader(regBody)) + w := httptest.NewRecorder() + h.registerHandler(w, req) + require.Equal(t, http.StatusOK, w.Code) + + require.Equal(t, int64(1), atomic.LoadInt64(&stats.Sessions), "sessions should be 1 after register") + require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should be 1 after register") + + // --- Deregister --- + deregBody, err := jsoniter.Marshal(&DeregisterRequest{ + SecretKey: secretKey, + CorrelationID: correlationID, + }) + require.NoError(t, err) + req = httptest.NewRequest("POST", "/deregister", bytes.NewReader(deregBody)) + w = httptest.NewRecorder() + h.deregisterHandler(w, req) + require.Equal(t, http.StatusOK, w.Code) + + select { + case <-removed: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for OnRemoval callback") + } + + require.Equal(t, int64(0), atomic.LoadInt64(&stats.Sessions), "sessions should be 0 after deregister") + require.Equal(t, int64(1), atomic.LoadInt64(&stats.SessionsTotal), "sessions_total should remain 1 after deregister") +} diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index b51d13a1..64e20f8d 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -9,17 +9,18 @@ import ( ) type Metrics struct { - Dns uint64 `json:"dns"` - Ftp uint64 `json:"ftp"` - Http uint64 `json:"http"` - Ldap uint64 `json:"ldap"` - Smb uint64 `json:"smb"` - Smtp uint64 `json:"smtp"` - Sessions int64 `json:"sessions"` - Cache *storage.CacheMetrics `json:"cache"` - Memory *MemoryMetrics `json:"memory"` - Cpu *CpuStats `json:"cpu"` - Network *NetworkStats `json:"network"` + Dns uint64 `json:"dns"` + Ftp uint64 `json:"ftp"` + Http uint64 `json:"http"` + Ldap uint64 `json:"ldap"` + Smb uint64 `json:"smb"` + Smtp uint64 `json:"smtp"` + Sessions int64 `json:"sessions"` + SessionsTotal int64 `json:"sessions_total"` + Cache *storage.CacheMetrics `json:"cache"` + Memory *MemoryMetrics `json:"memory"` + Cpu *CpuStats `json:"cpu"` + Network *NetworkStats `json:"network"` } func GetCacheMetrics(options *Options) *storage.CacheMetrics { diff --git a/pkg/storage/option.go b/pkg/storage/option.go index 5d32a5ff..1d91fb49 100644 --- a/pkg/storage/option.go +++ b/pkg/storage/option.go @@ -15,6 +15,9 @@ type Options struct { MaxSize int MaxSharedInteractions int EvictionStrategy EvictionStrategy + // OnRemoval is called for each client session removed from cache + // (deregistration, TTL expiry, size eviction, or cache close). + OnRemoval func() } func (options *Options) UseDisk() bool { diff --git a/pkg/storage/roundtrip_test.go b/pkg/storage/roundtrip_test.go index 7324c4cb..70554847 100644 --- a/pkg/storage/roundtrip_test.go +++ b/pkg/storage/roundtrip_test.go @@ -403,7 +403,7 @@ func TestStaleDataCleanupOnReRegistration(t *testing.T) { _ = priv1 } -// TestCacheEvictionCleansLevelDB verifies the OnCacheRemovalCallback properly +// TestCacheEvictionCleansLevelDB verifies the onCacheRemoval callback properly // deletes LevelDB entries when cache entries are evicted. func TestCacheEvictionCleansLevelDB(t *testing.T) { tmpDir, err := os.MkdirTemp("", "interactsh-eviction-*") @@ -451,7 +451,74 @@ func TestCacheEvictionCleansLevelDB(t *testing.T) { // Small delay for async eviction callback time.Sleep(50 * time.Millisecond) - // LevelDB entry should be cleaned up by OnCacheRemovalCallback + // LevelDB entry should be cleaned up by onCacheRemoval _, err = db.db.Get([]byte(correlationID), nil) require.Error(t, err, "LevelDB entry should be deleted after cache eviction") } + +// TestOnRemovalSessionTracking verifies that the OnRemoval callback fires +// exactly once per client session on deregister and TTL eviction, and does +// not fire for non-session entries created via SetID. +func TestOnRemovalSessionTracking(t *testing.T) { + removed := make(chan struct{}, 10) + onRemoval := func() { removed <- struct{}{} } + + db, err := New(&Options{ + EvictionTTL: 50 * time.Millisecond, + EvictionStrategy: EvictionStrategyFixed, + OnRemoval: onRemoval, + }) + require.NoError(t, err) + defer db.Close() + + waitRemoval := func(msg string) { + t.Helper() + select { + case <-removed: + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for OnRemoval: %s", msg) + } + } + + // --- Non-session entries (SetID) must not trigger OnRemoval --- + // Invalidate a SetID entry, then register+deregister a real session as a + // FIFO barrier: the cache event channel is ordered, so receiving the + // session's callback proves the SetID invalidation was already processed. + _ = db.SetID("token-entry") + db.cache.Invalidate("token-entry") + + secret := uuid.New().String() + cid := xid.New().String() + _, pubKey := generateRSAKeyPair(t) + require.NoError(t, db.SetIDPublicKey(cid, secret, pubKey)) + require.NoError(t, db.RemoveID(cid, secret)) + waitRemoval("deregister barrier") + select { + case <-removed: + t.Fatal("SetID entry should not trigger OnRemoval") + default: + } + + // --- TTL eviction must trigger OnRemoval --- + secret2 := uuid.New().String() + cid2 := xid.New().String() + _, pubKey2 := generateRSAKeyPair(t) + require.NoError(t, db.SetIDPublicKey(cid2, secret2, pubKey2)) + + // Periodically access the cache to trigger lazy eviction. + stop := make(chan struct{}) + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-stop: + return + case <-ticker.C: + db.cache.GetIfPresent(cid2) + } + } + }() + waitRemoval("TTL eviction") + close(stop) +} diff --git a/pkg/storage/storagedb.go b/pkg/storage/storagedb.go index 0aa9c8fd..0bc65185 100644 --- a/pkg/storage/storagedb.go +++ b/pkg/storage/storagedb.go @@ -50,9 +50,7 @@ func New(options *Options) (*StorageDB, error) { cacheOptions = append(cacheOptions, cache.WithExpireAfterAccess(options.EvictionTTL)) } } - if options.UseDisk() { - cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.OnCacheRemovalCallback)) - } + cacheOptions = append(cacheOptions, cache.WithRemovalListener(storageDB.onCacheRemoval)) cacheDb := cache.New(cacheOptions...) storageDB.cache = cacheDb @@ -77,10 +75,21 @@ func New(options *Options) (*StorageDB, error) { return storageDB, nil } -func (s *StorageDB) OnCacheRemovalCallback(key cache.Key, value cache.Value) { - if k, ok := key.(string); ok { +func (s *StorageDB) onCacheRemoval(key cache.Key, value cache.Value) { + k, ok := key.(string) + if !ok { + return + } + if s.Options.UseDisk() && s.db != nil { _ = s.db.Delete([]byte(k), &opt.WriteOptions{}) } + // Only fire for client sessions (entries with a SecretKey), + // not for token/domain entries created via SetID. + if s.Options.OnRemoval != nil { + if cd, ok := value.(*CorrelationData); ok && cd.SecretKey != "" { + s.Options.OnRemoval() + } + } } func (s *StorageDB) GetCacheMetrics() (*CacheMetrics, error) {