diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 5b7474a6..62c85017 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -407,7 +407,7 @@ public async Task CreateSessionAsync(SessionConfig config, Cance // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. - var session = new CopilotSession(sessionId, connection.Rpc); + var session = new CopilotSession(sessionId, connection.Rpc, _logger); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); if (config.OnUserInputRequest != null) @@ -511,7 +511,7 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes // Create and register the session before issuing the RPC so that // events emitted by the CLI (e.g. session.start) are not dropped. - var session = new CopilotSession(sessionId, connection.Rpc); + var session = new CopilotSession(sessionId, connection.Rpc, _logger); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); if (config.OnUserInputRequest != null) diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 324b3df6..c2b95808 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -2,12 +2,15 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ +using GitHub.Copilot.SDK.Rpc; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; using StreamJsonRpc; +using System.Collections.Immutable; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; -using GitHub.Copilot.SDK.Rpc; +using System.Threading.Channels; namespace GitHub.Copilot.SDK; @@ -52,22 +55,27 @@ namespace GitHub.Copilot.SDK; /// public sealed partial class CopilotSession : IAsyncDisposable { - /// - /// Multicast delegate used as a thread-safe, insertion-ordered handler list. - /// The compiler-generated add/remove accessors use a lock-free CAS loop over the backing field. - /// Dispatch reads the field once (inherent snapshot, no allocation). - /// Expected handler count is small (typically 1–3), so Delegate.Combine/Remove cost is negligible. - /// - private event SessionEventHandler? EventHandlers; private readonly Dictionary _toolHandlers = []; private readonly JsonRpc _rpc; + private readonly ILogger _logger; + private volatile PermissionRequestHandler? _permissionHandler; private volatile UserInputHandler? _userInputHandler; + private ImmutableArray _eventHandlers = ImmutableArray.Empty; + private SessionHooks? _hooks; private readonly SemaphoreSlim _hooksLock = new(1, 1); private SessionRpc? _sessionRpc; private int _isDisposed; + /// + /// Channel that serializes event dispatch. enqueues; + /// a single background consumer () dequeues and + /// invokes handlers one at a time, preserving arrival order. + /// + private readonly Channel _eventChannel = Channel.CreateUnbounded( + new() { SingleReader = true }); + /// /// Gets the unique identifier for this session. /// @@ -93,15 +101,20 @@ public sealed partial class CopilotSession : IAsyncDisposable /// /// The unique identifier for this session. /// The JSON-RPC connection to the Copilot CLI. + /// Logger for diagnostics. /// The workspace path if infinite sessions are enabled. /// /// This constructor is internal. Use to create sessions. /// - internal CopilotSession(string sessionId, JsonRpc rpc, string? workspacePath = null) + internal CopilotSession(string sessionId, JsonRpc rpc, ILogger logger, string? workspacePath = null) { SessionId = sessionId; _rpc = rpc; + _logger = logger; WorkspacePath = workspacePath; + + // Start the asynchronous processing loop. + _ = ProcessEventsAsync(); } private Task InvokeRpcAsync(string method, object?[]? args, CancellationToken cancellationToken) @@ -186,7 +199,7 @@ public async Task SendAsync(MessageOptions options, CancellationToken ca CancellationToken cancellationToken = default) { var effectiveTimeout = timeout ?? TimeSpan.FromSeconds(60); - var tcs = new TaskCompletionSource(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); AssistantMessageEvent? lastAssistantMessage = null; void Handler(SessionEvent evt) @@ -236,7 +249,9 @@ void Handler(SessionEvent evt) /// Multiple handlers can be registered and will all receive events. /// /// - /// Handler exceptions are allowed to propagate so they are not lost. + /// Handlers are invoked serially in event-arrival order on a background thread. + /// A handler will never be called concurrently with itself or with other handlers + /// on the same session. /// /// /// @@ -259,27 +274,53 @@ void Handler(SessionEvent evt) /// public IDisposable On(SessionEventHandler handler) { - EventHandlers += handler; - return new ActionDisposable(() => EventHandlers -= handler); + ImmutableInterlocked.Update(ref _eventHandlers, array => array.Add(handler)); + return new ActionDisposable(() => ImmutableInterlocked.Update(ref _eventHandlers, array => array.Remove(handler))); } /// - /// Dispatches an event to all registered handlers. + /// Enqueues an event for serial dispatch to all registered handlers. /// /// The session event to dispatch. /// - /// This method is internal. Handler exceptions are allowed to propagate so they are not lost. - /// Broadcast request events (external_tool.requested, permission.requested) are handled - /// internally before being forwarded to user handlers. + /// This method is non-blocking. Broadcast request events (external_tool.requested, + /// permission.requested) are fired concurrently so that a stalled handler does not + /// block event delivery. The event is then placed into an in-memory channel and + /// processed by a single background consumer (), + /// which guarantees user handlers see events one at a time, in order. /// internal void DispatchEvent(SessionEvent sessionEvent) { - // Handle broadcast request events (protocol v3) before dispatching to user handlers. - // Fire-and-forget: the response is sent asynchronously via RPC. - HandleBroadcastEventAsync(sessionEvent); + // Fire broadcast work concurrently (fire-and-forget with error logging). + // This is done outside the channel so broadcast handlers don't block the + // consumer loop — important when a secondary client's handler intentionally + // never completes (multi-client permission scenario). + _ = HandleBroadcastEventAsync(sessionEvent); + + // Queue the event for serial processing by user handlers. + _eventChannel.Writer.TryWrite(sessionEvent); + } - // Reading the field once gives us a snapshot; delegates are immutable. - EventHandlers?.Invoke(sessionEvent); + /// + /// Single-reader consumer loop that processes events from the channel. + /// Ensures user event handlers are invoked serially and in FIFO order. + /// + private async Task ProcessEventsAsync() + { + await foreach (var sessionEvent in _eventChannel.Reader.ReadAllAsync()) + { + foreach (var handler in _eventHandlers) + { + try + { + handler(sessionEvent); + } + catch (Exception ex) + { + LogEventHandlerError(ex); + } + } + } } /// @@ -355,37 +396,44 @@ internal async Task HandlePermissionRequestAsync(JsonEl /// Implements the protocol v3 broadcast model where tool calls and permission requests /// are broadcast as session events to all clients. /// - private async void HandleBroadcastEventAsync(SessionEvent sessionEvent) + private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) { - switch (sessionEvent) + try { - case ExternalToolRequestedEvent toolEvent: - { - var data = toolEvent.Data; - if (string.IsNullOrEmpty(data.RequestId) || string.IsNullOrEmpty(data.ToolName)) - return; - - var tool = GetTool(data.ToolName); - if (tool is null) - return; // This client doesn't handle this tool; another client will. - - await ExecuteToolAndRespondAsync(data.RequestId, data.ToolName, data.ToolCallId, data.Arguments, tool); - break; - } - - case PermissionRequestedEvent permEvent: - { - var data = permEvent.Data; - if (string.IsNullOrEmpty(data.RequestId) || data.PermissionRequest is null) - return; - - var handler = _permissionHandler; - if (handler is null) - return; // This client doesn't handle permissions; another client will. - - await ExecutePermissionAndRespondAsync(data.RequestId, data.PermissionRequest, handler); - break; - } + switch (sessionEvent) + { + case ExternalToolRequestedEvent toolEvent: + { + var data = toolEvent.Data; + if (string.IsNullOrEmpty(data.RequestId) || string.IsNullOrEmpty(data.ToolName)) + return; + + var tool = GetTool(data.ToolName); + if (tool is null) + return; // This client doesn't handle this tool; another client will. + + await ExecuteToolAndRespondAsync(data.RequestId, data.ToolName, data.ToolCallId, data.Arguments, tool); + break; + } + + case PermissionRequestedEvent permEvent: + { + var data = permEvent.Data; + if (string.IsNullOrEmpty(data.RequestId) || data.PermissionRequest is null) + return; + + var handler = _permissionHandler; + if (handler is null) + return; // This client doesn't handle permissions; another client will. + + await ExecutePermissionAndRespondAsync(data.RequestId, data.PermissionRequest, handler); + break; + } + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + LogBroadcastHandlerError(ex); } } @@ -703,6 +751,11 @@ public async Task LogAsync(string message, SessionLogRequestLevel? level = null, /// A task representing the dispose operation. /// /// + /// The caller should ensure the session is idle (e.g., + /// has returned) before disposing. If the session is not idle, in-flight event handlers + /// or tool handlers may observe failures. + /// + /// /// Session state on disk (conversation history, planning state, artifacts) is /// preserved, so the conversation can be resumed later by calling /// with the session ID. To @@ -731,6 +784,8 @@ public async ValueTask DisposeAsync() return; } + _eventChannel.Writer.TryComplete(); + try { await InvokeRpcAsync( @@ -745,12 +800,18 @@ await InvokeRpcAsync( // Connection is broken or closed } - EventHandlers = null; + _eventHandlers = ImmutableInterlocked.InterlockedExchange(ref _eventHandlers, ImmutableArray.Empty); _toolHandlers.Clear(); _permissionHandler = null; } + [LoggerMessage(Level = LogLevel.Error, Message = "Unhandled exception in broadcast event handler")] + private partial void LogBroadcastHandlerError(Exception exception); + + [LoggerMessage(Level = LogLevel.Error, Message = "Unhandled exception in session event handler")] + private partial void LogEventHandlerError(Exception exception); + internal record SendMessageRequest { public string SessionId { get; init; } = string.Empty; diff --git a/dotnet/test/Harness/TestHelper.cs b/dotnet/test/Harness/TestHelper.cs index 6dd919bc..a04e4365 100644 --- a/dotnet/test/Harness/TestHelper.cs +++ b/dotnet/test/Harness/TestHelper.cs @@ -10,7 +10,7 @@ public static class TestHelper CopilotSession session, TimeSpan? timeout = null) { - var tcs = new TaskCompletionSource(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60)); AssistantMessageEvent? finalAssistantMessage = null; @@ -78,7 +78,7 @@ public static async Task GetNextEventOfTypeAsync( CopilotSession session, TimeSpan? timeout = null) where T : SessionEvent { - var tcs = new TaskCompletionSource(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60)); using var subscription = session.On(evt => diff --git a/dotnet/test/MultiClientTests.cs b/dotnet/test/MultiClientTests.cs index 131fd31d..69ee39c9 100644 --- a/dotnet/test/MultiClientTests.cs +++ b/dotnet/test/MultiClientTests.cs @@ -109,10 +109,10 @@ public async Task Both_Clients_See_Tool_Request_And_Completion_Events() }); // Set up event waiters BEFORE sending the prompt to avoid race conditions - var client1Requested = new TaskCompletionSource(); - var client2Requested = new TaskCompletionSource(); - var client1Completed = new TaskCompletionSource(); - var client2Completed = new TaskCompletionSource(); + var client1Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Requested = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client1Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var client2Completed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var sub1 = session1.On(evt => { diff --git a/dotnet/test/SessionTests.cs b/dotnet/test/SessionTests.cs index 80043958..256921c3 100644 --- a/dotnet/test/SessionTests.cs +++ b/dotnet/test/SessionTests.cs @@ -249,18 +249,40 @@ public async Task Should_Receive_Session_Events() // session.start is emitted during the session.create RPC; if the session // weren't registered in the sessions map before the RPC, it would be dropped. var earlyEvents = new List(); + var sessionStartReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var session = await CreateSessionAsync(new SessionConfig { - OnEvent = evt => earlyEvents.Add(evt), + OnEvent = evt => + { + earlyEvents.Add(evt); + if (evt is SessionStartEvent) + sessionStartReceived.TrySetResult(true); + }, }); + // session.start is dispatched asynchronously via the event channel; + // wait briefly for the consumer to deliver it. + var started = await Task.WhenAny(sessionStartReceived.Task, Task.Delay(TimeSpan.FromSeconds(5))); + Assert.Equal(sessionStartReceived.Task, started); Assert.Contains(earlyEvents, evt => evt is SessionStartEvent); var receivedEvents = new List(); - var idleReceived = new TaskCompletionSource(); + var idleReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var concurrentCount = 0; + var maxConcurrent = 0; session.On(evt => { + // Track concurrent handler invocations to verify serial dispatch. + var current = Interlocked.Increment(ref concurrentCount); + var seenMax = Volatile.Read(ref maxConcurrent); + if (current > seenMax) + Interlocked.CompareExchange(ref maxConcurrent, current, seenMax); + + Thread.Sleep(10); + + Interlocked.Decrement(ref concurrentCount); + receivedEvents.Add(evt); if (evt is SessionIdleEvent) { @@ -281,6 +303,9 @@ public async Task Should_Receive_Session_Events() Assert.Contains(receivedEvents, evt => evt is AssistantMessageEvent); Assert.Contains(receivedEvents, evt => evt is SessionIdleEvent); + // Events must be dispatched serially — never more than one handler invocation at a time. + Assert.Equal(1, maxConcurrent); + // Verify the assistant response contains the expected answer var assistantMessage = await TestHelper.GetFinalAssistantMessageAsync(session); Assert.NotNull(assistantMessage); @@ -452,6 +477,54 @@ await WaitForAsync(() => Assert.Equal("notification", ephemeralEvent.Data.InfoType); } + [Fact] + public async Task Handler_Exception_Does_Not_Halt_Event_Delivery() + { + var session = await CreateSessionAsync(); + var eventCount = 0; + var gotIdle = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + session.On(evt => + { + eventCount++; + + // Throw on the first event to verify the loop keeps going. + if (eventCount == 1) + throw new InvalidOperationException("boom"); + + if (evt is SessionIdleEvent) + gotIdle.TrySetResult(); + }); + + await session.SendAsync(new MessageOptions { Prompt = "What is 1+1?" }); + + await gotIdle.Task.WaitAsync(TimeSpan.FromSeconds(30)); + + // Handler saw more than just the first (throwing) event. + Assert.True(eventCount > 1); + } + + [Fact] + public async Task DisposeAsync_From_Handler_Does_Not_Deadlock() + { + var session = await CreateSessionAsync(); + var disposed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + session.On(evt => + { + if (evt is UserMessageEvent) + { + // Call DisposeAsync from within a handler — must not deadlock. + session.DisposeAsync().AsTask().ContinueWith(_ => disposed.TrySetResult()); + } + }); + + await session.SendAsync(new MessageOptions { Prompt = "What is 1+1?" }); + + // If this times out, we deadlocked. + await disposed.Task.WaitAsync(TimeSpan.FromSeconds(10)); + } + private static async Task WaitForAsync(Func condition, TimeSpan timeout) { var deadline = DateTime.UtcNow + timeout; diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 40f62d4c..7c24d6f5 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -589,27 +589,27 @@ func TestSession(t *testing.T) { ctx.ConfigureForTest(t) // Use OnEvent to capture events dispatched during session creation. - // session.start is emitted during the session.create RPC; if the session - // weren't registered in the sessions map before the RPC, it would be dropped. - var earlyEvents []copilot.SessionEvent + // session.start is emitted during the session.create RPC; with channel-based + // dispatch it may not have been delivered by the time CreateSession returns. + sessionStartCh := make(chan bool, 1) session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ OnPermissionRequest: copilot.PermissionHandler.ApproveAll, OnEvent: func(event copilot.SessionEvent) { - earlyEvents = append(earlyEvents, event) + if event.Type == "session.start" { + select { + case sessionStartCh <- true: + default: + } + } }, }) if err != nil { t.Fatalf("Failed to create session: %v", err) } - hasSessionStart := false - for _, evt := range earlyEvents { - if evt.Type == "session.start" { - hasSessionStart = true - break - } - } - if !hasSessionStart { + select { + case <-sessionStartCh: + case <-time.After(5 * time.Second): t.Error("Expected session.start event via OnEvent during creation") } diff --git a/go/session.go b/go/session.go index 74529c52..bcfc3899 100644 --- a/go/session.go +++ b/go/session.go @@ -65,6 +65,11 @@ type Session struct { hooks *SessionHooks hooksMux sync.RWMutex + // eventCh serializes user event handler dispatch. dispatchEvent enqueues; + // a single goroutine (processEvents) dequeues and invokes handlers in FIFO order. + eventCh chan SessionEvent + closeOnce sync.Once // guards eventCh close so Disconnect is safe to call more than once + // RPC provides typed session-scoped RPC methods. RPC *rpc.SessionRpc } @@ -78,14 +83,17 @@ func (s *Session) WorkspacePath() string { // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { - return &Session{ + s := &Session{ SessionID: sessionID, workspacePath: workspacePath, client: client, handlers: make([]sessionHandler, 0), toolHandlers: make(map[string]ToolHandler), + eventCh: make(chan SessionEvent, 128), RPC: rpc.NewSessionRpc(client, sessionID), } + go s.processEvents() + return s } // Send sends a message to this session and waits for the response. @@ -435,36 +443,59 @@ func (s *Session) handleHooksInvoke(hookType string, rawInput json.RawMessage) ( } } -// dispatchEvent dispatches an event to all registered handlers. -// This is an internal method; handlers are called synchronously and any panics -// are recovered to prevent crashing the event dispatcher. +// dispatchEvent enqueues an event for delivery to user handlers and fires +// broadcast handlers concurrently. +// +// Broadcast work (tool calls, permission requests) is fired in a separate +// goroutine so it does not block the JSON-RPC read loop. User event handlers +// are delivered by a single consumer goroutine (processEvents), guaranteeing +// serial, FIFO dispatch without blocking the read loop. func (s *Session) dispatchEvent(event SessionEvent) { - // Handle broadcast request events internally (fire-and-forget) - s.handleBroadcastEvent(event) + go s.handleBroadcastEvent(event) + + // Send to the event channel in a closure with a recover guard. + // Disconnect closes eventCh, and in Go sending on a closed channel + // panics — there is no non-panicking send primitive. We only want + // to suppress that specific panic; other panics are not expected here. + func() { + defer func() { recover() }() + s.eventCh <- event + }() +} - s.handlerMutex.RLock() - handlers := make([]SessionEventHandler, 0, len(s.handlers)) - for _, h := range s.handlers { - handlers = append(handlers, h.fn) - } - s.handlerMutex.RUnlock() - - for _, handler := range handlers { - // Call handler - don't let panics crash the dispatcher - func() { - defer func() { - if r := recover(); r != nil { - fmt.Printf("Error in session event handler: %v\n", r) - } +// processEvents is the single consumer goroutine for the event channel. +// It invokes user handlers serially, in arrival order. Panics in individual +// handlers are recovered so that one misbehaving handler does not prevent +// others from receiving the event. +func (s *Session) processEvents() { + for event := range s.eventCh { + s.handlerMutex.RLock() + handlers := make([]SessionEventHandler, 0, len(s.handlers)) + for _, h := range s.handlers { + handlers = append(handlers, h.fn) + } + s.handlerMutex.RUnlock() + + for _, handler := range handlers { + func() { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Error in session event handler: %v\n", r) + } + }() + handler(event) }() - handler(event) - }() + } } } // handleBroadcastEvent handles broadcast request events by executing local handlers // and responding via RPC. This implements the protocol v3 broadcast model where tool // calls and permission requests are broadcast as session events to all clients. +// +// Handlers are executed in their own goroutine (not the JSON-RPC read loop or the +// event consumer loop) so that a stalled handler does not block event delivery or +// cause RPC deadlocks. func (s *Session) handleBroadcastEvent(event SessionEvent) { switch event.Type { case ExternalToolRequested: @@ -481,7 +512,7 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { if event.Data.ToolCallID != nil { toolCallID = *event.Data.ToolCallID } - go s.executeToolAndRespond(*requestID, *toolName, toolCallID, event.Data.Arguments, handler) + s.executeToolAndRespond(*requestID, *toolName, toolCallID, event.Data.Arguments, handler) case PermissionRequested: requestID := event.Data.RequestID @@ -492,7 +523,7 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { if handler == nil { return } - go s.executePermissionAndRespond(*requestID, *event.Data.PermissionRequest, handler) + s.executePermissionAndRespond(*requestID, *event.Data.PermissionRequest, handler) } } @@ -610,6 +641,10 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // Disconnect closes this session and releases all in-memory resources (event // handlers, tool handlers, permission handlers). // +// The caller should ensure the session is idle (e.g., [Session.SendAndWait] has +// returned) before disconnecting. If the session is not idle, in-flight event +// handlers or tool handlers may observe failures. +// // Session state on disk (conversation history, planning state, artifacts) is // preserved, so the conversation can be resumed later by calling // [Client.ResumeSession] with the session ID. To permanently remove all @@ -631,6 +666,8 @@ func (s *Session) Disconnect() error { return fmt.Errorf("failed to disconnect session: %w", err) } + s.closeOnce.Do(func() { close(s.eventCh) }) + // Clear handlers s.handlerMutex.Lock() s.handlers = nil diff --git a/go/session_test.go b/go/session_test.go index 40874a65..664c06e5 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -2,21 +2,36 @@ package copilot import ( "sync" + "sync/atomic" "testing" + "time" ) +// newTestSession creates a session with an event channel and starts the consumer goroutine. +// Returns a cleanup function that closes the channel (stopping the consumer). +func newTestSession() (*Session, func()) { + s := &Session{ + handlers: make([]sessionHandler, 0), + eventCh: make(chan SessionEvent, 128), + } + go s.processEvents() + return s, func() { close(s.eventCh) } +} + func TestSession_On(t *testing.T) { t.Run("multiple handlers all receive events", func(t *testing.T) { - session := &Session{ - handlers: make([]sessionHandler, 0), - } + session, cleanup := newTestSession() + defer cleanup() + var wg sync.WaitGroup + wg.Add(3) var received1, received2, received3 bool - session.On(func(event SessionEvent) { received1 = true }) - session.On(func(event SessionEvent) { received2 = true }) - session.On(func(event SessionEvent) { received3 = true }) + session.On(func(event SessionEvent) { received1 = true; wg.Done() }) + session.On(func(event SessionEvent) { received2 = true; wg.Done() }) + session.On(func(event SessionEvent) { received3 = true; wg.Done() }) session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() if !received1 || !received2 || !received3 { t.Errorf("Expected all handlers to receive event, got received1=%v, received2=%v, received3=%v", @@ -25,68 +40,81 @@ func TestSession_On(t *testing.T) { }) t.Run("unsubscribing one handler does not affect others", func(t *testing.T) { - session := &Session{ - handlers: make([]sessionHandler, 0), - } + session, cleanup := newTestSession() + defer cleanup() + + var count1, count2, count3 atomic.Int32 + var wg sync.WaitGroup - var count1, count2, count3 int - session.On(func(event SessionEvent) { count1++ }) - unsub2 := session.On(func(event SessionEvent) { count2++ }) - session.On(func(event SessionEvent) { count3++ }) + wg.Add(3) + session.On(func(event SessionEvent) { count1.Add(1); wg.Done() }) + unsub2 := session.On(func(event SessionEvent) { count2.Add(1); wg.Done() }) + session.On(func(event SessionEvent) { count3.Add(1); wg.Done() }) // First event - all handlers receive it session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() // Unsubscribe handler 2 unsub2() // Second event - only handlers 1 and 3 should receive it + wg.Add(2) session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() - if count1 != 2 { - t.Errorf("Expected handler 1 to receive 2 events, got %d", count1) + if count1.Load() != 2 { + t.Errorf("Expected handler 1 to receive 2 events, got %d", count1.Load()) } - if count2 != 1 { - t.Errorf("Expected handler 2 to receive 1 event (before unsubscribe), got %d", count2) + if count2.Load() != 1 { + t.Errorf("Expected handler 2 to receive 1 event (before unsubscribe), got %d", count2.Load()) } - if count3 != 2 { - t.Errorf("Expected handler 3 to receive 2 events, got %d", count3) + if count3.Load() != 2 { + t.Errorf("Expected handler 3 to receive 2 events, got %d", count3.Load()) } }) t.Run("calling unsubscribe multiple times is safe", func(t *testing.T) { - session := &Session{ - handlers: make([]sessionHandler, 0), - } + session, cleanup := newTestSession() + defer cleanup() + + var count atomic.Int32 + var wg sync.WaitGroup - var count int - unsub := session.On(func(event SessionEvent) { count++ }) + wg.Add(1) + unsub := session.On(func(event SessionEvent) { count.Add(1); wg.Done() }) session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() - // Call unsubscribe multiple times - should not panic unsub() unsub() unsub() + // Dispatch again and wait for it to be processed via a sentinel handler + wg.Add(1) + session.On(func(event SessionEvent) { wg.Done() }) session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() - if count != 1 { - t.Errorf("Expected handler to receive 1 event, got %d", count) + if count.Load() != 1 { + t.Errorf("Expected handler to receive 1 event, got %d", count.Load()) } }) t.Run("handlers are called in registration order", func(t *testing.T) { - session := &Session{ - handlers: make([]sessionHandler, 0), - } + session, cleanup := newTestSession() + defer cleanup() var order []int - session.On(func(event SessionEvent) { order = append(order, 1) }) - session.On(func(event SessionEvent) { order = append(order, 2) }) - session.On(func(event SessionEvent) { order = append(order, 3) }) + var wg sync.WaitGroup + wg.Add(3) + session.On(func(event SessionEvent) { order = append(order, 1); wg.Done() }) + session.On(func(event SessionEvent) { order = append(order, 2); wg.Done() }) + session.On(func(event SessionEvent) { order = append(order, 3); wg.Done() }) session.dispatchEvent(SessionEvent{Type: "test"}) + wg.Wait() if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 { t.Errorf("Expected handlers to be called in order [1,2,3], got %v", order) @@ -94,9 +122,8 @@ func TestSession_On(t *testing.T) { }) t.Run("concurrent subscribe and unsubscribe is safe", func(t *testing.T) { - session := &Session{ - handlers: make([]sessionHandler, 0), - } + session, cleanup := newTestSession() + defer cleanup() var wg sync.WaitGroup for i := 0; i < 100; i++ { @@ -109,7 +136,6 @@ func TestSession_On(t *testing.T) { } wg.Wait() - // Should not panic and handlers should be empty session.handlerMutex.RLock() count := len(session.handlers) session.handlerMutex.RUnlock() @@ -118,4 +144,63 @@ func TestSession_On(t *testing.T) { t.Errorf("Expected 0 handlers after all unsubscribes, got %d", count) } }) + + t.Run("events are dispatched serially", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + var concurrentCount atomic.Int32 + var maxConcurrent atomic.Int32 + var done sync.WaitGroup + const totalEvents = 5 + done.Add(totalEvents) + + session.On(func(event SessionEvent) { + current := concurrentCount.Add(1) + if current > maxConcurrent.Load() { + maxConcurrent.Store(current) + } + + time.Sleep(10 * time.Millisecond) + + concurrentCount.Add(-1) + done.Done() + }) + + for i := 0; i < totalEvents; i++ { + session.dispatchEvent(SessionEvent{Type: "test"}) + } + + done.Wait() + + if max := maxConcurrent.Load(); max != 1 { + t.Errorf("Expected max concurrent count of 1, got %d", max) + } + }) + + t.Run("handler panic does not halt delivery", func(t *testing.T) { + session, cleanup := newTestSession() + defer cleanup() + + var eventCount atomic.Int32 + var done sync.WaitGroup + done.Add(2) + + session.On(func(event SessionEvent) { + count := eventCount.Add(1) + defer done.Done() + if count == 1 { + panic("boom") + } + }) + + session.dispatchEvent(SessionEvent{Type: "test"}) + session.dispatchEvent(SessionEvent{Type: "test"}) + + done.Wait() + + if eventCount.Load() != 2 { + t.Errorf("Expected 2 events dispatched, got %d", eventCount.Load()) + } + }) } diff --git a/test/snapshots/session/disposeasync_from_handler_does_not_deadlock.yaml b/test/snapshots/session/disposeasync_from_handler_does_not_deadlock.yaml new file mode 100644 index 00000000..7c4d4699 --- /dev/null +++ b/test/snapshots/session/disposeasync_from_handler_does_not_deadlock.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1+1 = 2 diff --git a/test/snapshots/session/handler_exception_does_not_halt_event_delivery.yaml b/test/snapshots/session/handler_exception_does_not_halt_event_delivery.yaml new file mode 100644 index 00000000..7c4d4699 --- /dev/null +++ b/test/snapshots/session/handler_exception_does_not_halt_event_delivery.yaml @@ -0,0 +1,10 @@ +models: + - claude-sonnet-4.5 +conversations: + - messages: + - role: system + content: ${system} + - role: user + content: What is 1+1? + - role: assistant + content: 1+1 = 2