diff --git a/dotnet/src/webdriver/BiDi/Broker.cs b/dotnet/src/webdriver/BiDi/Broker.cs index 2cfa324838317..6419709d315d9 100644 --- a/dotnet/src/webdriver/BiDi/Broker.cs +++ b/dotnet/src/webdriver/BiDi/Broker.cs @@ -21,6 +21,7 @@ using System.Collections.Concurrent; using System.Text.Json; using System.Text.Json.Serialization.Metadata; +using System.Threading.Channels; using OpenQA.Selenium.BiDi.Session; using OpenQA.Selenium.Internal.Logging; @@ -28,6 +29,12 @@ namespace OpenQA.Selenium.BiDi; internal sealed class Broker : IAsyncDisposable { + // Limits how many received messages can be buffered before backpressure is applied to the transport. + private const int ReceivedMessageQueueCapacity = 16; + + // How long to wait for a command response before cancelling. + private static readonly TimeSpan DefaultCommandTimeout = TimeSpan.FromSeconds(30); + private readonly ILogger _logger = Internal.Logging.Log.GetLogger(); private readonly ITransport _transport; @@ -38,7 +45,16 @@ internal sealed class Broker : IAsyncDisposable private long _currentCommandId; - private readonly Task _receivingMessageTask; + private readonly Channel _receivedMessages = Channel.CreateBounded( + new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = true, SingleWriter = true, FullMode = BoundedChannelFullMode.Wait }); + + private readonly Channel _bufferPool = Channel.CreateBounded( + new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = false, SingleWriter = false }); + + private volatile Exception? _terminalReceiveException; + + private readonly Task _receivingTask; + private readonly Task _processingTask; private readonly CancellationTokenSource _receiveMessagesCancellationTokenSource; public Broker(ITransport transport, IBiDi bidi, Func sessionProvider) @@ -48,7 +64,8 @@ public Broker(ITransport transport, IBiDi bidi, Func sessionProv _eventDispatcher = new EventDispatcher(sessionProvider); _receiveMessagesCancellationTokenSource = new CancellationTokenSource(); - _receivingMessageTask = Task.Run(() => ReceiveMessagesLoopAsync(_receiveMessagesCancellationTokenSource.Token)); + _receivingTask = Task.Run(() => ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token)); + _processingTask = Task.Run(ProcessMessagesAsync); } public Task SubscribeAsync(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo jsonTypeInfo, CancellationToken cancellationToken) @@ -61,6 +78,11 @@ public async Task ExecuteCommandAsync(TCommand comma where TCommand : Command where TResult : EmptyResult { + if (_terminalReceiveException is { } terminalException) + { + throw new BiDiException("The broker is no longer processing messages due to a transport error.", terminalException); + } + command.Id = Interlocked.Increment(ref _currentCommandId); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -69,42 +91,49 @@ public async Task ExecuteCommandAsync(TCommand comma ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : new CancellationTokenSource(); - var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30); + var timeout = options?.Timeout ?? DefaultCommandTimeout; cts.CancelAfter(timeout); - using var sendBuffer = new PooledBufferWriter(); + var sendBuffer = RentBuffer(); - using (var writer = new Utf8JsonWriter(sendBuffer)) + try { - JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo); - } + using (var writer = new Utf8JsonWriter(sendBuffer)) + { + JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo); + } - var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo); - _pendingCommands[command.Id] = commandInfo; + var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo); + _pendingCommands[command.Id] = commandInfo; - using var ctsRegistration = cts.Token.Register(() => - { - tcs.TrySetCanceled(cts.Token); - _pendingCommands.TryRemove(command.Id, out _); - }); + using var ctsRegistration = cts.Token.Register(() => + { + tcs.TrySetCanceled(cts.Token); + _pendingCommands.TryRemove(command.Id, out _); + }); - try - { - if (_logger.IsEnabled(LogEventLevel.Trace)) + try { + if (_logger.IsEnabled(LogEventLevel.Trace)) + { #if NET8_0_OR_GREATER - _logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}"); + _logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}"); #else - _logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}"); + _logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}"); #endif - } + } - await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false); + await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false); + } + catch + { + _pendingCommands.TryRemove(command.Id, out _); + throw; + } } - catch + finally { - _pendingCommands.TryRemove(command.Id, out _); - throw; + ReturnBuffer(sendBuffer); } return (TResult)await tcs.Task.ConfigureAwait(false); @@ -114,22 +143,32 @@ public async ValueTask DisposeAsync() { _receiveMessagesCancellationTokenSource.Cancel(); - await _eventDispatcher.DisposeAsync().ConfigureAwait(false); - try { - await _receivingMessageTask.ConfigureAwait(false); - } - catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested) - { - // Expected when cancellation is requested, ignore. - } + try + { + await _receivingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested) + { + // Expected when cancellation is requested, ignore. + } - _receiveMessagesCancellationTokenSource.Dispose(); + await _transport.DisposeAsync().ConfigureAwait(false); - await _transport.DisposeAsync().ConfigureAwait(false); + await _processingTask.ConfigureAwait(false); + + await _eventDispatcher.DisposeAsync().ConfigureAwait(false); + } + finally + { + _receiveMessagesCancellationTokenSource.Dispose(); - GC.SuppressFinalize(this); + while (_bufferPool.Reader.TryRead(out var buffer)) + { + buffer.Dispose(); + } + } } private void ProcessReceivedMessage(ReadOnlySpan data) @@ -281,30 +320,63 @@ private void ProcessReceivedMessage(ReadOnlySpan data) } } - private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken) + private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { - using var receiveBufferWriter = new PooledBufferWriter(); - try { while (!cancellationToken.IsCancellationRequested) { - receiveBufferWriter.Reset(); - - await _transport.ReceiveAsync(receiveBufferWriter, cancellationToken).ConfigureAwait(false); + var buffer = RentBuffer(); - if (_logger.IsEnabled(LogEventLevel.Trace)) + try { + await _transport.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false); + + if (_logger.IsEnabled(LogEventLevel.Trace)) + { #if NET8_0_OR_GREATER - _logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.Span)}"); + _logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.Span)}"); #else - _logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.ToArray())}"); + _logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.ToArray())}"); #endif + } + + await _receivedMessages.Writer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false); + } + catch + { + ReturnBuffer(buffer); + throw; } + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + if (_logger.IsEnabled(LogEventLevel.Error)) + { + _logger.Error($"Unhandled error occurred while receiving remote messages: {ex}"); + } + // Propagated via _terminalReceiveException; not rethrown to keep disposal orderly. + _terminalReceiveException = ex; + } + finally + { + _receivedMessages.Writer.TryComplete(); + } + } + + private async Task ProcessMessagesAsync() + { + var reader = _receivedMessages.Reader; + + while (await reader.WaitToReadAsync().ConfigureAwait(false)) + { + while (reader.TryRead(out var buffer)) + { try { - ProcessReceivedMessage(receiveBufferWriter.WrittenMemory.Span); + ProcessReceivedMessage(buffer.WrittenMemory.Span); } catch (Exception ex) { @@ -313,25 +385,43 @@ private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken) _logger.Error($"Unhandled error occurred while processing remote message: {ex}"); } } + finally + { + ReturnBuffer(buffer); + } } } - catch (Exception ex) when (ex is not OperationCanceledException) - { - if (_logger.IsEnabled(LogEventLevel.Error)) - { - _logger.Error($"Unhandled error occurred while receiving remote messages: {ex}"); - } - // Fail all pending commands, as the connection is likely broken if we failed to receive messages. - foreach (var id in _pendingCommands.Keys) + // Channel is fully drained. Fail any commands that didn't get a response: + // either with the transport error or cancellation for clean shutdown. + var terminalException = _terminalReceiveException; + foreach (var id in _pendingCommands.Keys) + { + if (_pendingCommands.TryRemove(id, out var pendingCommand)) { - if (_pendingCommands.TryRemove(id, out var pendingCommand)) + if (terminalException is not null) { - pendingCommand.TaskCompletionSource.TrySetException(ex); + pendingCommand.TaskCompletionSource.TrySetException(terminalException); + } + else + { + pendingCommand.TaskCompletionSource.TrySetCanceled(); } } + } + } + + private PooledBufferWriter RentBuffer() + { + return _bufferPool.Reader.TryRead(out var buffer) ? buffer : new PooledBufferWriter(); + } - throw; + private void ReturnBuffer(PooledBufferWriter buffer) + { + buffer.Reset(); + if (!_bufferPool.Writer.TryWrite(buffer)) + { + buffer.Dispose(); } } @@ -359,7 +449,13 @@ public void Reset() _written = 0; } - public void Advance(int count) => _written += count; + public void Advance(int count) + { + if (count < 0) throw new ArgumentOutOfRangeException(nameof(count)); + if (_written + count > (_buffer?.Length ?? 0)) throw new InvalidOperationException("Cannot advance past the end of the buffer."); + + _written += count; + } public Memory GetMemory(int sizeHint = 0) { @@ -377,8 +473,10 @@ private void EnsureCapacity(int sizeHint) { var buffer = _buffer ?? throw new ObjectDisposedException(nameof(PooledBufferWriter)); - if (sizeHint <= 0) sizeHint = buffer.Length - _written; - if (sizeHint <= 0) sizeHint = buffer.Length; + if (sizeHint <= 0) + { + sizeHint = Math.Max(1, buffer.Length - _written); + } if (_written + sizeHint > buffer.Length) { diff --git a/dotnet/test/webdriver/BiDi/BiDiFixture.cs b/dotnet/test/webdriver/BiDi/BiDiFixture.cs index 7c3d01674b740..057176d488c5c 100644 --- a/dotnet/test/webdriver/BiDi/BiDiFixture.cs +++ b/dotnet/test/webdriver/BiDi/BiDiFixture.cs @@ -57,7 +57,10 @@ public async Task BiDiTearDown() await bidi.DisposeAsync(); } - driver?.Dispose(); + if (driver is not null) + { + await driver.DisposeAsync(); + } } public class BiDiEnabledDriverOptions : DriverOptions diff --git a/dotnet/test/webdriver/BiDi/Session/SessionTests.cs b/dotnet/test/webdriver/BiDi/Session/SessionTests.cs index eda5184fb435d..57d3fac77426f 100644 --- a/dotnet/test/webdriver/BiDi/Session/SessionTests.cs +++ b/dotnet/test/webdriver/BiDi/Session/SessionTests.cs @@ -21,6 +21,13 @@ namespace OpenQA.Selenium.Tests.BiDi.Session; internal class SessionTests : BiDiTestFixture { + [Test] + public async Task ShouldHaveIdempotentDisposal() + { + await bidi.DisposeAsync(); + await bidi.DisposeAsync(); + } + [Test] public async Task CanGetStatus() {