Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions internal/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,24 @@ func New(cfg *config.Config, cfgPath string, logger *zap.Logger) (*Runtime, erro
phaseMachine: newPhaseMachine(PhaseInitializing),
}

// Register flow expiry callback now that rt is available (Spec 027)
if rt.flowService != nil {
rt.flowService.SetExpiryCallback(func(summary *flow.FlowSummary) {
rt.EmitActivityFlowSummary(
summary.SessionID,
summary.CoverageMode,
summary.DurationMinutes,
summary.TotalOrigins,
summary.TotalFlows,
summary.FlowTypeDistribution,
summary.RiskLevelDistribution,
summary.LinkedMCPSessions,
summary.ToolsUsed,
summary.HasSensitiveFlows,
)
})
}

return rt, nil
}

Expand Down Expand Up @@ -551,6 +569,14 @@ func (r *Runtime) Close() error {
}
}

// Stop flow service to halt session expiry and correlation goroutines
if r.flowService != nil {
r.flowService.Stop()
if r.logger != nil {
r.logger.Info("Flow service stopped")
}
}

// Phase 6: Stop Supervisor first to stop reconciliation
if r.supervisor != nil {
r.supervisor.Stop()
Expand Down
86 changes: 86 additions & 0 deletions internal/security/flow/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,89 @@ func TestCorrelation_StaleCorrelation_Ignored(t *testing.T) {
hookID := svc.correlator.MatchAndConsume(argsHash)
assert.Empty(t, hookID, "stale correlation should not match")
}

// === Additional Service Tests (Stop and Expiry) ===

// TestFlowService_Stop_Idempotent tests that calling Stop() multiple times is safe.
func TestFlowService_Stop_Idempotent(t *testing.T) {
classifier := NewClassifier(nil)
trackerCfg := &TrackerConfig{
SessionTimeoutMin: 30,
MaxOriginsPerSession: 10000,
HashMinLength: 20,
MaxResponseHashBytes: 65536,
}
tracker := NewFlowTracker(trackerCfg)
policyCfg := &PolicyConfig{
InternalToExternal: PolicyAsk,
SensitiveDataExternal: PolicyDeny,
}
policy := NewPolicyEvaluator(policyCfg)
correlator := NewCorrelator(5 * time.Second)

svc := NewFlowService(classifier, tracker, policy, nil, correlator)

// Call Stop() twice — should not panic
svc.Stop()
svc.Stop()

// Verify tracker and correlator stop channels are closed
// (can't directly verify channel closure, but calling Stop again shouldn't panic)
assert.NotPanics(t, func() {
svc.Stop()
}, "multiple Stop() calls should be safe")
}

// TestFlowService_ExpiryCallback_EmitsSummary tests that session expiry triggers the callback
// with a correct FlowSummary.
func TestFlowService_ExpiryCallback_EmitsSummary(t *testing.T) {
classifier := NewClassifier(nil)
trackerCfg := &TrackerConfig{
SessionTimeoutMin: 0, // Will use manual expiry
MaxOriginsPerSession: 10000,
HashMinLength: 20,
MaxResponseHashBytes: 65536,
}
tracker := NewFlowTracker(trackerCfg)
policyCfg := &PolicyConfig{
InternalToExternal: PolicyAsk,
SensitiveDataExternal: PolicyDeny,
}
policy := NewPolicyEvaluator(policyCfg)
correlator := NewCorrelator(5 * time.Second)

svc := NewFlowService(classifier, tracker, policy, &mockDetector{}, correlator)
defer svc.Stop()

// Channel to receive the callback
summaryReceived := make(chan *FlowSummary, 1)

// Register expiry callback
svc.SetExpiryCallback(func(summary *FlowSummary) {
summaryReceived <- summary
})

// Create a session and record an origin
sessionID := "expiry-test-session"
svc.RecordOriginProxy(sessionID, "postgres", "query", "test data for expiry callback verification")

// Force session to look expired by setting LastActivity in the past
session := tracker.GetSession(sessionID)
require.NotNil(t, session)
session.mu.Lock()
session.LastActivity = time.Now().Add(-2 * time.Hour)
session.mu.Unlock()

// Manually trigger expiry
tracker.expireSessions()

// Wait for callback with timeout
select {
case summary := <-summaryReceived:
assert.Equal(t, sessionID, summary.SessionID, "summary should have correct session ID")
assert.Greater(t, summary.TotalOrigins, 0, "summary should have origins recorded")
assert.NotEmpty(t, summary.CoverageMode, "summary should have coverage mode set")
case <-time.After(1 * time.Second):
t.Fatal("expiry callback was not called within timeout")
}
}
4 changes: 2 additions & 2 deletions internal/security/flow/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func (ft *FlowTracker) CheckFlow(sessionID string, toolName, serverName string,
return nil, nil
}

session.mu.RLock()
defer session.mu.RUnlock()
session.mu.Lock()
defer session.mu.Unlock()

var edges []*FlowEdge
matched := make(map[string]bool) // Avoid duplicate edges for same content hash
Expand Down
118 changes: 118 additions & 0 deletions internal/security/flow/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,121 @@ func TestFlowTracker_ToolsUsedTracking(t *testing.T) {
require.NotNil(t, session)
assert.True(t, session.ToolsUsed["Read"], "Read should be tracked in ToolsUsed")
}

// TestFlowTracker_CheckFlow_ConcurrentRace is a regression test for the race condition
// fix where CheckFlow used RLock but modified session fields (LastActivity, ToolsUsed, Flows).
// The fix changed line 80 from RLock to Lock to properly synchronize write access.
func TestFlowTracker_CheckFlow_ConcurrentRace(t *testing.T) {
tracker := NewFlowTracker(newTestTrackerConfig())
defer tracker.Stop()

sessionID := "concurrent-race-session"

// Record multiple origins with different content to create potential flow edges
origins := []string{
"database record one with sufficient length for hashing",
"database record two with sufficient length for hashing",
"database record three with sufficient length for hashing",
"database record four with sufficient length for hashing",
"database record five with sufficient length for hashing",
}

for i, data := range origins {
origin := &DataOrigin{
ContentHash: HashContent(data),
ToolName: fmt.Sprintf("Read-%d", i),
Classification: ClassInternal,
Timestamp: time.Now(),
}
tracker.RecordOrigin(sessionID, origin)
}

// Prepare test cases with concurrent CheckFlow calls
testCases := []struct {
toolName string
data string
shouldHit bool
}{
{"WebFetch-1", origins[0], true},
{"WebFetch-2", origins[1], true},
{"WebFetch-3", origins[2], true},
{"WebFetch-4", "unrelated data that does not match", false},
{"WebFetch-5", origins[3], true},
}

var wg sync.WaitGroup
errChan := make(chan error, len(testCases)*3+5)

// Spawn concurrent CheckFlow goroutines (3 iterations per test case)
for _, tc := range testCases {
for iteration := 0; iteration < 3; iteration++ {
wg.Add(1)
go func(toolName, data string, shouldHit bool) {
defer wg.Done()
argsJSON := fmt.Sprintf(`{"payload": %q}`, data)
edges, err := tracker.CheckFlow(sessionID, toolName, "", ClassExternal, argsJSON)
if err != nil {
errChan <- fmt.Errorf("CheckFlow error for %s: %w", toolName, err)
return
}
if shouldHit && len(edges) == 0 {
errChan <- fmt.Errorf("expected flow edge for %s but got none", toolName)
return
}
if !shouldHit && len(edges) > 0 {
errChan <- fmt.Errorf("unexpected flow edge for %s", toolName)
return
}
}(tc.toolName, tc.data, tc.shouldHit)
}
}

// Also spawn concurrent RecordOrigin goroutines to increase contention
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
newData := fmt.Sprintf("concurrent origin data number %d with sufficient length", idx)
origin := &DataOrigin{
ContentHash: HashContent(newData),
ToolName: fmt.Sprintf("ConcurrentTool-%d", idx),
Classification: ClassInternal,
Timestamp: time.Now(),
}
tracker.RecordOrigin(sessionID, origin)
}(i)
}

// Wait for all goroutines to complete
wg.Wait()
close(errChan)

// Check for errors
var errors []error
for err := range errChan {
errors = append(errors, err)
}
require.Empty(t, errors, "concurrent operations should not produce errors: %v", errors)

// Verify session state is consistent
session := tracker.GetSession(sessionID)
require.NotNil(t, session, "session should exist after concurrent operations")

// Verify origins are recorded (initial 5 + concurrent 5)
assert.GreaterOrEqual(t, len(session.Origins), 5, "should have at least initial origins")

// Verify flows were detected
assert.Greater(t, len(session.Flows), 0, "should have detected some flows")

// Verify ToolsUsed is populated (this field is modified by CheckFlow)
assert.NotEmpty(t, session.ToolsUsed, "ToolsUsed should be populated")
for _, tc := range testCases {
if tc.shouldHit {
assert.True(t, session.ToolsUsed[tc.toolName], "%s should be in ToolsUsed", tc.toolName)
}
}

// Verify LastActivity was updated (this field is modified by CheckFlow)
assert.False(t, session.LastActivity.IsZero(), "LastActivity should be set")
assert.True(t, time.Since(session.LastActivity) < 2*time.Second, "LastActivity should be recent")
}