Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 157 additions & 59 deletions dotnet/src/webdriver/BiDi/Broker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@
using System.Collections.Concurrent;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading.Channels;
using OpenQA.Selenium.BiDi.Session;
using OpenQA.Selenium.Internal.Logging;

namespace OpenQA.Selenium.BiDi;

internal sealed class Broker : IAsyncDisposable
{
// Limits how many received messages can be buffered before backpressure is applied to the transport.
private const int ReceivedMessageQueueCapacity = 16;

// How long to wait for a command response before cancelling.
private static readonly TimeSpan DefaultCommandTimeout = TimeSpan.FromSeconds(30);

private readonly ILogger _logger = Internal.Logging.Log.GetLogger<Broker>();

private readonly ITransport _transport;
Expand All @@ -38,7 +45,16 @@ internal sealed class Broker : IAsyncDisposable

private long _currentCommandId;

private readonly Task _receivingMessageTask;
private readonly Channel<PooledBufferWriter> _receivedMessages = Channel.CreateBounded<PooledBufferWriter>(
new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = true, SingleWriter = true, FullMode = BoundedChannelFullMode.Wait });

private readonly Channel<PooledBufferWriter> _bufferPool = Channel.CreateBounded<PooledBufferWriter>(
new BoundedChannelOptions(ReceivedMessageQueueCapacity) { SingleReader = false, SingleWriter = false });

private volatile Exception? _terminalReceiveException;

private readonly Task _receivingTask;
private readonly Task _processingTask;
private readonly CancellationTokenSource _receiveMessagesCancellationTokenSource;

public Broker(ITransport transport, IBiDi bidi, Func<ISessionModule> sessionProvider)
Expand All @@ -48,7 +64,8 @@ public Broker(ITransport transport, IBiDi bidi, Func<ISessionModule> sessionProv
_eventDispatcher = new EventDispatcher(sessionProvider);

_receiveMessagesCancellationTokenSource = new CancellationTokenSource();
_receivingMessageTask = Task.Run(() => ReceiveMessagesLoopAsync(_receiveMessagesCancellationTokenSource.Token));
_receivingTask = Task.Run(() => ReceiveMessagesAsync(_receiveMessagesCancellationTokenSource.Token));
_processingTask = Task.Run(ProcessMessagesAsync);
}

public Task<Subscription> SubscribeAsync<TEventArgs>(string eventName, EventHandler eventHandler, SubscriptionOptions? options, JsonTypeInfo<TEventArgs> jsonTypeInfo, CancellationToken cancellationToken)
Expand All @@ -61,6 +78,11 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
where TCommand : Command
where TResult : EmptyResult
{
if (_terminalReceiveException is { } terminalException)
{
throw new BiDiException("The broker is no longer processing messages due to a transport error.", terminalException);
}

command.Id = Interlocked.Increment(ref _currentCommandId);

var tcs = new TaskCompletionSource<EmptyResult>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -69,42 +91,49 @@ public async Task<TResult> ExecuteCommandAsync<TCommand, TResult>(TCommand comma
? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)
: new CancellationTokenSource();

var timeout = options?.Timeout ?? TimeSpan.FromSeconds(30);
var timeout = options?.Timeout ?? DefaultCommandTimeout;
cts.CancelAfter(timeout);

using var sendBuffer = new PooledBufferWriter();
var sendBuffer = RentBuffer();

using (var writer = new Utf8JsonWriter(sendBuffer))
try
{
JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo);
}
using (var writer = new Utf8JsonWriter(sendBuffer))
{
JsonSerializer.Serialize(writer, command, jsonCommandTypeInfo);
}

var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo);
_pendingCommands[command.Id] = commandInfo;
var commandInfo = new CommandInfo(tcs, jsonResultTypeInfo);
_pendingCommands[command.Id] = commandInfo;

using var ctsRegistration = cts.Token.Register(() =>
{
tcs.TrySetCanceled(cts.Token);
_pendingCommands.TryRemove(command.Id, out _);
});
using var ctsRegistration = cts.Token.Register(() =>
{
tcs.TrySetCanceled(cts.Token);
_pendingCommands.TryRemove(command.Id, out _);
});

try
{
if (_logger.IsEnabled(LogEventLevel.Trace))
try
{
if (_logger.IsEnabled(LogEventLevel.Trace))
{
#if NET8_0_OR_GREATER
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}");
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.Span)}");
#else
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}");
_logger.Trace($"BiDi SND --> {System.Text.Encoding.UTF8.GetString(sendBuffer.WrittenMemory.ToArray())}");
#endif
}
}

await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false);
await _transport.SendAsync(sendBuffer.WrittenMemory, cts.Token).ConfigureAwait(false);
}
catch
{
_pendingCommands.TryRemove(command.Id, out _);
throw;
}
}
catch
finally
{
_pendingCommands.TryRemove(command.Id, out _);
throw;
ReturnBuffer(sendBuffer);
}

return (TResult)await tcs.Task.ConfigureAwait(false);
Expand All @@ -114,22 +143,32 @@ public async ValueTask DisposeAsync()
{
_receiveMessagesCancellationTokenSource.Cancel();

await _eventDispatcher.DisposeAsync().ConfigureAwait(false);

try
{
await _receivingMessageTask.ConfigureAwait(false);
}
catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested)
{
// Expected when cancellation is requested, ignore.
}
try
{
await _receivingTask.ConfigureAwait(false);
}
catch (OperationCanceledException) when (_receiveMessagesCancellationTokenSource.IsCancellationRequested)
{
// Expected when cancellation is requested, ignore.
}

_receiveMessagesCancellationTokenSource.Dispose();
await _transport.DisposeAsync().ConfigureAwait(false);

