diff --git a/README.md b/README.md index 1f59c5c..d05b530 100755 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ SageMaker transport for the [Deepgram Java SDK](https://github.com/deepgram/deep ```groovy dependencies { - implementation 'com.deepgram:deepgram-java-sdk:0.2.1' + implementation 'com.deepgram:deepgram-java-sdk:0.4.0' implementation 'com.deepgram:deepgram-sagemaker:0.1.2' // x-release-please-version } ``` @@ -29,7 +29,7 @@ dependencies { ## Requirements - Java 11+ -- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.2.1+ +- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.4.0+ (the `default ReconnectOptions reconnectOptions()` hook on `DeepgramTransportFactory` is required for storm absorption) - AWS credentials configured (environment variables, shared credentials file, or IAM role) - A Deepgram model deployed to an AWS SageMaker endpoint @@ -95,6 +95,15 @@ The transport is transparent — the SDK API is identical whether using Deepgram |-----------|----------|---------|-------------| | `endpointName` | Yes | — | SageMaker endpoint name | | `region` | No | `us-west-2` | AWS region | +| `connectionTimeout` | No | `30s` | Max time for the underlying TCP/TLS connect (AWS Netty default is 2 s — bumped here so cold-start endpoints under burst load have time to accept TLS handshakes). | +| `connectionAcquireTimeout` | No | `60s` | Max time to acquire a connection from the Netty pool (AWS Netty default is 10 s — bumped so a 200–500-stream burst doesn't drain the acquire pool). | +| `subscriptionTimeout` | No | `60s` | Max time the transport waits for the AWS SDK to subscribe to the bidi-stream input publisher before failing. A timeout here is treated as a transient connect failure and counts against `maxRetries` / `retryBudget`. | +| `maxConcurrency` | No | `500` | Max simultaneous in-flight HTTP/2 streams across the shared Netty pool. With `maxStreams=1` this is the cap on simultaneous bidirectional streams. | +| `maxRetries` | No | `5` | Max retries on transient AWS errors (throttling, pool-exhausted, transient connect/timeout). Set to `0` to disable internal retry. Terminal errors (auth, validation) bypass this. | +| `initialBackoff` | No | `100ms` | First backoff delay applied after the initial failure. | +| `maxBackoff` | No | `5s` | Cap on the per-attempt backoff delay regardless of multiplier. | +| `backoffMultiplier` | No | `2.0` | Exponential growth factor between retry attempts. Must be `>= 1.0`. | +| `retryBudget` | No | `30s` | Total wall-clock cap across all retry attempts before giving up and surfacing the error to listeners. | ```java SageMakerConfig config = SageMakerConfig.builder() @@ -103,6 +112,88 @@ SageMakerConfig config = SageMakerConfig.builder() .build(); ``` +#### High-concurrency notes + +The transport's defaults are tuned for high-burst workloads (large numbers of +streams opened in a tight loop against an endpoint that may need to scale up). +If you're opening 200–500 streams simultaneously against a cold endpoint, +the AWS Netty defaults (2 s connect / 10 s acquire) will fire before +the load balancer has accepted all of the inbound TLS handshakes — you'll +see a wave of `connection acquire` and `connect timed out` errors that look +like server-side problems but are really client-side fail-fast tripping early. + +This transport ships with more lenient defaults (30 s / 60 s) so the +common high-concurrency path works out of the box. Tighten them if you need +fail-fast behavior in low-latency pipelines: + +```java +SageMakerConfig config = SageMakerConfig.builder() + .endpointName("my-deepgram-endpoint") + .region("us-east-2") + .connectionTimeout(Duration.ofSeconds(5)) + .connectionAcquireTimeout(Duration.ofSeconds(15)) + .build(); +``` + +#### Retry & storm absorption + +Transient AWS-side failures (`ThrottlingException`, connection-pool exhaustion, transient +connect/timeout failures) are absorbed by the transport itself: classified as retryable, retried +with exponential backoff up to `maxRetries` and `retryBudget`, with messages enqueued during the +reset window persisted across the reconnect so audio isn't dropped. Only **terminal** errors (auth, +validation) and budget-exhausted retryable errors propagate to `transport.onError(...)` and reach +the application's error handler. + +This means the SDK's wrapper-level reconnect (`ReconnectingWebSocketListener`) would compound the +plugin's internal retries into a Throttling-on-Throttling storm under burst load, so the plugin +declares `ReconnectOptions.builder().maxRetries(0).build()` via the +`DeepgramTransportFactory.reconnectOptions()` hook. The SDK applies it automatically when it sees +a `transportFactory` in use; no user wiring required. + +To tune retry behavior: + +```java +SageMakerConfig config = SageMakerConfig.builder() + .endpointName("my-deepgram-endpoint") + .maxRetries(10) + .initialBackoff(Duration.ofMillis(200)) + .maxBackoff(Duration.ofSeconds(10)) + .retryBudget(Duration.ofMinutes(1)) + .build(); +``` + +Set `maxRetries(0)` to disable internal retry entirely (every transient AWS error then surfaces +immediately to the application). + +#### Connection-pool sharing + +The default `new SageMakerTransportFactory(config)` constructor backs every factory instance with +a **process-wide shared** `SageMakerRuntimeHttp2AsyncClient`, keyed by the parts of +`SageMakerConfig` that affect the underlying Netty HTTP/2 client (region, max concurrency, +connect/acquire timeouts). Multiple factories built with the same config fingerprint reuse one +Netty event loop group and one connection pool — so naive code that constructs a fresh factory +per stream still gets a single, well-behaved client underneath. + +Without sharing, every factory instantiates its own Netty pool, and a burst of N factories +triggers N simultaneous TLS handshakes from N distinct Netty clients against the same SageMaker +endpoint. Under high concurrency (100+ streams) the SageMaker HTTP/2 frontline silently drops a +large fraction of those streams before they ever reach the model container — verified +end-to-end with CloudWatch logs from a 400-stream burst test against a 1× ml.g6.2xlarge endpoint: +without sharing, ~65% of streams never appeared in the Deepgram container's listen log; with +sharing, the burst behaves the same as the canonical Python load-test harness. + +Lifecycle: + +| Constructor | Client backing | `factory.shutdown()` | +|---|---|---| +| `SageMakerTransportFactory(config)` | shared (lazy-init, keyed by config fingerprint) | no-op — call `SageMakerTransportFactory.shutdownAllSharedClients()` once at app shutdown to release Netty resources | +| `SageMakerTransportFactory(config, smClient)` | caller-provided (BYO, used for testing or custom credential providers) | no-op — caller owns the client lifecycle | + +```java +// At app shutdown — releases all shared Netty pools the plugin lazily created. +Runtime.getRuntime().addShutdownHook(new Thread(SageMakerTransportFactory::shutdownAllSharedClients)); +``` + ### Custom AWS Client For custom credential providers, proxy configuration, or testing: diff --git a/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java b/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java index 89f3d55..ee2dacf 100644 --- a/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java +++ b/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java @@ -8,6 +8,7 @@ import com.deepgram.resources.listen.v2.websocket.V2WebSocketClient; import com.deepgram.sagemaker.SageMakerConfig; import com.deepgram.sagemaker.SageMakerTransportFactory; +import com.deepgram.types.ListenV2Model; import java.io.RandomAccessFile; import java.nio.ByteBuffer; @@ -85,10 +86,10 @@ public static void main(String[] args) throws Exception { done.countDown(); }); - // Connect — V2 uses model name as string via additionalProperty + // Connect using the typed Flux model constant from the SDK. CompletableFuture connectFuture = wsClient.connect( V2ConnectOptions.builder() - .model("flux-general-en") + .model(ListenV2Model.FLUX_GENERAL_EN) .build()); connectFuture.get(30, TimeUnit.SECONDS); System.out.println("Connected. Streaming audio...\n"); diff --git a/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java b/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java index b4e7d36..8c05588 100644 --- a/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java +++ b/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java @@ -9,6 +9,7 @@ import com.deepgram.sagemaker.SageMakerConfig; import com.deepgram.sagemaker.SageMakerTransportFactory; import com.deepgram.types.ListenV2Encoding; +import com.deepgram.types.ListenV2Model; import com.deepgram.types.ListenV2SampleRate; import javax.sound.sampled.AudioFormat; @@ -100,7 +101,7 @@ public static void main(String[] args) throws Exception { CompletableFuture connectFuture = wsClient.connect( V2ConnectOptions.builder() - .model("flux-general-en") + .model(ListenV2Model.FLUX_GENERAL_EN) .encoding(ListenV2Encoding.LINEAR16) .sampleRate(ListenV2SampleRate.of(16000)) .build()); diff --git a/pom.xml b/pom.xml index cfbadb2..1b63052 100644 --- a/pom.xml +++ b/pom.xml @@ -69,7 +69,7 @@ com.deepgram deepgram-java-sdk - 0.2.1 + 0.4.0 software.amazon.awssdk diff --git a/sagemaker-transport/build.gradle b/sagemaker-transport/build.gradle index e1e0a0a..f7a8536 100755 --- a/sagemaker-transport/build.gradle +++ b/sagemaker-transport/build.gradle @@ -1,6 +1,6 @@ dependencies { // Deepgram Java SDK — provides DeepgramTransport / DeepgramTransportFactory interfaces - api 'com.deepgram:deepgram-java-sdk:0.2.1' + api 'com.deepgram:deepgram-java-sdk:0.4.0' // AWS SDK v2 — SageMaker Runtime HTTP/2 bidirectional streaming api platform('software.amazon.awssdk:bom:2.42.0') diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java index f4996d7..42329f5 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java @@ -1,28 +1,156 @@ package com.deepgram.sagemaker; +import java.time.Duration; + import software.amazon.awssdk.regions.Region; /** * Configuration for connecting to a Deepgram model hosted on SageMaker. + * + *

Defaults are tuned for high-burst workloads (large numbers of streams + * opened in a tight loop against an endpoint that may need to scale up). They + * are intentionally more lenient than the AWS SDK Netty defaults so that + * 200–500-stream bursts don't trip connect-acquire / connect-handshake + * timeouts before the endpoint has had a chance to accept the inbound TLS + * handshakes. Tighten them if you want fail-fast behavior in low-latency + * pipelines. */ public class SageMakerConfig { + /** AWS Netty default is 2 s. Set to 30 s so cold endpoints under burst load can accept TLS. */ + public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(30); + + /** AWS Netty default is 10 s. Set to 60 s so a 400-stream burst doesn't drain the acquire pool. */ + public static final Duration DEFAULT_CONNECTION_ACQUIRE_TIMEOUT = Duration.ofSeconds(60); + + /** Time to wait for the AWS SDK to subscribe to the bidi-stream input publisher before failing. */ + public static final Duration DEFAULT_SUBSCRIPTION_TIMEOUT = Duration.ofSeconds(60); + + /** + * Max simultaneous in-flight HTTP/2 streams across the shared Netty connection + * pool. Combined with {@code maxStreams=1} (set by {@link SageMakerTransportFactory}), + * this is the cap on simultaneous bidirectional streams. + */ + public static final int DEFAULT_MAX_CONCURRENCY = 500; + + /** Max retries on transient AWS errors (throttling, pool-exhausted, transient connect) per stream. */ + public static final int DEFAULT_MAX_RETRIES = 5; + + /** First backoff delay after the initial failure. */ + public static final Duration DEFAULT_INITIAL_BACKOFF = Duration.ofMillis(100); + + /** Cap on the per-attempt backoff delay regardless of multiplier. */ + public static final Duration DEFAULT_MAX_BACKOFF = Duration.ofSeconds(5); + + /** Exponential growth factor applied between retry attempts. */ + public static final double DEFAULT_BACKOFF_MULTIPLIER = 2.0; + + /** Total wall-clock budget across all retry attempts before giving up and surfacing the error. */ + public static final Duration DEFAULT_RETRY_BUDGET = Duration.ofSeconds(30); + + /** + * Cap on the in-memory replay buffer that holds sent-but-unacked stream events for the + * current bidi stream attempt. If the SDK has to retry (throttling, post-subscription + * stream reset), this buffer is drained onto the new stream so AWS sees a continuous + * audio sequence instead of the gap created by the discarded events. + * + *

The buffer is trimmed when {@code handlePayloadPart} fires (the model just produced + * a transcript, so prior audio is acked from our perspective), so under steady-state + * operation it stays small (≤ a few hundred KB). It only grows during a throttle/reset + * window where no payload parts come back. 8 MiB ≈ 256 s of 16 kHz mono + * 16-bit PCM, which covers the longest throttle storms we've seen in practice with + * margin to spare. Lower the cap for tight memory budgets; raise it if you expect + * longer retry windows per stream. + */ + public static final long DEFAULT_MAX_REPLAY_BUFFER_BYTES = 8L * 1024 * 1024; + + /** + * Max concurrent HTTP/2 streams multiplexed onto a single underlying TCP connection. Defaults + * to 1, which gives each bidi stream its own dedicated TCP connection — preventing slow-stream + * starvation but creating one HTTP/2 keep-alive ping cycle per logical stream. Under heavy + * concurrent load (hundreds of simultaneous streams from one process), the resulting flood of + * pings can saturate the Netty event-loop pool and trigger spurious {@code PingFailedException} + * connection teardowns. Raise this (e.g. 50–200) to multiplex many streams onto fewer + * connections and slash the ping load. + */ + public static final long DEFAULT_MAX_STREAMS_PER_CONNECTION = 1L; + + /** + * Number of Netty event-loop worker threads handling HTTP/2 frames for the shared client. + * {@code null} (the default) lets the AWS SDK Netty client pick — currently {@code 2 * NCPU}. + * Override when running large numbers of concurrent streams on hardware where the default + * leaves event loops saturated by inbound transcript frames + ping ACK bookkeeping. + */ + public static final Integer DEFAULT_NETTY_EVENT_LOOP_THREADS = null; + + /** + * Period between HTTP/2 keep-alive PING frames sent to the server, and the timeout for the + * PING ACK. {@code null} (the default) leaves the AWS SDK Netty client default in place + * (currently 5 s). Under heavy single-process load (hundreds of concurrent streams + * sharing a small event-loop pool), 5 s is too tight: an event loop briefly busy processing + * inbound transcript frames can fail to read the PING ACK in time, causing + * {@code PingFailedException} → connection-death → cascading retries even though the + * underlying connection is healthy. Bump to 30 s+ to tolerate moderate event-loop + * stalls; pass {@code Duration.ZERO} to disable PING frames entirely (the SDK's own retry + * + stream-error path will still detect genuinely-dead connections). + */ + public static final Duration DEFAULT_HEALTH_CHECK_PING_PERIOD = null; + private final String endpointName; private final Region region; private final String contentType; private final String acceptType; + private final Duration connectionTimeout; + private final Duration connectionAcquireTimeout; + private final Duration subscriptionTimeout; + private final int maxConcurrency; + private final int maxRetries; + private final Duration initialBackoff; + private final Duration maxBackoff; + private final double backoffMultiplier; + private final Duration retryBudget; + private final long maxReplayBufferBytes; + private final long maxStreamsPerConnection; + private final Integer nettyEventLoopThreads; + private final Duration healthCheckPingPeriod; private SageMakerConfig(Builder builder) { this.endpointName = builder.endpointName; this.region = builder.region; this.contentType = builder.contentType; this.acceptType = builder.acceptType; + this.connectionTimeout = builder.connectionTimeout; + this.connectionAcquireTimeout = builder.connectionAcquireTimeout; + this.subscriptionTimeout = builder.subscriptionTimeout; + this.maxConcurrency = builder.maxConcurrency; + this.maxRetries = builder.maxRetries; + this.initialBackoff = builder.initialBackoff; + this.maxBackoff = builder.maxBackoff; + this.backoffMultiplier = builder.backoffMultiplier; + this.retryBudget = builder.retryBudget; + this.maxReplayBufferBytes = builder.maxReplayBufferBytes; + this.maxStreamsPerConnection = builder.maxStreamsPerConnection; + this.nettyEventLoopThreads = builder.nettyEventLoopThreads; + this.healthCheckPingPeriod = builder.healthCheckPingPeriod; } public String endpointName() { return endpointName; } public Region region() { return region; } public String contentType() { return contentType; } public String acceptType() { return acceptType; } + public Duration connectionTimeout() { return connectionTimeout; } + public Duration connectionAcquireTimeout() { return connectionAcquireTimeout; } + public Duration subscriptionTimeout() { return subscriptionTimeout; } + public int maxConcurrency() { return maxConcurrency; } + public int maxRetries() { return maxRetries; } + public Duration initialBackoff() { return initialBackoff; } + public Duration maxBackoff() { return maxBackoff; } + public double backoffMultiplier() { return backoffMultiplier; } + public Duration retryBudget() { return retryBudget; } + public long maxReplayBufferBytes() { return maxReplayBufferBytes; } + public long maxStreamsPerConnection() { return maxStreamsPerConnection; } + public Integer nettyEventLoopThreads() { return nettyEventLoopThreads; } + public Duration healthCheckPingPeriod() { return healthCheckPingPeriod; } public static Builder builder() { return new Builder(); @@ -33,6 +161,19 @@ public static class Builder { private Region region = Region.US_WEST_2; private String contentType = "application/octet-stream"; private String acceptType = "application/json"; + private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; + private Duration connectionAcquireTimeout = DEFAULT_CONNECTION_ACQUIRE_TIMEOUT; + private Duration subscriptionTimeout = DEFAULT_SUBSCRIPTION_TIMEOUT; + private int maxConcurrency = DEFAULT_MAX_CONCURRENCY; + private int maxRetries = DEFAULT_MAX_RETRIES; + private Duration initialBackoff = DEFAULT_INITIAL_BACKOFF; + private Duration maxBackoff = DEFAULT_MAX_BACKOFF; + private double backoffMultiplier = DEFAULT_BACKOFF_MULTIPLIER; + private Duration retryBudget = DEFAULT_RETRY_BUDGET; + private long maxReplayBufferBytes = DEFAULT_MAX_REPLAY_BUFFER_BYTES; + private long maxStreamsPerConnection = DEFAULT_MAX_STREAMS_PER_CONNECTION; + private Integer nettyEventLoopThreads = DEFAULT_NETTY_EVENT_LOOP_THREADS; + private Duration healthCheckPingPeriod = DEFAULT_HEALTH_CHECK_PING_PERIOD; public Builder endpointName(String endpointName) { this.endpointName = endpointName; @@ -59,10 +200,170 @@ public Builder acceptType(String acceptType) { return this; } + /** + * Max time to wait for the underlying TCP/TLS connect to complete. + * Forwards to {@code NettyNioAsyncHttpClient.Builder.connectionTimeout}. + */ + public Builder connectionTimeout(Duration connectionTimeout) { + if (connectionTimeout == null || connectionTimeout.isNegative() || connectionTimeout.isZero()) { + throw new IllegalArgumentException("connectionTimeout must be positive"); + } + this.connectionTimeout = connectionTimeout; + return this; + } + + /** + * Max time to wait when acquiring a connection from the Netty pool. + * Forwards to {@code NettyNioAsyncHttpClient.Builder.connectionAcquisitionTimeout}. + */ + public Builder connectionAcquireTimeout(Duration connectionAcquireTimeout) { + if (connectionAcquireTimeout == null + || connectionAcquireTimeout.isNegative() + || connectionAcquireTimeout.isZero()) { + throw new IllegalArgumentException("connectionAcquireTimeout must be positive"); + } + this.connectionAcquireTimeout = connectionAcquireTimeout; + return this; + } + + /** + * Max time the transport waits for the AWS SDK to subscribe to the + * bidirectional input publisher before failing the first send. + */ + public Builder subscriptionTimeout(Duration subscriptionTimeout) { + if (subscriptionTimeout == null + || subscriptionTimeout.isNegative() + || subscriptionTimeout.isZero()) { + throw new IllegalArgumentException("subscriptionTimeout must be positive"); + } + this.subscriptionTimeout = subscriptionTimeout; + return this; + } + + /** + * Max simultaneous in-flight HTTP/2 streams across the shared Netty pool. + * With {@code maxStreams=1} this equals the maximum number of concurrent + * bidirectional streams the factory can support. + */ + public Builder maxConcurrency(int maxConcurrency) { + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be positive"); + } + this.maxConcurrency = maxConcurrency; + return this; + } + + /** + * Max retries on transient AWS errors per stream invocation. Set to {@code 0} to disable + * internal retry. Transient errors include throttling, connection-pool exhaustion, and + * transient connect/timeout failures; terminal errors (auth, validation) bypass this and + * surface to the application immediately. + */ + public Builder maxRetries(int maxRetries) { + if (maxRetries < 0) { + throw new IllegalArgumentException("maxRetries must be non-negative"); + } + this.maxRetries = maxRetries; + return this; + } + + /** First backoff delay applied after the initial failure. */ + public Builder initialBackoff(Duration initialBackoff) { + if (initialBackoff == null || initialBackoff.isNegative() || initialBackoff.isZero()) { + throw new IllegalArgumentException("initialBackoff must be positive"); + } + this.initialBackoff = initialBackoff; + return this; + } + + /** Cap on the per-attempt backoff delay regardless of multiplier. */ + public Builder maxBackoff(Duration maxBackoff) { + if (maxBackoff == null || maxBackoff.isNegative() || maxBackoff.isZero()) { + throw new IllegalArgumentException("maxBackoff must be positive"); + } + this.maxBackoff = maxBackoff; + return this; + } + + /** Exponential growth factor applied between retry attempts. Must be {@code >= 1.0}. */ + public Builder backoffMultiplier(double backoffMultiplier) { + if (backoffMultiplier < 1.0) { + throw new IllegalArgumentException("backoffMultiplier must be >= 1.0"); + } + this.backoffMultiplier = backoffMultiplier; + return this; + } + + /** Total wall-clock budget across all retry attempts before giving up. */ + public Builder retryBudget(Duration retryBudget) { + if (retryBudget == null || retryBudget.isNegative() || retryBudget.isZero()) { + throw new IllegalArgumentException("retryBudget must be positive"); + } + this.retryBudget = retryBudget; + return this; + } + + /** + * Cap on the in-memory replay buffer that holds sent-but-unacked stream events. Set to + * {@code 0} to disable replay (sent events are dropped on internal reset, matching the + * pre-replay-buffer behavior). See {@link #DEFAULT_MAX_REPLAY_BUFFER_BYTES}. + */ + public Builder maxReplayBufferBytes(long maxReplayBufferBytes) { + if (maxReplayBufferBytes < 0) { + throw new IllegalArgumentException("maxReplayBufferBytes must be non-negative"); + } + this.maxReplayBufferBytes = maxReplayBufferBytes; + return this; + } + + /** + * Max concurrent HTTP/2 streams per underlying TCP connection. See + * {@link #DEFAULT_MAX_STREAMS_PER_CONNECTION}. Raise above 1 to multiplex many bidi + * streams onto fewer connections (slashes ping load); leave at 1 for one-stream-per-TCP + * isolation. + */ + public Builder maxStreamsPerConnection(long maxStreamsPerConnection) { + if (maxStreamsPerConnection <= 0) { + throw new IllegalArgumentException("maxStreamsPerConnection must be positive"); + } + this.maxStreamsPerConnection = maxStreamsPerConnection; + return this; + } + + /** + * Number of Netty event-loop worker threads. {@code null} (default) uses the AWS SDK's + * default ({@code 2 * NCPU}). Override for high-burst single-process workloads where + * the default leaves event loops saturated by inbound frame processing. + */ + public Builder nettyEventLoopThreads(Integer nettyEventLoopThreads) { + if (nettyEventLoopThreads != null && nettyEventLoopThreads <= 0) { + throw new IllegalArgumentException("nettyEventLoopThreads must be positive or null"); + } + this.nettyEventLoopThreads = nettyEventLoopThreads; + return this; + } + + /** + * Period between HTTP/2 keep-alive PING frames (and ACK timeout). {@code null} (default) + * uses the AWS SDK Netty client default (5 s). Pass {@code Duration.ZERO} to disable + * PING frames entirely. See {@link #DEFAULT_HEALTH_CHECK_PING_PERIOD}. + */ + public Builder healthCheckPingPeriod(Duration healthCheckPingPeriod) { + if (healthCheckPingPeriod != null && healthCheckPingPeriod.isNegative()) { + throw new IllegalArgumentException("healthCheckPingPeriod must be non-negative or null"); + } + this.healthCheckPingPeriod = healthCheckPingPeriod; + return this; + } + public SageMakerConfig build() { if (endpointName == null || endpointName.isBlank()) { throw new IllegalArgumentException("endpointName is required"); } + if (initialBackoff.compareTo(maxBackoff) > 0) { + throw new IllegalArgumentException("initialBackoff (" + initialBackoff + + ") must not exceed maxBackoff (" + maxBackoff + ")"); + } return new SageMakerConfig(this); } } diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index 4a3bf10..efcbbdf 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -2,6 +2,7 @@ import com.deepgram.core.transport.DeepgramTransport; +import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.sagemakerruntimehttp2.SageMakerRuntimeHttp2AsyncClient; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.InvokeEndpointWithBidirectionalStreamRequest; @@ -13,15 +14,22 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * {@link DeepgramTransport} implementation that routes Deepgram API calls through @@ -41,6 +49,9 @@ */ public class SageMakerTransport implements DeepgramTransport { + private static final Logger log = LoggerFactory.getLogger(SageMakerTransport.class); + private final String transportId = String.format("%08x", System.identityHashCode(this)); + private final SageMakerRuntimeHttp2AsyncClient smClient; private final SageMakerConfig config; private final String invocationPath; @@ -59,6 +70,29 @@ public class SageMakerTransport implements DeepgramTransport { private volatile CompletableFuture streamFuture; private final Object connectLock = new Object(); + // Hoisted out of StreamPublisher so messages queued during an internal reset survive into the + // next stream attempt instead of being dropped with the discarded publisher. + private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); + + // Replay buffer: events sent on the current stream that AWS hasn't acked yet (no payload part + // received since they were sent). On internal reset (handleStreamError → RETRYABLE), the next + // attemptConnect drains this buffer onto the new stream so audio sent on the rejected stream + // isn't lost. Trimmed in handlePayloadPart (a transcript proves AWS has consumed prior audio) + // and capped at config.maxReplayBufferBytes() with FIFO eviction so unbounded throttle storms + // don't OOM the JVM. + private final ConcurrentLinkedDeque replayBuffer = new ConcurrentLinkedDeque<>(); + private final AtomicLong replayBufferBytes = new AtomicLong(0L); + + // Retry budget tracking. Reset to 0 once real downstream data flows back to the application + // (handlePayloadPart) — NOT on subscription success. Subscription succeeds in TLS+HTTP/2 + // setup terms even when the bidi-stream request will be throttled milliseconds later. + private final AtomicInteger retryAttempt = new AtomicInteger(0); + private volatile long retryWindowStart = 0L; + // Earliest wall-clock at which the next attemptConnect is allowed to proceed. Set by + // handleStreamError so post-subscription throttles (which never reach the ensureConnected + // catch-block backoff) still pace the next attempt. + private volatile long retryNotBeforeMs = 0L; + SageMakerTransport( SageMakerRuntimeHttp2AsyncClient smClient, SageMakerConfig config, @@ -71,63 +105,333 @@ public class SageMakerTransport implements DeepgramTransport { } /** - * Establish the bidirectional stream if not already connected. - * Blocks until the SDK has subscribed to the event publisher. + * Establish the bidirectional stream if not already connected. Blocks until the AWS SDK + * subscribes to the event publisher. + * + *

Internally retries with exponential backoff on transient AWS errors (throttling, + * connection-pool exhaustion, transient connect/timeout failures) bounded by + * {@link SageMakerConfig#maxRetries()} and {@link SageMakerConfig#retryBudget()}. Terminal + * errors (auth, validation) and budget exhaustion bubble out and surface to {@code errorListeners} + * via the caller's {@code send*} path. */ private void ensureConnected() { - if (connected.get()) return; + // Fast path requires BOTH connected=true AND a live publisher. If either is false we + // re-enter the synchronized reconnect loop. Without the publisher liveness check, a + // race between handleStreamError (which sets connected=false then completes the + // publisher) and ensureConnected's tail (connected.set(true)) can leave us with + // connected=true pointing at a dead publisher — sends silently NOOP forever and the + // conn never recovers (root cause of the 44 silent conns observed in the + // 2026-05-03 replay-buffer-only test). + StreamPublisher pub = inputPublisher; + if (connected.get() && pub != null && !pub.isCompleted()) return; synchronized (connectLock) { - if (connected.get()) return; + pub = inputPublisher; + if (connected.get() && pub != null && !pub.isCompleted()) return; + // If we got here because the publisher died but connected was still true (the race + // above), tell the loop to start a fresh attempt rather than treat ourselves as + // already-connected. + connected.set(false); + + if (retryWindowStart == 0L) { + retryWindowStart = System.currentTimeMillis(); + } - inputPublisher = new StreamPublisher(); + Throwable lastError = null; + while (true) { + // Honor backoff scheduled by handleStreamError so post-subscription throttles + // pace the next attempt (the catch-block backoff below never fires for those — + // attemptConnect succeeds, the throttle hits later). + long now = System.currentTimeMillis(); + if (retryNotBeforeMs > now) { + long sleepMs = retryNotBeforeMs - now; + log.info("[{}] ensureConnected: honoring scheduled backoff ({}ms) before next attemptConnect", + transportId, sleepMs); + try { + Thread.sleep(sleepMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during scheduled backoff", lastError); + } + } + int attemptBefore = retryAttempt.get(); + log.info("[{}] ensureConnected: starting attemptConnect (attempt={}/{}, elapsed={}ms/{}ms)", + transportId, attemptBefore, config.maxRetries(), + System.currentTimeMillis() - retryWindowStart, config.retryBudget().toMillis()); + try { + attemptConnect(); + // Subscription succeeded — but DO NOT reset retryAttempt here. Subscription + // success only proves TLS+HTTP/2 setup; the actual bidi-stream request can + // still be throttled. We reset retryAttempt only when real downstream data + // arrives, in handlePayloadPart. + log.info("[{}] ensureConnected: attemptConnect SUCCEEDED (attempt={} — counter NOT reset; " + + "waits for handlePayloadPart to confirm real data flow before resetting)", + transportId, attemptBefore); + connected.set(true); + return; + } catch (Throwable t) { + lastError = t; + Classification c = classify(t); + int attempt = retryAttempt.get(); + long elapsed = System.currentTimeMillis() - retryWindowStart; + boolean budgetLeft = attempt < config.maxRetries() + && elapsed < config.retryBudget().toMillis(); + log.info("[{}] ensureConnected: attemptConnect FAILED — class={} attempt={}/{} elapsed={}ms/{}ms budgetLeft={} err={}", + transportId, c, attempt, config.maxRetries(), elapsed, + config.retryBudget().toMillis(), budgetLeft, summarize(t)); + if (c == Classification.TERMINAL || !budgetLeft) { + log.warn("[{}] ensureConnected: SURFACING (class={} budgetLeft={}) — err={}", + transportId, c, budgetLeft, summarize(t)); + if (t instanceof RuntimeException) throw (RuntimeException) t; + throw new RuntimeException(t); + } + long backoff = computeBackoff(attempt); + log.info("[{}] ensureConnected: backoff={}ms before retry attempt {}", + transportId, backoff, attempt + 1); + retryAttempt.incrementAndGet(); + try { + Thread.sleep(backoff); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted during retry backoff after " + (attempt + 1) + " attempts", lastError); + } + } + } + } + } - InvokeEndpointWithBidirectionalStreamRequest.Builder requestBuilder = - InvokeEndpointWithBidirectionalStreamRequest.builder() - .endpointName(config.endpointName()) - .modelInvocationPath(invocationPath); - if (queryString != null && !queryString.isEmpty()) { - requestBuilder.modelQueryString(queryString); + /** Single connect attempt — invokes the bidi stream and waits for subscription. */ + private void attemptConnect() throws TimeoutException, InterruptedException { + StreamPublisher publisher = new StreamPublisher(pending); + inputPublisher = publisher; + + // Replay any unacked events from a prior failed stream attempt. They go through the + // publisher's pre-subscription pending queue and are flushed onto the new AWS subscriber + // as soon as it arrives — so AWS sees the buffered audio first, in order, before any + // new chunks the caller sends after this attempt succeeds. + if (!replayBuffer.isEmpty()) { + int count = replayBuffer.size(); + long bytes = replayBufferBytes.get(); + log.info("[{}] attemptConnect: replaying {} buffered events ({} bytes) onto new stream", + transportId, count, bytes); + for (BufferedEvent be : replayBuffer) { + publisher.send(be.event); } - InvokeEndpointWithBidirectionalStreamRequest request = requestBuilder.build(); - - InvokeEndpointWithBidirectionalStreamResponseHandler handler = - InvokeEndpointWithBidirectionalStreamResponseHandler.builder() - .onResponse(response -> { }) - .subscriber(InvokeEndpointWithBidirectionalStreamResponseHandler - .Visitor.builder() - .onPayloadPart(this::handlePayloadPart) - .build()) - .onError(error -> { - if (closeSent.get()) { - // Model idle timeout after CloseStream — treat as normal close. - inputPublisher.complete(); - notifyClose(1000, "Normal"); - } else { - for (Consumer l : errorListeners) { - l.accept(error); - } - } - }) - .onComplete(() -> { - notifyClose(1000, "Normal"); - }) - .build(); - - streamFuture = smClient.invokeEndpointWithBidirectionalStream( - request, inputPublisher, handler); - - // Wait for the SDK to subscribe to our publisher before sending events + } + + InvokeEndpointWithBidirectionalStreamRequest.Builder requestBuilder = + InvokeEndpointWithBidirectionalStreamRequest.builder() + .endpointName(config.endpointName()) + .modelInvocationPath(invocationPath); + if (queryString != null && !queryString.isEmpty()) { + requestBuilder.modelQueryString(queryString); + } + InvokeEndpointWithBidirectionalStreamRequest request = requestBuilder.build(); + + InvokeEndpointWithBidirectionalStreamResponseHandler handler = + InvokeEndpointWithBidirectionalStreamResponseHandler.builder() + .onResponse(response -> { }) + .subscriber(InvokeEndpointWithBidirectionalStreamResponseHandler + .Visitor.builder() + .onPayloadPart(this::handlePayloadPart) + .build()) + .onError(this::handleStreamError) + .onComplete(() -> notifyClose(1000, "Normal")) + .build(); + + streamFuture = smClient.invokeEndpointWithBidirectionalStream(request, publisher, handler); + + if (!publisher.awaitSubscription(config.subscriptionTimeout().toMillis(), TimeUnit.MILLISECONDS)) { + // Subscription never landed — treat as a transient connect failure so the retry loop + // classifies and (if budget allows) tries again on a fresh stream. try { - inputPublisher.awaitSubscription(30, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted waiting for stream subscription", e); + streamFuture.cancel(true); + } catch (Throwable ignored) { + // best-effort + } + throw new TimeoutException( + "Timed out waiting for AWS SDK to subscribe to stream publisher after " + + config.subscriptionTimeout()); + } + + // Race guard: under high-concurrency burst-throttle, AWS sometimes accepts the + // subscription and IMMEDIATELY rejects the stream (ThrottlingException, ping-timeout). + // The handleStreamError callback fires on a Netty event-loop thread BEFORE attemptConnect + // returns from awaitSubscription, sets connected=false, and completes this publisher. + // If we then unconditionally let ensureConnected set connected=true, the audio thread + // takes the fast path and pushes chunks into a dead publisher — silently dropped, no + // transcripts ever come back. Detect the race here by re-checking the publisher's + // completion state and surfacing it as a transient failure so ensureConnected re-enters + // the retry loop on the next send. + if (publisher.isCompleted()) { + throw new IllegalStateException( + "Stream rejected immediately after subscription (handleStreamError already " + + "fired and completed the publisher); ensureConnected will re-attempt"); + } + } + + /** + * Async error gate: AWS SDK reports stream errors via the response handler's onError. Classify + * here so transient AWS errors trigger an internal reset (next send re-enters ensureConnected + * via the retry loop) instead of bubbling straight to {@code errorListeners}. + */ + private void handleStreamError(Throwable error) { + if (closeSent.get()) { + // Model idle timeout after CloseStream — treat as normal close. + log.info("[{}] handleStreamError: closeSent=true → treating as normal close", transportId); + if (inputPublisher != null) inputPublisher.complete(); + notifyClose(1000, "Normal"); + return; + } + + Classification c = classify(error); + int attempt = retryAttempt.get(); + long windowStart = retryWindowStart; + long elapsed = windowStart == 0L ? 0L : System.currentTimeMillis() - windowStart; + boolean budgetLeft = attempt < config.maxRetries() + && elapsed < config.retryBudget().toMillis(); + + log.info("[{}] handleStreamError: class={} attempt={}/{} elapsed={}ms/{}ms budgetLeft={} err={}", + transportId, c, attempt, config.maxRetries(), elapsed, + config.retryBudget().toMillis(), budgetLeft, summarize(error)); + + if (c == Classification.RETRYABLE && budgetLeft) { + // Internal reset: drop current stream, mark disconnected, and SCHEDULE a backoff so + // the next ensureConnected pause-then-reconnect rather than immediately hammering + // the AWS frontline. We don't sleep here (this runs on a Netty event-loop thread); + // we just set retryNotBeforeMs and the next ensureConnected honors it. + // + // We also advance retryAttempt and start the budget window if not already started, + // so under repeated post-subscription throttles the retry budget actually gets + // consumed and eventually surfaces a terminal error to the application. + if (retryWindowStart == 0L) { + retryWindowStart = System.currentTimeMillis(); } + int attemptForBackoff = retryAttempt.getAndIncrement(); + long backoff = computeBackoff(attemptForBackoff); + retryNotBeforeMs = System.currentTimeMillis() + backoff; + log.info("[{}] handleStreamError: RETRYABLE → internal reset, attempt {} → scheduled backoff {}ms", + transportId, attemptForBackoff + 1, backoff); + connected.set(false); + if (inputPublisher != null) inputPublisher.complete(); + if (streamFuture != null) { + try { + streamFuture.cancel(true); + } catch (Throwable ignored) { + // best-effort + } + } + return; + } - connected.set(true); + // Terminal or budget-exhausted: surface to listeners. + log.warn("[{}] handleStreamError: SURFACING (class={} budgetLeft={}) → invoking {} errorListener(s)", + transportId, c, budgetLeft, errorListeners.size()); + for (Consumer l : errorListeners) { + l.accept(error); } } + /** One-line summary of a Throwable for log lines. */ + private static String summarize(Throwable t) { + if (t == null) return "null"; + String msg = t.getMessage(); + if (msg == null) msg = ""; + if (msg.length() > 160) msg = msg.substring(0, 157) + "..."; + StringBuilder sb = new StringBuilder(t.getClass().getSimpleName()).append(": ").append(msg); + if (t instanceof AwsServiceException) { + AwsServiceException ase = (AwsServiceException) t; + sb.append(" [status=").append(ase.statusCode()); + if (ase.awsErrorDetails() != null && ase.awsErrorDetails().errorCode() != null) { + sb.append(" code=").append(ase.awsErrorDetails().errorCode()); + } + sb.append("]"); + } + return sb.toString(); + } + + private long computeBackoff(int attempt) { + // Lambda re-resolves ThreadLocalRandom.current() at each invocation, so each calling thread + // gets its own per-thread RNG state — safe even if the LongBinaryOperator is cached and + // invoked from a different thread later (which the current call path doesn't do, but this + // is defensive). Method-reference form `ThreadLocalRandom.current()::nextLong` would + // capture whichever thread called computeBackoff first and wedge it as the RNG source. + return computeBackoff( + config.initialBackoff().toMillis(), + config.maxBackoff().toMillis(), + config.backoffMultiplier(), + attempt, + (origin, bound) -> ThreadLocalRandom.current().nextLong(origin, bound)); + } + + /** + * Pure-function backoff calculator with full jitter. Package-private + static for testability. + * + *

Without jitter, N conns failing simultaneously all compute the same exponential backoff + * and retry in lockstep, hammering the endpoint in waves (worst case: every {@code maxMs} the + * entire fleet retries together). Full jitter — random uniform in {@code [initialMs, ceiling]} — + * spreads the retry load continuously over the backoff window. See AWS Architecture Blog + * "Exponential Backoff and Jitter". + * + * @param randomLong injected RNG: {@code (originInclusive, boundExclusive) -> long}, allowing + * deterministic tests via a stub. + */ + static long computeBackoff(long initialMs, long maxMs, double multiplier, int attempt, + java.util.function.LongBinaryOperator randomLong) { + double scaled = initialMs * Math.pow(multiplier, attempt); + long ceiling = (scaled > maxMs || Double.isInfinite(scaled)) ? maxMs : Math.max(initialMs, (long) scaled); + if (ceiling <= initialMs) { + return ceiling; + } + return randomLong.applyAsLong(initialMs, ceiling + 1); + } + + enum Classification { RETRYABLE, TERMINAL } + + /** + * Classify an AWS-side exception as transient (retry internally, don't surface) vs terminal + * (surface to {@code errorListeners}). + * + *

Default is RETRYABLE — the retry budget ({@link SageMakerConfig#retryBudget()}) is the + * safety net, not the classifier. Under high-burst load against SageMaker we've seen a + * long tail of transient failure modes (Netty WriteTimeout, HTTP/2 stream resets, SSO 429s, + * pool-acquire timeouts, AWS-frontline ThrottlingException) that the classifier kept missing + * one-by-one; flipping the default means a new transient hiccup retries instead of surfacing + * as a hard failure. The budget caps the worst case at the configured wall-clock anyway. + * + *

The narrow set we genuinely consider TERMINAL is caller-side rejections from AWS: + * an {@link AwsServiceException} with a 4xx status code (other than 429, 424, and other than + * "throttling"-coded errors). Validation errors, AccessDenied, ResourceNotFound, etc. are + * authoritative — retrying won't change the outcome. Everything else, including the cause + * chain we don't recognize, defaults to RETRYABLE. + * + *

424 (Failed Dependency) is treated as RETRYABLE because under burst load we observed + * SageMaker emitting {@code ModelErrorException: Received server error (424) from primary + * with message "Failed to establish WebSocket connection"} when the upstream model container + * couldn't accept a new stream at the moment of the request. The next attempt usually + * succeeds, so it doesn't belong in the caller-side-rejection bucket. + */ + static Classification classify(Throwable error) { + for (Throwable t = error; t != null; t = t.getCause()) { + if (t instanceof AwsServiceException) { + AwsServiceException ase = (AwsServiceException) t; + int status = ase.statusCode(); + String code = ase.awsErrorDetails() != null ? ase.awsErrorDetails().errorCode() : null; + // Throttling is sometimes coded as 4xx (e.g. SageMaker frontline returns 400 with + // errorCode=ThrottlingException); always retry these regardless of status. + if (code != null && code.toLowerCase().contains("throttl")) return Classification.RETRYABLE; + // 4xx caller-side rejections: validation, auth, notfound — won't fix on retry. + // Exclusions: 429 (rate-limit, retry-with-backoff) and 424 (Failed Dependency, + // SageMaker upstream "Failed to establish WebSocket connection" — transient). + if (status >= 400 && status < 500 && status != 429 && status != 424) { + return Classification.TERMINAL; + } + } + if (t == t.getCause()) break; + } + return Classification.RETRYABLE; + } + @Override public CompletableFuture sendBinary(byte[] data) { if (!open.get()) { @@ -139,6 +443,7 @@ public CompletableFuture sendBinary(byte[] data) { .bytes(SdkBytes.fromByteArray(data)) .build(); inputPublisher.send(event); + bufferForReplay(event, data.length); return CompletableFuture.completedFuture(null); } @@ -149,11 +454,13 @@ public CompletableFuture sendText(String data) { new IllegalStateException("Transport is closed")); } ensureConnected(); + byte[] payload = data.getBytes(StandardCharsets.UTF_8); RequestStreamEvent event = RequestStreamEvent.payloadPartBuilder() - .bytes(SdkBytes.fromByteArray(data.getBytes(StandardCharsets.UTF_8))) + .bytes(SdkBytes.fromByteArray(payload)) .dataType("UTF8") .build(); inputPublisher.send(event); + bufferForReplay(event, payload.length); // Track that we've signaled end-of-audio so we can treat the model's // idle timeout as a normal close rather than an error. @@ -164,6 +471,44 @@ public CompletableFuture sendText(String data) { return CompletableFuture.completedFuture(null); } + /** + * Append {@code event} to the replay buffer for re-delivery on a future internal reset. + * Evicts oldest events FIFO once the configured byte cap is exceeded so an unbounded + * throttle storm can't OOM the JVM. Caller passes the payload byte length so we don't + * have to re-walk the SdkBytes wrapper (which would double-allocate on every send). + */ + private void bufferForReplay(RequestStreamEvent event, int eventBytes) { + long cap = config.maxReplayBufferBytes(); + if (cap == 0L) return; // replay disabled + replayBuffer.addLast(new BufferedEvent(event, eventBytes)); + long total = replayBufferBytes.addAndGet(eventBytes); + while (total > cap) { + BufferedEvent dropped = replayBuffer.pollFirst(); + if (dropped == null) break; + total = replayBufferBytes.addAndGet(-dropped.bytes); + } + } + + /** Drain the replay buffer (e.g. after AWS acks via {@code handlePayloadPart}). */ + private void clearReplayBuffer() { + replayBuffer.clear(); + replayBufferBytes.set(0L); + } + + // Test-only accessors. Package-private so unit tests in the same package can verify buffer + // behavior without going through the full AWS reactive-streams stack. + int replayBufferSize() { return replayBuffer.size(); } + long replayBufferBytes() { return replayBufferBytes.get(); } + void bufferForReplayForTest(RequestStreamEvent event, int eventBytes) { + bufferForReplay(event, eventBytes); + } + void clearReplayBufferForTest() { clearReplayBuffer(); } + java.util.List drainReplayBufferForTest() { + java.util.List out = new java.util.ArrayList<>(replayBuffer.size()); + for (BufferedEvent be : replayBuffer) out.add(be.event); + return out; + } + private void notifyClose(int code, String reason) { if (closeNotified.compareAndSet(false, true)) { for (CloseListener l : closeListeners) { @@ -175,10 +520,58 @@ private void notifyClose(int code, String reason) { private void handlePayloadPart(ResponsePayloadPart part) { byte[] bytes = part.bytes().asByteArray(); - // JSON messages start with '{"' (0x7B 0x22). Checking two bytes avoids + // Decode once. JSON messages start with '{"' (0x7B 0x22) — checking two bytes avoids // false positives from binary audio chunks that happen to start with 0x7B. - if (bytes.length > 1 && bytes[0] == '{' && bytes[1] == '"') { - String text = new String(bytes, StandardCharsets.UTF_8); + boolean isJson = bytes.length > 1 && bytes[0] == '{' && bytes[1] == '"'; + String text = isJson ? new String(bytes, StandardCharsets.UTF_8) : null; + + // Decide whether this payload counts as "the model consumed input" — the signal that + // makes it safe to reset retry counters and trim the replay buffer. + // + // Rather than enumerate every downstream message type per Deepgram product (Listen STT + // emits `Results`, Flux emits `TurnInfo`, Speak/Aura emits binary audio chunks, ...), + // we invert the check: assume ANY downstream payload counts as a real ack EXCEPT for + // end-of-stream-only message types that the container can emit without having consumed + // input. Two types are documented to fire at close regardless of consumption: + // + // - `Metadata` — end-of-stream summary. Under burst-replay we observed the container + // emitting `Metadata` (sha256="incomplete", duration=0) when the model errored out + // before producing any transcript; the original code treated this as an ack and + // cleared the replay buffer, so the next retry had nothing to replay and the + // replayed audio was lost (root cause of conn 125-style front-loss in the + // 2026-05-03 RC run). + // - `Error` — model-side fatal error message, also typically sent at close. + // + // Anything else — Results, TurnInfo, EndOfTurn, Warning, SpeechStarted, UtteranceEnd, + // raw binary audio (TTS) — represents real model activity and is safe to trust. + boolean isCloseOnlyMessage = isJson + && (text.contains("\"type\":\"Metadata\"") || text.contains("\"type\":\"Error\"")); + boolean countsAsAck = !isCloseOnlyMessage; + + if (countsAsAck) { + // Real downstream content on this transport (or first since the last retry-loop + // reset) → the stream is genuinely working end-to-end. Reset the retry budget so a + // future transient failure gets a fresh budget. + if (retryAttempt.get() != 0 || retryWindowStart != 0L || retryNotBeforeMs != 0L) { + log.info("[{}] handlePayloadPart: ack-class data received ({}B) → resetting retry counters " + + "(was attempt={}, windowStart={}, notBeforeMs={})", + transportId, bytes.length, + retryAttempt.get(), retryWindowStart, retryNotBeforeMs); + retryAttempt.set(0); + retryWindowStart = 0L; + retryNotBeforeMs = 0L; + } + + // The model produced output → it consumed input up to (at least) the moment this + // part's window started. Drop the replay buffer; if a retry happens later, we only + // need to replay events sent SINCE this ack. The model's internal latency means a + // small leading window (< 1 s, typically) of audio sent just before the failure + // might be re-transcribed on the new stream — net impact is a brief duplicate, far + // better than dropping 30+ s of audio outright. + clearReplayBuffer(); + } + + if (isJson) { for (Consumer l : messageListeners) { l.accept(text); } @@ -223,6 +616,10 @@ public boolean isOpen() { @Override public void close() { if (!open.compareAndSet(true, false)) return; + // Terminal close — drop any messages that were queued during a reset window, and free the + // replay buffer so we don't pin its memory after the transport is done. + pending.clear(); + clearReplayBuffer(); if (inputPublisher != null) { inputPublisher.complete(); } @@ -231,16 +628,38 @@ public void close() { } } + /** + * Replay-buffer entry — keeps the byte length alongside the event so the cap-and-evict path + * doesn't have to peek inside SdkBytes (which would incur an extra copy on every send). + */ + static final class BufferedEvent { + final RequestStreamEvent event; + final int bytes; + + BufferedEvent(RequestStreamEvent event, int bytes) { + this.event = event; + this.bytes = bytes; + } + } + /** * Reactive Streams publisher that buffers events until the SDK subscribes, * then delivers them in order. After subscription, events are forwarded immediately. + * + *

The {@code pending} queue is owned by the enclosing {@link SageMakerTransport} and + * shared across reconnect cycles, so events queued during an internal reset are drained + * onto whichever stream subscribes next. */ static class StreamPublisher implements Publisher { private volatile Subscriber subscriber; private final AtomicBoolean completed = new AtomicBoolean(false); - private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue pending; private final CountDownLatch subscribed = new CountDownLatch(1); + StreamPublisher(ConcurrentLinkedQueue sharedPending) { + this.pending = sharedPending; + } + @Override public void subscribe(Subscriber s) { this.subscriber = s; @@ -255,7 +674,8 @@ public void cancel() { completed.set(true); } }); - // Flush any events that were queued before subscription + // Flush any events that were queued before subscription (including events that + // survived a previous internal reset). RequestStreamEvent event; while ((event = pending.poll()) != null) { s.onNext(event); @@ -269,7 +689,7 @@ void send(RequestStreamEvent event) { if (s != null) { s.onNext(event); } else { - // Buffer until the SDK subscribes + // Buffer until the SDK subscribes (this stream or the next one after a reset) pending.add(event); } } @@ -283,8 +703,21 @@ void complete() { } } - void awaitSubscription(long timeout, TimeUnit unit) throws InterruptedException { - subscribed.await(timeout, unit); + /** + * @return {@code true} if subscription happened within the timeout, {@code false} on timeout. + */ + boolean awaitSubscription(long timeout, TimeUnit unit) throws InterruptedException { + return subscribed.await(timeout, unit); + } + + /** + * @return {@code true} if {@link #complete()} or the subscriber's cancel hook has fired. + * Used by the {@code attemptConnect} race-guard to detect when handleStreamError + * has already torn down the publisher between subscription and the + * {@code connected.set(true)} fast-path. + */ + boolean isCompleted() { + return completed.get(); } } } diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index 710ff57..9474865 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -1,8 +1,10 @@ package com.deepgram.sagemaker; +import com.deepgram.core.ReconnectingWebSocketListener; import com.deepgram.core.transport.DeepgramTransport; import com.deepgram.core.transport.DeepgramTransportFactory; +import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; import software.amazon.awssdk.http.Protocol; import software.amazon.awssdk.http.nio.netty.Http2Configuration; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @@ -10,6 +12,7 @@ import java.net.URI; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * Factory that creates SageMaker bidirectional streaming transports. @@ -33,46 +36,108 @@ * .transportFactory(factory) * .build(); * } + * + *

Connection-pool sharing

+ * + *

The default constructor backs the factory with a process-wide shared + * {@link SageMakerRuntimeHttp2AsyncClient} keyed by the parts of {@link SageMakerConfig} that + * affect the underlying Netty HTTP/2 client (region, max concurrency, connect/acquire timeouts). + * Multiple factories built with the same config fingerprint reuse one Netty event loop group and + * one connection pool — so naive code that constructs a fresh factory per stream still gets a + * single, well-behaved client underneath. + * + *

Without sharing, every factory instantiates its own Netty pool, and a burst of N factories + * triggers N simultaneous TLS handshakes from N distinct Netty clients against the same SageMaker + * endpoint — the SageMaker HTTP/2 frontline silently drops a large fraction of those streams + * before they ever reach the model container. Sharing matches the behavior of the canonical + * Python load-test harness, which has been verified to handle 400+ concurrent streams cleanly. + * + *

Lifecycle: + *

    + *
  • Default constructor → shared client; {@link #shutdown()} is a no-op. Call + * {@link #shutdownAllSharedClients()} once at app shutdown to release Netty resources. + *
  • {@link #SageMakerTransportFactory(SageMakerConfig, SageMakerRuntimeHttp2AsyncClient)} + * (BYO client) → caller owns the client lifecycle; {@link #shutdown()} is a no-op. + *
*/ public class SageMakerTransportFactory implements DeepgramTransportFactory { - private final SageMakerConfig config; - private final SageMakerRuntimeHttp2AsyncClient smClient; - /** - * Default max concurrent HTTP/2 streams (in-flight requests) across the - * shared connection pool. With {@code maxStreams=1} each stream gets its - * own TCP connection, so this value equals the maximum number of - * simultaneous bidirectional streams the factory can support. + * Process-wide shared clients keyed by config fingerprint. Subsequent factories with the + * same fingerprint reuse the existing client, so one Netty event loop group + connection + * pool serves all of them. */ - private static final int DEFAULT_MAX_CONCURRENCY = 500; + private static final ConcurrentHashMap SHARED_CLIENTS = + new ConcurrentHashMap<>(); + + private final SageMakerConfig config; + private final SageMakerRuntimeHttp2AsyncClient smClient; public SageMakerTransportFactory(SageMakerConfig config) { this.config = config; - this.smClient = SageMakerRuntimeHttp2AsyncClient.builder() - .region(config.region()) - .httpClientBuilder( - NettyNioAsyncHttpClient.builder() - .protocol(Protocol.HTTP2) - .maxConcurrency(DEFAULT_MAX_CONCURRENCY) - .http2Configuration( - Http2Configuration.builder() - .maxStreams(1L) - .build() - ) - ) - .build(); + this.smClient = SHARED_CLIENTS.computeIfAbsent( + sharedClientKey(config), + k -> buildClient(config)); } /** - * Create with a pre-configured SageMaker HTTP/2 client (for testing or - * custom credential providers). + * Create with a pre-configured SageMaker HTTP/2 client (for testing or custom credential + * providers). The provided client is not closed by {@link #shutdown()}; + * the caller owns its lifecycle. */ public SageMakerTransportFactory(SageMakerConfig config, SageMakerRuntimeHttp2AsyncClient smClient) { this.config = config; this.smClient = smClient; } + /** + * Cache key for the shared-client pool. Includes only the fields that affect the underlying + * Netty client; per-stream config (endpointName, contentType, retry knobs) doesn't. + */ + private static String sharedClientKey(SageMakerConfig c) { + return c.region().id() + + "|" + c.maxConcurrency() + + "|" + c.connectionTimeout().toMillis() + + "|" + c.connectionAcquireTimeout().toMillis() + + "|" + c.maxStreamsPerConnection() + + "|" + (c.nettyEventLoopThreads() == null ? "default" : c.nettyEventLoopThreads()) + + "|" + (c.healthCheckPingPeriod() == null ? "default" : c.healthCheckPingPeriod().toMillis()); + } + + private static SageMakerRuntimeHttp2AsyncClient buildClient(SageMakerConfig config) { + Http2Configuration.Builder http2Builder = Http2Configuration.builder() + .maxStreams(config.maxStreamsPerConnection()); + if (config.healthCheckPingPeriod() != null) { + http2Builder.healthCheckPingPeriod(config.healthCheckPingPeriod()); + } + NettyNioAsyncHttpClient.Builder httpBuilder = NettyNioAsyncHttpClient.builder() + .protocol(Protocol.HTTP2) + .maxConcurrency(config.maxConcurrency()) + .connectionTimeout(config.connectionTimeout()) + .connectionAcquisitionTimeout(config.connectionAcquireTimeout()) + .http2Configuration(http2Builder.build()); + if (config.nettyEventLoopThreads() != null) { + httpBuilder.eventLoopGroupBuilder( + software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup.builder() + .numberOfThreads(config.nettyEventLoopThreads()) + ); + } + return SageMakerRuntimeHttp2AsyncClient.builder() + .region(config.region()) + // Disable the AWS SDK's internal retry strategy. SageMakerTransport owns the + // retry policy (handleStreamError + ensureConnected backoff). The AWS SDK's + // default 3-attempt strategy compounds on top: every "1 retry" in our schedule + // becomes ~4 hits on the SageMaker frontline (the original attempt + 3 SDK + // retries with their own ~25/100/400 ms backoffs). Under a per-LB-IP throttle + // ceiling measured in requests/sec, that amplification keeps the conn pinned + // in the throttle window. AwsRetryStrategy.doNotRetry() removes the SDK layer; + // transient TLS / connection-reset hiccups are still caught by our IOException + // → RETRYABLE classification and the same backoff path. + .overrideConfiguration(c -> c.retryStrategy(AwsRetryStrategy.doNotRetry())) + .httpClientBuilder(httpBuilder) + .build(); + } + @Override public DeepgramTransport create(String url, Map headers) { // Parse the WebSocket URL to extract invocation path and query string. @@ -86,8 +151,41 @@ public DeepgramTransport create(String url, Map headers) { return new SageMakerTransport(smClient, config, invocationPath, queryString); } - /** Shut down the underlying AWS SDK client. */ + /** + * Disable the SDK's wrapper-level reconnect loop. {@link SageMakerTransport} owns its own + * retry/backoff/classification (see {@link SageMakerConfig#maxRetries()}, + * {@link SageMakerConfig#retryBudget()}); wrapping it in another retry layer compounds + * transient AWS errors into Throttling-on-Throttling storms under burst load. + */ + @Override + public ReconnectingWebSocketListener.ReconnectOptions reconnectOptions() { + return ReconnectingWebSocketListener.ReconnectOptions.builder() + .maxRetries(0) + .build(); + } + + /** + * No-op for factories backed by the shared client pool or a caller-owned (BYO) client. + * Use {@link #shutdownAllSharedClients()} to close shared clients at app shutdown; close + * your own client directly if you provided one. + */ public void shutdown() { - smClient.close(); + // Intentionally no-op. See class-level Javadoc for lifecycle semantics. + } + + /** + * Close all process-wide shared {@link SageMakerRuntimeHttp2AsyncClient} instances. Call + * once at app shutdown if you want to release Netty resources cleanly before JVM exit. + * Subsequent default-constructor factories will lazily build new shared clients. + */ + public static void shutdownAllSharedClients() { + for (SageMakerRuntimeHttp2AsyncClient client : SHARED_CLIENTS.values()) { + try { + client.close(); + } catch (Exception ignored) { + // best-effort + } + } + SHARED_CLIENTS.clear(); } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java index 2a4e10a..27047b6 100755 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java @@ -3,7 +3,9 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; +import com.deepgram.core.ReconnectingWebSocketListener; import com.deepgram.core.transport.DeepgramTransport; +import java.time.Duration; import java.util.Map; import software.amazon.awssdk.regions.Region; @@ -62,4 +64,305 @@ void factoryAcceptsCustomRegion() { assertEquals(Region.EU_WEST_1, config.region()); } + + @Test + void configDefaultsAreLenientForHighConcurrency() { + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + + assertEquals(Duration.ofSeconds(30), config.connectionTimeout()); + assertEquals(Duration.ofSeconds(60), config.connectionAcquireTimeout()); + assertEquals(Duration.ofSeconds(60), config.subscriptionTimeout()); + assertEquals(500, config.maxConcurrency()); + } + + @Test + void configAcceptsCustomTimeoutsAndConcurrency() { + SageMakerConfig config = + SageMakerConfig.builder() + .endpointName("my-endpoint") + .connectionTimeout(Duration.ofSeconds(5)) + .connectionAcquireTimeout(Duration.ofSeconds(15)) + .subscriptionTimeout(Duration.ofSeconds(45)) + .maxConcurrency(1000) + .build(); + + assertEquals(Duration.ofSeconds(5), config.connectionTimeout()); + assertEquals(Duration.ofSeconds(15), config.connectionAcquireTimeout()); + assertEquals(Duration.ofSeconds(45), config.subscriptionTimeout()); + assertEquals(1000, config.maxConcurrency()); + } + + @Test + void configRejectsNonPositiveTimeouts() { + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").connectionTimeout(Duration.ZERO).build()); + assertThrows( + IllegalArgumentException.class, + () -> + SageMakerConfig.builder() + .endpointName("e") + .connectionAcquireTimeout(Duration.ofSeconds(-1)) + .build()); + assertThrows( + IllegalArgumentException.class, + () -> + SageMakerConfig.builder() + .endpointName("e") + .subscriptionTimeout(null) + .build()); + } + + @Test + void configRejectsNonPositiveMaxConcurrency() { + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxConcurrency(0).build()); + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxConcurrency(-1).build()); + } + + @Test + void retryConfigDefaults() { + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + + assertEquals(5, config.maxRetries()); + assertEquals(Duration.ofMillis(100), config.initialBackoff()); + assertEquals(Duration.ofSeconds(5), config.maxBackoff()); + assertEquals(2.0, config.backoffMultiplier()); + assertEquals(Duration.ofSeconds(30), config.retryBudget()); + } + + @Test + void retryConfigAcceptsCustomValues() { + SageMakerConfig config = + SageMakerConfig.builder() + .endpointName("my-endpoint") + .maxRetries(10) + .initialBackoff(Duration.ofMillis(50)) + .maxBackoff(Duration.ofSeconds(20)) + .backoffMultiplier(3.0) + .retryBudget(Duration.ofMinutes(2)) + .build(); + + assertEquals(10, config.maxRetries()); + assertEquals(Duration.ofMillis(50), config.initialBackoff()); + assertEquals(Duration.ofSeconds(20), config.maxBackoff()); + assertEquals(3.0, config.backoffMultiplier()); + assertEquals(Duration.ofMinutes(2), config.retryBudget()); + } + + @Test + void retryConfigValidates() { + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxRetries(-1).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").initialBackoff(Duration.ZERO).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxBackoff(Duration.ofSeconds(-1)).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").backoffMultiplier(0.5).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").retryBudget(null).build()); + } + + @Test + void retryConfigRejectsInitialGreaterThanMax() { + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder() + .endpointName("e") + .initialBackoff(Duration.ofSeconds(10)) + .maxBackoff(Duration.ofSeconds(5)) + .build()); + } + + @Test + void factoryDeclaresMaxRetriesZeroForReconnectOptions() { + // The plugin declares maxRetries(0) so the SDK's wrapper-level reconnect doesn't compound + // SageMaker's internal retries into a storm. + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + SageMakerTransportFactory factory = new SageMakerTransportFactory(config, mockClient); + + ReconnectingWebSocketListener.ReconnectOptions opts = factory.reconnectOptions(); + assertNotNull(opts); + assertEquals(0, opts.maxRetries); + } + + // ------------------------------------------------------------------------- + // Shared-client pool tests. The default constructor backs the factory with a process-wide + // shared SageMakerRuntimeHttp2AsyncClient keyed by config fingerprint, so naive code that + // builds a fresh factory per stream still benefits from a single Netty pool underneath. + // ------------------------------------------------------------------------- + + @Test + void defaultConstructorReusesSharedClientAcrossFactoriesWithSameConfig() { + // Reset the shared pool to isolate this test from other tests in the class. + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig configA = SageMakerConfig.builder() + .endpointName("endpoint-A") // endpoint name does NOT affect the shared client + .region("us-east-1") + .build(); + SageMakerConfig configB = SageMakerConfig.builder() + .endpointName("endpoint-B") // different endpoint, same Netty-relevant config + .region("us-east-1") + .build(); + + SageMakerTransportFactory f1 = new SageMakerTransportFactory(configA); + SageMakerTransportFactory f2 = new SageMakerTransportFactory(configB); + + // Use reflection sparingly — verify both factories point at the same underlying smClient. + assertSame(getSmClient(f1), getSmClient(f2), + "factories with same Netty-relevant config must share one smClient"); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void defaultConstructorBuildsDistinctSharedClientsForDifferentConfigs() { + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig configEast = SageMakerConfig.builder() + .endpointName("e").region("us-east-1").build(); + SageMakerConfig configWest = SageMakerConfig.builder() + .endpointName("e").region("us-west-2").build(); + SageMakerConfig configEastBigPool = SageMakerConfig.builder() + .endpointName("e").region("us-east-1").maxConcurrency(1000).build(); + + SageMakerTransportFactory east1 = new SageMakerTransportFactory(configEast); + SageMakerTransportFactory east2 = new SageMakerTransportFactory(configEast); + SageMakerTransportFactory west = new SageMakerTransportFactory(configWest); + SageMakerTransportFactory eastBig = new SageMakerTransportFactory(configEastBigPool); + + assertSame(getSmClient(east1), getSmClient(east2)); + assertNotSame(getSmClient(east1), getSmClient(west)); + assertNotSame(getSmClient(east1), getSmClient(eastBig)); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void byoClientConstructorIsNotPooled() { + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + + SageMakerTransportFactory byo = new SageMakerTransportFactory(config, mockClient); + SageMakerTransportFactory shared = new SageMakerTransportFactory(config); + + assertSame(mockClient, getSmClient(byo), "BYO factory must use the provided client"); + assertNotSame(mockClient, getSmClient(shared), + "Shared factory must build/lookup its own client, not steal the BYO mock"); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void shutdownIsNoopForBoth() { + // factory.shutdown() must not close shared or BYO clients — lifecycle belongs elsewhere. + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + + new SageMakerTransportFactory(config, mockClient).shutdown(); + verify(mockClient, never()).close(); + + SageMakerTransportFactory shared = new SageMakerTransportFactory(config); + shared.shutdown(); + // We can't easily verify .close() wasn't called on a real client without re-fetching; + // the assertion above on mockClient is the strong guarantee. The shared variant is + // documented as no-op via Javadoc. + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void shutdownAllSharedClientsClearsThePool() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + + SageMakerTransportFactory before = new SageMakerTransportFactory(config); + SageMakerRuntimeHttp2AsyncClient firstClient = getSmClient(before); + + SageMakerTransportFactory.shutdownAllSharedClients(); + + SageMakerTransportFactory after = new SageMakerTransportFactory(config); + SageMakerRuntimeHttp2AsyncClient secondClient = getSmClient(after); + + assertNotSame(firstClient, secondClient, + "shutdownAllSharedClients must clear the pool so the next factory builds a fresh client"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + @Test + void differentMaxStreamsForcesDistinctSharedClients() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig one = SageMakerConfig.builder() + .endpointName("e").maxStreamsPerConnection(1).build(); + SageMakerConfig hundred = SageMakerConfig.builder() + .endpointName("e").maxStreamsPerConnection(100).build(); + + SageMakerRuntimeHttp2AsyncClient c1 = getSmClient(new SageMakerTransportFactory(one)); + SageMakerRuntimeHttp2AsyncClient c100 = getSmClient(new SageMakerTransportFactory(hundred)); + + assertNotSame(c1, c100, + "different maxStreamsPerConnection must produce distinct shared clients (different cache keys)"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + @Test + void differentNettyEventLoopThreadsForcesDistinctSharedClients() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig dflt = SageMakerConfig.builder() + .endpointName("e").build(); + SageMakerConfig eighty = SageMakerConfig.builder() + .endpointName("e").nettyEventLoopThreads(80).build(); + + SageMakerRuntimeHttp2AsyncClient cDefault = getSmClient(new SageMakerTransportFactory(dflt)); + SageMakerRuntimeHttp2AsyncClient c80 = getSmClient(new SageMakerTransportFactory(eighty)); + + assertNotSame(cDefault, c80, + "different nettyEventLoopThreads must produce distinct shared clients (different cache keys)"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + @Test + void sameStreamAndEventLoopConfigSharesClient() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig configA = SageMakerConfig.builder() + .endpointName("e-a").maxStreamsPerConnection(50).nettyEventLoopThreads(40).build(); + SageMakerConfig configB = SageMakerConfig.builder() + .endpointName("e-b").maxStreamsPerConnection(50).nettyEventLoopThreads(40).build(); + + SageMakerRuntimeHttp2AsyncClient cA = getSmClient(new SageMakerTransportFactory(configA)); + SageMakerRuntimeHttp2AsyncClient cB = getSmClient(new SageMakerTransportFactory(configB)); + + // endpointName differs but it's not in the cache key — the underlying Netty client should be shared. + assertSame(cA, cB, + "configs with same Netty-affecting fields share the underlying client even with different endpointName"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + /** Reflection helper — read the package-private smClient field set by the constructor. */ + private static SageMakerRuntimeHttp2AsyncClient getSmClient(SageMakerTransportFactory f) { + try { + java.lang.reflect.Field fld = SageMakerTransportFactory.class.getDeclaredField("smClient"); + fld.setAccessible(true); + return (SageMakerRuntimeHttp2AsyncClient) fld.get(f); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java new file mode 100644 index 0000000..e956394 --- /dev/null +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -0,0 +1,478 @@ +package com.deepgram.sagemaker; + +import static org.junit.jupiter.api.Assertions.*; + +import com.deepgram.sagemaker.SageMakerTransport.StreamPublisher; +import java.io.IOException; +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.services.sagemakerruntimehttp2.model.RequestStreamEvent; + +/** + * Unit tests for {@link SageMakerTransport}'s retry/classification logic and the hoisted pending-queue + * machinery. End-to-end retry against the AWS reactive-streams handler isn't covered here (the + * handler indirection makes it hard to deterministically stub); those paths are exercised by the + * burst test in the README. + */ +class SageMakerTransportRetryTest { + + // Helpers live on the outer class because @Nested inner classes can't have static members on Java 11. + + private static RequestStreamEvent payloadEvent(String s) { + return RequestStreamEvent.payloadPartBuilder() + .bytes(SdkBytes.fromUtf8String(s)) + .build(); + } + + private static Subscriber noopSubscriber() { + return new Subscriber() { + @Override + public void onSubscribe(Subscription s) {} + @Override + public void onNext(RequestStreamEvent e) {} + @Override + public void onError(Throwable t) {} + @Override + public void onComplete() {} + }; + } + + // RNG stubs for backoff tests — must live on outer class because @Nested inner classes can't + // have static members on Java 11. + private static final java.util.function.LongBinaryOperator MAX_RNG = (origin, bound) -> bound - 1; + private static final java.util.function.LongBinaryOperator MIN_RNG = (origin, bound) -> origin; + private static final java.util.function.LongBinaryOperator MID_RNG = + (origin, bound) -> origin + (bound - origin) / 2; + + private static class CapturingSubscriber implements Subscriber { + final List received = new ArrayList<>(); + final CountDownLatch completed = new CountDownLatch(1); + + @Override + public void onSubscribe(Subscription s) {} + @Override + public void onNext(RequestStreamEvent e) { + received.add(e); + } + @Override + public void onError(Throwable t) {} + @Override + public void onComplete() { + completed.countDown(); + } + } + + @Nested + @DisplayName("classify(Throwable)") + class ClassifyTests { + @Test + @DisplayName("TimeoutException is retryable") + void timeoutIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new TimeoutException("acquire timeout"))); + } + + @Test + @DisplayName("ConnectException is retryable") + void connectExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new ConnectException("connection refused"))); + } + + @Test + @DisplayName("IOException is retryable") + void ioExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new IOException("network error"))); + } + + @Test + @DisplayName("CancellationException is retryable (covers self-induced retry-reset cancels)") + void cancellationExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new CancellationException())); + } + + @Test + @DisplayName("FutureCancelledException-style wrapper (RuntimeException with CancellationException cause) is retryable") + void cancellationWrappedInRuntimeIsRetryable() { + // Mirrors AWS Netty's FutureCancelledException, which extends RuntimeException and + // wraps a CancellationException as its cause. + RuntimeException wrapper = new RuntimeException("future cancelled", new CancellationException()); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(wrapper)); + } + + @Test + @DisplayName("AWS 429 (Too Many Requests) is retryable") + void aws429IsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("Rate exceeded") + .statusCode(429) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 5xx is retryable") + void aws5xxIsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("internal") + .statusCode(503) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS error code containing 'throttl' is retryable regardless of status") + void awsThrottlingErrorCodeIsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("Rate exceeded") + .statusCode(400) + .awsErrorDetails(AwsErrorDetails.builder() + .errorCode("ThrottlingException") + .build()) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 401 (Unauthorized) is terminal") + void aws401IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("Forbidden") + .statusCode(401) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 403 (Forbidden) is terminal") + void aws403IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("Forbidden") + .statusCode(403) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("SdkException with 'pool exhausted' message is retryable (defensive belt)") + void sdkExceptionWithPoolKeywordIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Connection pool exhausted") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("'Unable to load credentials' is retryable when an SSO/STS provider hit Status Code: 429") + void credentialLoadFailureWithSsoThrottleIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Unable to load credentials from any of the providers in the chain " + + "AwsCredentialsProviderChain(...): [..., " + + "ProfileCredentialsProvider(profileName=shared-dev, ...): " + + "HTTP 429 Unknown Code (Service: Sso, Status Code: 429, Request ID: abc) " + + "(SDK Attempt Count: 4), ...]") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("'Unable to load credentials' is retryable when a credential backend returned 5xx") + void credentialLoadFailureWith5xxIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Unable to load credentials from any of the providers in the chain " + + "AwsCredentialsProviderChain(...): [..., " + + "InstanceProfileCredentialsProvider(): " + + "HTTP 503 (Service: Imds, Status Code: 503, Request ID: xyz)]") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("Walks the cause chain — IOException wrapped in RuntimeException is retryable") + void walksCauseChain() { + RuntimeException wrapper = new RuntimeException("oops", new IOException("netty")); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(wrapper)); + } + + @Test + @DisplayName("Unknown exception defaults to RETRYABLE (budget is the safety net)") + void unknownDefaultsToRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new RuntimeException("mystery"))); + } + + @Test + @DisplayName("Netty WriteTimeoutException-style RuntimeException is retryable by default") + void nettyWriteTimeoutIsRetryable() { + // io.netty.handler.timeout.WriteTimeoutException extends Netty's own TimeoutException + // (NOT java.util.concurrent.TimeoutException) which extends RuntimeException. We don't + // want to take a direct Netty compile dep just to instanceof-check it, so the new + // default-RETRYABLE policy covers this organically. + class WriteTimeoutException extends RuntimeException { + WriteTimeoutException() { super(); } + } + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new WriteTimeoutException())); + } + + @Test + @DisplayName("AWS 400 (ValidationException) is terminal — caller-side rejection") + void aws400ValidationIsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("invalid input") + .statusCode(400) + .awsErrorDetails(AwsErrorDetails.builder() + .errorCode("ValidationException") + .build()) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 404 (ResourceNotFound) is terminal — won't appear on retry") + void aws404IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("endpoint not found") + .statusCode(404) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 424 (Failed Dependency, SageMaker ModelError) is retryable — upstream container transient") + void aws424IsRetryable() { + // Mirrors the actual SageMaker burst-load error: + // ModelErrorException: Received server error (424) from primary with message + // "Failed to establish WebSocket connection" + AwsServiceException ase = AwsServiceException.builder() + .message("Received server error (424) from primary with message " + + "\"Failed to establish WebSocket connection\"") + .statusCode(424) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + } + + @Nested + @DisplayName("StreamPublisher") + class StreamPublisherTests { + @Test + @DisplayName("awaitSubscription returns false on timeout") + void awaitSubscriptionTimeout() throws InterruptedException { + ConcurrentLinkedQueue q = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(q); + assertFalse(pub.awaitSubscription(50, TimeUnit.MILLISECONDS), + "no subscriber within timeout — must report false so callers can fail fast"); + } + + @Test + @DisplayName("awaitSubscription returns true once a subscriber arrives") + void awaitSubscriptionSuccess() throws InterruptedException { + ConcurrentLinkedQueue q = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(q); + pub.subscribe(noopSubscriber()); + assertTrue(pub.awaitSubscription(100, TimeUnit.MILLISECONDS)); + } + + @Test + @DisplayName("Pending queue is shared across publisher instances — surviving an internal reset") + void pendingPersistsAcrossPublishers() { + ConcurrentLinkedQueue shared = new ConcurrentLinkedQueue<>(); + + // First publisher: send 3 events before any subscriber arrives, then complete (simulating reset). + StreamPublisher first = new StreamPublisher(shared); + first.send(payloadEvent("a")); + first.send(payloadEvent("b")); + first.send(payloadEvent("c")); + first.complete(); + + assertEquals(3, shared.size(), "events must remain in the shared queue across publishers"); + + // Second publisher: subscribes and drains the queued events. + StreamPublisher second = new StreamPublisher(shared); + CapturingSubscriber sub = new CapturingSubscriber(); + second.subscribe(sub); + + assertEquals(3, sub.received.size(), "events queued before reset must arrive on the new stream"); + assertTrue(shared.isEmpty(), "queue must be drained after subscription"); + } + + @Test + @DisplayName("send() with subscriber present forwards immediately, no queueing") + void sendForwardsWhenSubscribed() { + ConcurrentLinkedQueue shared = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(shared); + CapturingSubscriber sub = new CapturingSubscriber(); + pub.subscribe(sub); + + pub.send(payloadEvent("hello")); + + assertEquals(1, sub.received.size()); + assertTrue(shared.isEmpty()); + } + } + + @Nested + @DisplayName("Replay buffer") + class ReplayBufferTests { + private SageMakerTransport newTransport(long maxReplayBufferBytes) { + SageMakerConfig cfg = SageMakerConfig.builder() + .endpointName("test") + .region("us-east-1") + .maxReplayBufferBytes(maxReplayBufferBytes) + .build(); + // null AWS client is fine — these tests only exercise the in-memory buffer helpers, + // never attempt a real bidi stream. + return new SageMakerTransport(null, cfg, "v1/listen", ""); + } + + @Test + @DisplayName("buffer accumulates events with running byte count") + void bufferAccumulates() { + SageMakerTransport t = newTransport(1024); + t.bufferForReplayForTest(payloadEvent("aaa"), 3); + t.bufferForReplayForTest(payloadEvent("bbbbb"), 5); + assertEquals(2, t.replayBufferSize()); + assertEquals(8L, t.replayBufferBytes()); + } + + @Test + @DisplayName("clearReplayBuffer drops everything (the AWS-acked path)") + void clearDropsAll() { + SageMakerTransport t = newTransport(1024); + t.bufferForReplayForTest(payloadEvent("a"), 1); + t.bufferForReplayForTest(payloadEvent("b"), 1); + t.clearReplayBufferForTest(); + assertEquals(0, t.replayBufferSize()); + assertEquals(0L, t.replayBufferBytes()); + } + + @Test + @DisplayName("FIFO eviction once cap exceeded — newest events kept, oldest dropped") + void evictionFifo() { + SageMakerTransport t = newTransport(10); // cap = 10 bytes + RequestStreamEvent a = payloadEvent("aaaa"); + RequestStreamEvent b = payloadEvent("bbbb"); + RequestStreamEvent c = payloadEvent("cccc"); + RequestStreamEvent d = payloadEvent("dddd"); + t.bufferForReplayForTest(a, 4); // total 4 + t.bufferForReplayForTest(b, 4); // total 8 + t.bufferForReplayForTest(c, 4); // total 12 → evict a, total 8 + t.bufferForReplayForTest(d, 4); // total 12 → evict b, total 8 + + // Latest two events should survive, oldest two evicted, ordering preserved. + java.util.List remaining = t.drainReplayBufferForTest(); + assertEquals(2, remaining.size()); + assertSame(c, remaining.get(0), "oldest surviving event should be 'cccc'"); + assertSame(d, remaining.get(1), "newest event should be 'dddd'"); + assertEquals(8L, t.replayBufferBytes()); + } + + @Test + @DisplayName("maxReplayBufferBytes=0 disables buffering entirely") + void disabledByZeroCap() { + SageMakerTransport t = newTransport(0L); + t.bufferForReplayForTest(payloadEvent("a"), 1); + t.bufferForReplayForTest(payloadEvent("b"), 1); + assertEquals(0, t.replayBufferSize(), "events must not accumulate when cap=0"); + assertEquals(0L, t.replayBufferBytes()); + } + + @Test + @DisplayName("oversized single event is dropped immediately by the eviction loop") + void oversizedEventCantStick() { + SageMakerTransport t = newTransport(10); + // A single 16-byte event exceeds the 10-byte cap; eviction loop should remove it. + t.bufferForReplayForTest(payloadEvent("0123456789ABCDEF"), 16); + assertEquals(0, t.replayBufferSize()); + assertEquals(0L, t.replayBufferBytes()); + } + } + + @Nested + @DisplayName("computeBackoff(initial, max, multiplier, attempt) — full jitter") + class ComputeBackoffTests { + @Test + @DisplayName("ceiling grows exponentially up to max; range is [initial, ceiling] inclusive") + void exponentialCeilingGrowth() { + // initial=100, multiplier=2 → ceilings 100, 200, 400, 800, 1600, capped at 1000. + // attempt=0: ceiling=initial=100 → degenerate range, returns 100 without RNG. + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 0, MID_RNG)); + // attempt=1: ceiling=200. Range=[100,201). MID_RNG returns origin + (bound-origin)/2 = 100 + 50 = 150. + assertEquals(150L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 1, MID_RNG)); + // attempt=4: scaled=1600, capped to ceiling=1000. Range=[100,1001). MID_RNG returns 100 + 450 = 550. + assertEquals(550L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 4, MID_RNG)); + } + + @Test + @DisplayName("MIN_RNG returns the initial floor, MAX_RNG returns the ceiling") + void rngBoundsRespected() { + // attempt=2 with initial=100, mult=2: scaled=400, ceiling=400. Range=[100,401). + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 2, MIN_RNG)); + assertEquals(400L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 2, MAX_RNG)); + } + + @Test + @DisplayName("ceiling caps at max regardless of attempt") + void ceilingCappedAtMax() { + // High attempt count would overflow without the cap. + assertEquals(5000L, SageMakerTransport.computeBackoff(100, 5000, 2.0, 100, MAX_RNG)); + // Even infinity scaling caps cleanly. + assertEquals(5000L, SageMakerTransport.computeBackoff(100, 5000, 2.0, 10_000, MAX_RNG)); + } + + @Test + @DisplayName("when ceiling == initial (attempt=0 or multiplier degenerate), returns ceiling without invoking RNG") + void degenerateRangeReturnsCeiling() { + java.util.concurrent.atomic.AtomicInteger rngCalls = new java.util.concurrent.atomic.AtomicInteger(0); + java.util.function.LongBinaryOperator countingRng = (o, b) -> { rngCalls.incrementAndGet(); return o; }; + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 0, countingRng)); + // multiplier=1.0 means scaled never grows beyond initial. + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 1.0, 5, countingRng)); + assertEquals(0, rngCalls.get(), "RNG must not be invoked when range collapses to a single value"); + } + + @Test + @DisplayName("with real ThreadLocalRandom: 1000 samples spread continuously across [initial, ceiling]") + void productionRngSpreadsRetries() { + // The whole point of this fix: in production, N concurrent retries should NOT cluster + // at the same ceiling value. Sample many times and assert the spread is meaningful. + int trials = 1000; + long min = Long.MAX_VALUE, max = Long.MIN_VALUE; + long sum = 0; + for (int i = 0; i < trials; i++) { + long b = SageMakerTransport.computeBackoff( + 100, 1000, 2.0, /*attempt*/ 4, + java.util.concurrent.ThreadLocalRandom.current()::nextLong); + min = Math.min(min, b); + max = Math.max(max, b); + sum += b; + } + // attempt=4 → ceiling=1000, range=[100,1000]. Expected spread is large. + assertTrue(min < 200, "min sample should land near initial floor; got " + min); + assertTrue(max > 900, "max sample should land near ceiling; got " + max); + long mean = sum / trials; + assertTrue(mean > 400 && mean < 700, + "mean of uniform [100,1000] should be near 550; got " + mean); + } + } +}