await _transport.DisposeAsync().ConfigureAwait(false);
await _processingTask.ConfigureAwait(false);

await _eventDispatcher.DisposeAsync().ConfigureAwait(false);
}
finally
{
_receiveMessagesCancellationTokenSource.Dispose();

GC.SuppressFinalize(this);
while (_bufferPool.Reader.TryRead(out var buffer))
{
buffer.Dispose();
}
}
}

private void ProcessReceivedMessage(ReadOnlySpan<byte> data)
Expand Down Expand Up @@ -281,30 +320,63 @@ private void ProcessReceivedMessage(ReadOnlySpan<byte> data)
}
}

private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken)
private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
{
using var receiveBufferWriter = new PooledBufferWriter();

try
{
while (!cancellationToken.IsCancellationRequested)
{
receiveBufferWriter.Reset();

await _transport.ReceiveAsync(receiveBufferWriter, cancellationToken).ConfigureAwait(false);
var buffer = RentBuffer();

if (_logger.IsEnabled(LogEventLevel.Trace))
try
{
await _transport.ReceiveAsync(buffer, cancellationToken).ConfigureAwait(false);

if (_logger.IsEnabled(LogEventLevel.Trace))
{
#if NET8_0_OR_GREATER
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.Span)}");
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.Span)}");
#else
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(receiveBufferWriter.WrittenMemory.ToArray())}");
_logger.Trace($"BiDi RCV <-- {System.Text.Encoding.UTF8.GetString(buffer.WrittenMemory.ToArray())}");
#endif
}

await _receivedMessages.Writer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
}
catch
{
ReturnBuffer(buffer);
throw;
}
}
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
if (_logger.IsEnabled(LogEventLevel.Error))
{
_logger.Error($"Unhandled error occurred while receiving remote messages: {ex}");
}

// Propagated via _terminalReceiveException; not rethrown to keep disposal orderly.
_terminalReceiveException = ex;
}
finally
{
_receivedMessages.Writer.TryComplete();
}
}

private async Task ProcessMessagesAsync()
{
var reader = _receivedMessages.Reader;

while (await reader.WaitToReadAsync().ConfigureAwait(false))
{
while (reader.TryRead(out var buffer))
{
try
{
ProcessReceivedMessage(receiveBufferWriter.WrittenMemory.Span);
ProcessReceivedMessage(buffer.WrittenMemory.Span);
}
catch (Exception ex)
{
Expand All @@ -313,25 +385,43 @@ private async Task ReceiveMessagesLoopAsync(CancellationToken cancellationToken)
_logger.Error($"Unhandled error occurred while processing remote message: {ex}");
}
}
finally
{
ReturnBuffer(buffer);
}
}
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
if (_logger.IsEnabled(LogEventLevel.Error))
{
_logger.Error($"Unhandled error occurred while receiving remote messages: {ex}");
}

// Fail all pending commands, as the connection is likely broken if we failed to receive messages.
foreach (var id in _pendingCommands.Keys)
// Channel is fully drained. Fail any commands that didn't get a response:
// either with the transport error or cancellation for clean shutdown.
var terminalException = _terminalReceiveException;
foreach (var id in _pendingCommands.Keys)
{
if (_pendingCommands.TryRemove(id, out var pendingCommand))
{
if (_pendingCommands.TryRemove(id, out var pendingCommand))
if (terminalException is not null)
{
pendingCommand.TaskCompletionSource.TrySetException(ex);
pendingCommand.TaskCompletionSource.TrySetException(terminalException);
}
else
{
pendingCommand.TaskCompletionSource.TrySetCanceled();
}
}
}
}

private PooledBufferWriter RentBuffer()
{
return _bufferPool.Reader.TryRead(out var buffer) ? buffer : new PooledBufferWriter();
}

throw;
private void ReturnBuffer(PooledBufferWriter buffer)
{
buffer.Reset();
if (!_bufferPool.Writer.TryWrite(buffer))
{
buffer.Dispose();
}
}

Expand Down Expand Up @@ -359,7 +449,13 @@ public void Reset()
_written = 0;
}

public void Advance(int count) => _written += count;
public void Advance(int count)
{
if (count < 0) throw new ArgumentOutOfRangeException(nameof(count));
if (_written + count > (_buffer?.Length ?? 0)) throw new InvalidOperationException("Cannot advance past the end of the buffer.");

_written += count;
}

public Memory<byte> GetMemory(int sizeHint = 0)
{
Expand All @@ -377,8 +473,10 @@ private void EnsureCapacity(int sizeHint)
{
var buffer = _buffer ?? throw new ObjectDisposedException(nameof(PooledBufferWriter));

if (sizeHint <= 0) sizeHint = buffer.Length - _written;
if (sizeHint <= 0) sizeHint = buffer.Length;
if (sizeHint <= 0)
{
sizeHint = Math.Max(1, buffer.Length - _written);
}

if (_written + sizeHint > buffer.Length)
{
Expand Down
5 changes: 4 additions & 1 deletion dotnet/test/webdriver/BiDi/BiDiFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ public async Task BiDiTearDown()
await bidi.DisposeAsync();
}

driver?.Dispose();
if (driver is not null)
{
await driver.DisposeAsync();
}
}

public class BiDiEnabledDriverOptions : DriverOptions
Expand Down
7 changes: 7 additions & 0 deletions dotnet/test/webdriver/BiDi/Session/SessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ namespace OpenQA.Selenium.Tests.BiDi.Session;

internal class SessionTests : BiDiTestFixture
{
[Test]
public async Task ShouldHaveIdempotentDisposal()
{
await bidi.DisposeAsync();
await bidi.DisposeAsync();
}

[Test]
public async Task CanGetStatus()
{
Expand Down
Loading