From e819738ffd6d7e1c1668694440615cdc1777cc8a Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Tue, 28 Apr 2026 15:33:11 -0700 Subject: [PATCH 01/10] feat: configurable timeouts and concurrency with lenient defaults for high-burst workloads Expose `connectionTimeout`, `connectionAcquireTimeout`, `subscriptionTimeout`, and `maxConcurrency` on `SageMakerConfig.Builder`. Wire them through `SageMakerTransportFactory` (Netty client) and `SageMakerTransport` (input-publisher subscription wait) so high-concurrency callers can tune the chokepoints that the AWS Netty defaults expose under burst load. Move the defaults to values tuned for high-burst workloads: connectionTimeout AWS Netty default 2s -> 30s connectionAcquireTimeout AWS Netty default 10s -> 60s subscriptionTimeout (was hardcoded 30s) -> 60s maxConcurrency (was hardcoded 500) -> 500 (now tunable) Empirically validated against a 400-concurrent-stream burst test on a 10x ml.g6.2xlarge endpoint: with the previous AWS defaults, customers hit a wave of `connection acquire` and `connect timed out` errors in the first few seconds of the run that look like server-side problems but are really client-side fail-fast tripping early. With the new defaults, the underlying connect/acquire layer absorbs the burst and those error categories go to near-zero. (Note: this fix only addresses the transport layer; SDK-level retry-storm remains until paired with `reconnect(false)` in `deepgram-java-sdk` >= 0.4.0.) Adds 4 unit tests covering defaults, custom values, and validation of non-positive arguments. README updated with a "High-concurrency notes" section and the new parameter table. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 27 ++++++ .../deepgram/sagemaker/SageMakerConfig.java | 95 +++++++++++++++++++ .../sagemaker/SageMakerTransport.java | 3 +- .../sagemaker/SageMakerTransportFactory.java | 12 +-- .../SageMakerTransportFactoryTest.java | 59 ++++++++++++ 5 files changed, 186 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 1f59c5c..004287b 100755 --- a/README.md +++ b/README.md @@ -95,6 +95,10 @@ The transport is transparent — the SDK API is identical whether using Deepgram |-----------|----------|---------|-------------| | `endpointName` | Yes | — | SageMaker endpoint name | | `region` | No | `us-west-2` | AWS region | +| `connectionTimeout` | No | `30s` | Max time for the underlying TCP/TLS connect (AWS Netty default is 2 s — bumped here so cold-start endpoints under burst load have time to accept TLS handshakes). | +| `connectionAcquireTimeout` | No | `60s` | Max time to acquire a connection from the Netty pool (AWS Netty default is 10 s — bumped so a 200–500-stream burst doesn't drain the acquire pool). | +| `subscriptionTimeout` | No | `60s` | Max time the transport waits for the AWS SDK to subscribe to the bidi-stream input publisher before failing the first send. | +| `maxConcurrency` | No | `500` | Max simultaneous in-flight HTTP/2 streams across the shared Netty pool. With `maxStreams=1` this is the cap on simultaneous bidirectional streams. | ```java SageMakerConfig config = SageMakerConfig.builder() @@ -103,6 +107,29 @@ SageMakerConfig config = SageMakerConfig.builder() .build(); ``` +#### High-concurrency notes + +The transport's defaults are tuned for high-burst workloads (large numbers of +streams opened in a tight loop against an endpoint that may need to scale up). +If you're opening 200–500 streams simultaneously against a cold endpoint, +the AWS Netty defaults (2 s connect / 10 s acquire) will fire before +the load balancer has accepted all of the inbound TLS handshakes — you'll +see a wave of `connection acquire` and `connect timed out` errors that look +like server-side problems but are really client-side fail-fast tripping early. + +This transport ships with more lenient defaults (30 s / 60 s) so the +common high-concurrency path works out of the box. Tighten them if you need +fail-fast behavior in low-latency pipelines: + +```java +SageMakerConfig config = SageMakerConfig.builder() + .endpointName("my-deepgram-endpoint") + .region("us-east-2") + .connectionTimeout(Duration.ofSeconds(5)) + .connectionAcquireTimeout(Duration.ofSeconds(15)) + .build(); +``` + ### Custom AWS Client For custom credential providers, proxy configuration, or testing: diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java index f4996d7..ee5b472 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java @@ -1,28 +1,66 @@ package com.deepgram.sagemaker; +import java.time.Duration; + import software.amazon.awssdk.regions.Region; /** * Configuration for connecting to a Deepgram model hosted on SageMaker. + * + *

Defaults are tuned for high-burst workloads (large numbers of streams + * opened in a tight loop against an endpoint that may need to scale up). They + * are intentionally more lenient than the AWS SDK Netty defaults so that + * 200–500-stream bursts don't trip connect-acquire / connect-handshake + * timeouts before the endpoint has had a chance to accept the inbound TLS + * handshakes. Tighten them if you want fail-fast behavior in low-latency + * pipelines. */ public class SageMakerConfig { + /** AWS Netty default is 2 s. Set to 30 s so cold endpoints under burst load can accept TLS. */ + public static final Duration DEFAULT_CONNECTION_TIMEOUT = Duration.ofSeconds(30); + + /** AWS Netty default is 10 s. Set to 60 s so a 400-stream burst doesn't drain the acquire pool. */ + public static final Duration DEFAULT_CONNECTION_ACQUIRE_TIMEOUT = Duration.ofSeconds(60); + + /** Time to wait for the AWS SDK to subscribe to the bidi-stream input publisher before failing. */ + public static final Duration DEFAULT_SUBSCRIPTION_TIMEOUT = Duration.ofSeconds(60); + + /** + * Max simultaneous in-flight HTTP/2 streams across the shared Netty connection + * pool. Combined with {@code maxStreams=1} (set by {@link SageMakerTransportFactory}), + * this is the cap on simultaneous bidirectional streams. + */ + public static final int DEFAULT_MAX_CONCURRENCY = 500; + private final String endpointName; private final Region region; private final String contentType; private final String acceptType; + private final Duration connectionTimeout; + private final Duration connectionAcquireTimeout; + private final Duration subscriptionTimeout; + private final int maxConcurrency; private SageMakerConfig(Builder builder) { this.endpointName = builder.endpointName; this.region = builder.region; this.contentType = builder.contentType; this.acceptType = builder.acceptType; + this.connectionTimeout = builder.connectionTimeout; + this.connectionAcquireTimeout = builder.connectionAcquireTimeout; + this.subscriptionTimeout = builder.subscriptionTimeout; + this.maxConcurrency = builder.maxConcurrency; } public String endpointName() { return endpointName; } public Region region() { return region; } public String contentType() { return contentType; } public String acceptType() { return acceptType; } + public Duration connectionTimeout() { return connectionTimeout; } + public Duration connectionAcquireTimeout() { return connectionAcquireTimeout; } + public Duration subscriptionTimeout() { return subscriptionTimeout; } + public int maxConcurrency() { return maxConcurrency; } public static Builder builder() { return new Builder(); @@ -33,6 +71,10 @@ public static class Builder { private Region region = Region.US_WEST_2; private String contentType = "application/octet-stream"; private String acceptType = "application/json"; + private Duration connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; + private Duration connectionAcquireTimeout = DEFAULT_CONNECTION_ACQUIRE_TIMEOUT; + private Duration subscriptionTimeout = DEFAULT_SUBSCRIPTION_TIMEOUT; + private int maxConcurrency = DEFAULT_MAX_CONCURRENCY; public Builder endpointName(String endpointName) { this.endpointName = endpointName; @@ -59,6 +101,59 @@ public Builder acceptType(String acceptType) { return this; } + /** + * Max time to wait for the underlying TCP/TLS connect to complete. + * Forwards to {@code NettyNioAsyncHttpClient.Builder.connectionTimeout}. + */ + public Builder connectionTimeout(Duration connectionTimeout) { + if (connectionTimeout == null || connectionTimeout.isNegative() || connectionTimeout.isZero()) { + throw new IllegalArgumentException("connectionTimeout must be positive"); + } + this.connectionTimeout = connectionTimeout; + return this; + } + + /** + * Max time to wait when acquiring a connection from the Netty pool. + * Forwards to {@code NettyNioAsyncHttpClient.Builder.connectionAcquisitionTimeout}. + */ + public Builder connectionAcquireTimeout(Duration connectionAcquireTimeout) { + if (connectionAcquireTimeout == null + || connectionAcquireTimeout.isNegative() + || connectionAcquireTimeout.isZero()) { + throw new IllegalArgumentException("connectionAcquireTimeout must be positive"); + } + this.connectionAcquireTimeout = connectionAcquireTimeout; + return this; + } + + /** + * Max time the transport waits for the AWS SDK to subscribe to the + * bidirectional input publisher before failing the first send. + */ + public Builder subscriptionTimeout(Duration subscriptionTimeout) { + if (subscriptionTimeout == null + || subscriptionTimeout.isNegative() + || subscriptionTimeout.isZero()) { + throw new IllegalArgumentException("subscriptionTimeout must be positive"); + } + this.subscriptionTimeout = subscriptionTimeout; + return this; + } + + /** + * Max simultaneous in-flight HTTP/2 streams across the shared Netty pool. + * With {@code maxStreams=1} this equals the maximum number of concurrent + * bidirectional streams the factory can support. + */ + public Builder maxConcurrency(int maxConcurrency) { + if (maxConcurrency <= 0) { + throw new IllegalArgumentException("maxConcurrency must be positive"); + } + this.maxConcurrency = maxConcurrency; + return this; + } + public SageMakerConfig build() { if (endpointName == null || endpointName.isBlank()) { throw new IllegalArgumentException("endpointName is required"); diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index 4a3bf10..c7ab977 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -118,7 +118,8 @@ private void ensureConnected() { // Wait for the SDK to subscribe to our publisher before sending events try { - inputPublisher.awaitSubscription(30, TimeUnit.SECONDS); + inputPublisher.awaitSubscription( + config.subscriptionTimeout().toMillis(), TimeUnit.MILLISECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException("Interrupted waiting for stream subscription", e); diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index 710ff57..ac5517b 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -39,14 +39,6 @@ public class SageMakerTransportFactory implements DeepgramTransportFactory { private final SageMakerConfig config; private final SageMakerRuntimeHttp2AsyncClient smClient; - /** - * Default max concurrent HTTP/2 streams (in-flight requests) across the - * shared connection pool. With {@code maxStreams=1} each stream gets its - * own TCP connection, so this value equals the maximum number of - * simultaneous bidirectional streams the factory can support. - */ - private static final int DEFAULT_MAX_CONCURRENCY = 500; - public SageMakerTransportFactory(SageMakerConfig config) { this.config = config; this.smClient = SageMakerRuntimeHttp2AsyncClient.builder() @@ -54,7 +46,9 @@ public SageMakerTransportFactory(SageMakerConfig config) { .httpClientBuilder( NettyNioAsyncHttpClient.builder() .protocol(Protocol.HTTP2) - .maxConcurrency(DEFAULT_MAX_CONCURRENCY) + .maxConcurrency(config.maxConcurrency()) + .connectionTimeout(config.connectionTimeout()) + .connectionAcquisitionTimeout(config.connectionAcquireTimeout()) .http2Configuration( Http2Configuration.builder() .maxStreams(1L) diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java index 2a4e10a..56d2204 100755 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java @@ -4,6 +4,7 @@ import static org.mockito.Mockito.*; import com.deepgram.core.transport.DeepgramTransport; +import java.time.Duration; import java.util.Map; import software.amazon.awssdk.regions.Region; @@ -62,4 +63,62 @@ void factoryAcceptsCustomRegion() { assertEquals(Region.EU_WEST_1, config.region()); } + + @Test + void configDefaultsAreLenientForHighConcurrency() { + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + + assertEquals(Duration.ofSeconds(30), config.connectionTimeout()); + assertEquals(Duration.ofSeconds(60), config.connectionAcquireTimeout()); + assertEquals(Duration.ofSeconds(60), config.subscriptionTimeout()); + assertEquals(500, config.maxConcurrency()); + } + + @Test + void configAcceptsCustomTimeoutsAndConcurrency() { + SageMakerConfig config = + SageMakerConfig.builder() + .endpointName("my-endpoint") + .connectionTimeout(Duration.ofSeconds(5)) + .connectionAcquireTimeout(Duration.ofSeconds(15)) + .subscriptionTimeout(Duration.ofSeconds(45)) + .maxConcurrency(1000) + .build(); + + assertEquals(Duration.ofSeconds(5), config.connectionTimeout()); + assertEquals(Duration.ofSeconds(15), config.connectionAcquireTimeout()); + assertEquals(Duration.ofSeconds(45), config.subscriptionTimeout()); + assertEquals(1000, config.maxConcurrency()); + } + + @Test + void configRejectsNonPositiveTimeouts() { + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").connectionTimeout(Duration.ZERO).build()); + assertThrows( + IllegalArgumentException.class, + () -> + SageMakerConfig.builder() + .endpointName("e") + .connectionAcquireTimeout(Duration.ofSeconds(-1)) + .build()); + assertThrows( + IllegalArgumentException.class, + () -> + SageMakerConfig.builder() + .endpointName("e") + .subscriptionTimeout(null) + .build()); + } + + @Test + void configRejectsNonPositiveMaxConcurrency() { + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxConcurrency(0).build()); + assertThrows( + IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxConcurrency(-1).build()); + } } From 93b472f3cdf92ab6d14d4d0d3a209f4b804dc5b8 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Wed, 29 Apr 2026 15:00:52 -0700 Subject: [PATCH 02/10] feat: absorb AWS storm internally with retry/classification + persistent pending queue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bundles the four storm-handling asks from @lukeocodes' review of #14 into the existing PR. The plugin now owns end-to-end retry/backoff/classification for transient AWS failures so they never reach transport.onError(...) and the SDK's wrapper-level reconnect can be disabled via the new DeepgramTransportFactory.reconnectOptions() hook (paired with the SDK fix in deepgram/deepgram-java-sdk#45). SageMakerConfig — new retry knobs (defaults tuned for high-burst workloads): - maxRetries = 5 (set 0 to disable internal retry) - initialBackoff = 100 ms - maxBackoff = 5 s - backoffMultiplier = 2.0 (validated >= 1.0) - retryBudget = 30 s wall-clock cap across all retries - build() rejects initialBackoff > maxBackoff SageMakerTransport — internal storm absorption: - ensureConnected() wraps attemptConnect() in a budgeted retry loop with exponential backoff. Successful subscription resets the budget. - handleStreamError(): the response-handler onError gate from line 100 is now a classify-and-decide step. RETRYABLE + budget left triggers an internal reset (cancel future, complete publisher, mark disconnected) so the next send re-enters the loop on a fresh stream. TERMINAL or budget-exhausted surfaces to errorListeners as before. - classify(Throwable) walks the cause chain. Retryable: TimeoutException, ConnectException, IOException, AwsServiceException with status 429 or 5xx or error code containing "throttl", and SdkException whose message contains "acquire"/"pool"/"throttl"/"timeout". Terminal: 4xx other than 429, anything unmatched (default-deny for retry). - StreamPublisher.pending is hoisted up to a SageMakerTransport instance field so events queued during an internal reset survive across the publisher swap. The new publisher drains the shared queue on subscribe. - awaitSubscription now returns the boolean. Timeout throws TimeoutException, which flows into the retry loop as RETRYABLE — the new subscriptionTimeout knob now actually fails fast as advertised. SageMakerTransportFactory: - Overrides the new reconnectOptions() default to return ReconnectOptions.builder().maxRetries(0).build(). Combined with the SDK's auto-wiring in TransportWebSocketFactory, this disables the SDK's wrapper-level reconnect for any per-resource WebSocket client constructed against this factory — no user wiring required. build.gradle: - Bumps deepgram-java-sdk dep from 0.2.1 to 0.3.0 to pick up the new DeepgramTransportFactory.reconnectOptions() default method, ReconnectOptions.connectionTimeoutMs, the maxRetries(0) bug fix, and the applyOptionsOverride hook. Tests: - New SageMakerTransportRetryTest covers classify() across all branches (Timeout/Connect/IOException; AWS 401/403/429/5xx/throttling code; SdkException pool-keyword; cause-chain walking; default-terminal) and StreamPublisher behaviour (awaitSubscription false-on-timeout, true-on-subscribe; pending-queue persistence across publisher instances; immediate forward when subscriber present). - SageMakerTransportFactoryTest gains: retry-config defaults + customisation + validation, initialBackoff > maxBackoff rejection, and factoryDeclaresMaxRetriesZeroForReconnectOptions verifying the storm-suppression policy is published correctly. README: - Bumps the SDK dep version reference and notes the v0.3.0+ requirement. - Adds the new retry knobs to the configuration table. - New "Retry & storm absorption" section explaining the internal classification/budget design and how factory.reconnectOptions() auto-wires the SDK's reconnect-disable. End-to-end mock-AWS retry coverage isn't included here — the AWS reactive streams handler indirection makes deterministic stubbing fragile. Those paths are exercised by the existing burst test described in the PR. Co-Authored-By: Claude Opus 4.7 (1M context) --- README.md | 41 ++- sagemaker-transport/build.gradle | 2 +- .../deepgram/sagemaker/SageMakerConfig.java | 89 +++++++ .../sagemaker/SageMakerTransport.java | 244 ++++++++++++++---- .../sagemaker/SageMakerTransportFactory.java | 14 + .../SageMakerTransportFactoryTest.java | 68 +++++ .../SageMakerTransportRetryTest.java | 231 +++++++++++++++++ 7 files changed, 636 insertions(+), 53 deletions(-) create mode 100644 sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java diff --git a/README.md b/README.md index 004287b..f7c0e52 100755 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ SageMaker transport for the [Deepgram Java SDK](https://github.com/deepgram/deep ```groovy dependencies { - implementation 'com.deepgram:deepgram-java-sdk:0.2.1' + implementation 'com.deepgram:deepgram-java-sdk:0.3.0' implementation 'com.deepgram:deepgram-sagemaker:0.1.2' // x-release-please-version } ``` @@ -29,7 +29,7 @@ dependencies { ## Requirements - Java 11+ -- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.2.1+ +- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.3.0+ (the `default ReconnectOptions reconnectOptions()` hook on `DeepgramTransportFactory` is required for storm absorption) - AWS credentials configured (environment variables, shared credentials file, or IAM role) - A Deepgram model deployed to an AWS SageMaker endpoint @@ -97,8 +97,13 @@ The transport is transparent — the SDK API is identical whether using Deepgram | `region` | No | `us-west-2` | AWS region | | `connectionTimeout` | No | `30s` | Max time for the underlying TCP/TLS connect (AWS Netty default is 2 s — bumped here so cold-start endpoints under burst load have time to accept TLS handshakes). | | `connectionAcquireTimeout` | No | `60s` | Max time to acquire a connection from the Netty pool (AWS Netty default is 10 s — bumped so a 200–500-stream burst doesn't drain the acquire pool). | -| `subscriptionTimeout` | No | `60s` | Max time the transport waits for the AWS SDK to subscribe to the bidi-stream input publisher before failing the first send. | +| `subscriptionTimeout` | No | `60s` | Max time the transport waits for the AWS SDK to subscribe to the bidi-stream input publisher before failing. A timeout here is treated as a transient connect failure and counts against `maxRetries` / `retryBudget`. | | `maxConcurrency` | No | `500` | Max simultaneous in-flight HTTP/2 streams across the shared Netty pool. With `maxStreams=1` this is the cap on simultaneous bidirectional streams. | +| `maxRetries` | No | `5` | Max retries on transient AWS errors (throttling, pool-exhausted, transient connect/timeout). Set to `0` to disable internal retry. Terminal errors (auth, validation) bypass this. | +| `initialBackoff` | No | `100ms` | First backoff delay applied after the initial failure. | +| `maxBackoff` | No | `5s` | Cap on the per-attempt backoff delay regardless of multiplier. | +| `backoffMultiplier` | No | `2.0` | Exponential growth factor between retry attempts. Must be `>= 1.0`. | +| `retryBudget` | No | `30s` | Total wall-clock cap across all retry attempts before giving up and surfacing the error to listeners. | ```java SageMakerConfig config = SageMakerConfig.builder() @@ -130,6 +135,36 @@ SageMakerConfig config = SageMakerConfig.builder() .build(); ``` +#### Retry & storm absorption + +Transient AWS-side failures (`ThrottlingException`, connection-pool exhaustion, transient +connect/timeout failures) are absorbed by the transport itself: classified as retryable, retried +with exponential backoff up to `maxRetries` and `retryBudget`, with messages enqueued during the +reset window persisted across the reconnect so audio isn't dropped. Only **terminal** errors (auth, +validation) and budget-exhausted retryable errors propagate to `transport.onError(...)` and reach +the application's error handler. + +This means the SDK's wrapper-level reconnect (`ReconnectingWebSocketListener`) would compound the +plugin's internal retries into a Throttling-on-Throttling storm under burst load, so the plugin +declares `ReconnectOptions.builder().maxRetries(0).build()` via the +`DeepgramTransportFactory.reconnectOptions()` hook. The SDK applies it automatically when it sees +a `transportFactory` in use; no user wiring required. + +To tune retry behavior: + +```java +SageMakerConfig config = SageMakerConfig.builder() + .endpointName("my-deepgram-endpoint") + .maxRetries(10) + .initialBackoff(Duration.ofMillis(200)) + .maxBackoff(Duration.ofSeconds(10)) + .retryBudget(Duration.ofMinutes(1)) + .build(); +``` + +Set `maxRetries(0)` to disable internal retry entirely (every transient AWS error then surfaces +immediately to the application). + ### Custom AWS Client For custom credential providers, proxy configuration, or testing: diff --git a/sagemaker-transport/build.gradle b/sagemaker-transport/build.gradle index e1e0a0a..544aa30 100755 --- a/sagemaker-transport/build.gradle +++ b/sagemaker-transport/build.gradle @@ -1,6 +1,6 @@ dependencies { // Deepgram Java SDK — provides DeepgramTransport / DeepgramTransportFactory interfaces - api 'com.deepgram:deepgram-java-sdk:0.2.1' + api 'com.deepgram:deepgram-java-sdk:0.3.0' // AWS SDK v2 — SageMaker Runtime HTTP/2 bidirectional streaming api platform('software.amazon.awssdk:bom:2.42.0') diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java index ee5b472..3d9fe6c 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java @@ -33,6 +33,21 @@ public class SageMakerConfig { */ public static final int DEFAULT_MAX_CONCURRENCY = 500; + /** Max retries on transient AWS errors (throttling, pool-exhausted, transient connect) per stream. */ + public static final int DEFAULT_MAX_RETRIES = 5; + + /** First backoff delay after the initial failure. */ + public static final Duration DEFAULT_INITIAL_BACKOFF = Duration.ofMillis(100); + + /** Cap on the per-attempt backoff delay regardless of multiplier. */ + public static final Duration DEFAULT_MAX_BACKOFF = Duration.ofSeconds(5); + + /** Exponential growth factor applied between retry attempts. */ + public static final double DEFAULT_BACKOFF_MULTIPLIER = 2.0; + + /** Total wall-clock budget across all retry attempts before giving up and surfacing the error. */ + public static final Duration DEFAULT_RETRY_BUDGET = Duration.ofSeconds(30); + private final String endpointName; private final Region region; private final String contentType; @@ -41,6 +56,11 @@ public class SageMakerConfig { private final Duration connectionAcquireTimeout; private final Duration subscriptionTimeout; private final int maxConcurrency; + private final int maxRetries; + private final Duration initialBackoff; + private final Duration maxBackoff; + private final double backoffMultiplier; + private final Duration retryBudget; private SageMakerConfig(Builder builder) { this.endpointName = builder.endpointName; @@ -51,6 +71,11 @@ private SageMakerConfig(Builder builder) { this.connectionAcquireTimeout = builder.connectionAcquireTimeout; this.subscriptionTimeout = builder.subscriptionTimeout; this.maxConcurrency = builder.maxConcurrency; + this.maxRetries = builder.maxRetries; + this.initialBackoff = builder.initialBackoff; + this.maxBackoff = builder.maxBackoff; + this.backoffMultiplier = builder.backoffMultiplier; + this.retryBudget = builder.retryBudget; } public String endpointName() { return endpointName; } @@ -61,6 +86,11 @@ private SageMakerConfig(Builder builder) { public Duration connectionAcquireTimeout() { return connectionAcquireTimeout; } public Duration subscriptionTimeout() { return subscriptionTimeout; } public int maxConcurrency() { return maxConcurrency; } + public int maxRetries() { return maxRetries; } + public Duration initialBackoff() { return initialBackoff; } + public Duration maxBackoff() { return maxBackoff; } + public double backoffMultiplier() { return backoffMultiplier; } + public Duration retryBudget() { return retryBudget; } public static Builder builder() { return new Builder(); @@ -75,6 +105,11 @@ public static class Builder { private Duration connectionAcquireTimeout = DEFAULT_CONNECTION_ACQUIRE_TIMEOUT; private Duration subscriptionTimeout = DEFAULT_SUBSCRIPTION_TIMEOUT; private int maxConcurrency = DEFAULT_MAX_CONCURRENCY; + private int maxRetries = DEFAULT_MAX_RETRIES; + private Duration initialBackoff = DEFAULT_INITIAL_BACKOFF; + private Duration maxBackoff = DEFAULT_MAX_BACKOFF; + private double backoffMultiplier = DEFAULT_BACKOFF_MULTIPLIER; + private Duration retryBudget = DEFAULT_RETRY_BUDGET; public Builder endpointName(String endpointName) { this.endpointName = endpointName; @@ -154,10 +189,64 @@ public Builder maxConcurrency(int maxConcurrency) { return this; } + /** + * Max retries on transient AWS errors per stream invocation. Set to {@code 0} to disable + * internal retry. Transient errors include throttling, connection-pool exhaustion, and + * transient connect/timeout failures; terminal errors (auth, validation) bypass this and + * surface to the application immediately. + */ + public Builder maxRetries(int maxRetries) { + if (maxRetries < 0) { + throw new IllegalArgumentException("maxRetries must be non-negative"); + } + this.maxRetries = maxRetries; + return this; + } + + /** First backoff delay applied after the initial failure. */ + public Builder initialBackoff(Duration initialBackoff) { + if (initialBackoff == null || initialBackoff.isNegative() || initialBackoff.isZero()) { + throw new IllegalArgumentException("initialBackoff must be positive"); + } + this.initialBackoff = initialBackoff; + return this; + } + + /** Cap on the per-attempt backoff delay regardless of multiplier. */ + public Builder maxBackoff(Duration maxBackoff) { + if (maxBackoff == null || maxBackoff.isNegative() || maxBackoff.isZero()) { + throw new IllegalArgumentException("maxBackoff must be positive"); + } + this.maxBackoff = maxBackoff; + return this; + } + + /** Exponential growth factor applied between retry attempts. Must be {@code >= 1.0}. */ + public Builder backoffMultiplier(double backoffMultiplier) { + if (backoffMultiplier < 1.0) { + throw new IllegalArgumentException("backoffMultiplier must be >= 1.0"); + } + this.backoffMultiplier = backoffMultiplier; + return this; + } + + /** Total wall-clock budget across all retry attempts before giving up. */ + public Builder retryBudget(Duration retryBudget) { + if (retryBudget == null || retryBudget.isNegative() || retryBudget.isZero()) { + throw new IllegalArgumentException("retryBudget must be positive"); + } + this.retryBudget = retryBudget; + return this; + } + public SageMakerConfig build() { if (endpointName == null || endpointName.isBlank()) { throw new IllegalArgumentException("endpointName is required"); } + if (initialBackoff.compareTo(maxBackoff) > 0) { + throw new IllegalArgumentException("initialBackoff (" + initialBackoff + + ") must not exceed maxBackoff (" + maxBackoff + ")"); + } return new SageMakerConfig(this); } } diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index c7ab977..bb7b7cf 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -2,13 +2,17 @@ import com.deepgram.core.transport.DeepgramTransport; +import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.sagemakerruntimehttp2.SageMakerRuntimeHttp2AsyncClient; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.InvokeEndpointWithBidirectionalStreamRequest; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.InvokeEndpointWithBidirectionalStreamResponseHandler; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.RequestStreamEvent; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.ResponsePayloadPart; +import java.io.IOException; +import java.net.ConnectException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -16,7 +20,9 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import org.reactivestreams.Publisher; @@ -59,6 +65,14 @@ public class SageMakerTransport implements DeepgramTransport { private volatile CompletableFuture streamFuture; private final Object connectLock = new Object(); + // Hoisted out of StreamPublisher so messages queued during an internal reset survive into the + // next stream attempt instead of being dropped with the discarded publisher. + private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); + + // Retry budget tracking. Reset to 0 once a stream successfully establishes (subscription). + private final AtomicInteger retryAttempt = new AtomicInteger(0); + private volatile long retryWindowStart = 0L; + SageMakerTransport( SageMakerRuntimeHttp2AsyncClient smClient, SageMakerConfig config, @@ -71,62 +85,180 @@ public class SageMakerTransport implements DeepgramTransport { } /** - * Establish the bidirectional stream if not already connected. - * Blocks until the SDK has subscribed to the event publisher. + * Establish the bidirectional stream if not already connected. Blocks until the AWS SDK + * subscribes to the event publisher. + * + *

Internally retries with exponential backoff on transient AWS errors (throttling, + * connection-pool exhaustion, transient connect/timeout failures) bounded by + * {@link SageMakerConfig#maxRetries()} and {@link SageMakerConfig#retryBudget()}. Terminal + * errors (auth, validation) and budget exhaustion bubble out and surface to {@code errorListeners} + * via the caller's {@code send*} path. */ private void ensureConnected() { if (connected.get()) return; synchronized (connectLock) { if (connected.get()) return; - inputPublisher = new StreamPublisher(); + if (retryWindowStart == 0L) { + retryWindowStart = System.currentTimeMillis(); + } - InvokeEndpointWithBidirectionalStreamRequest.Builder requestBuilder = - InvokeEndpointWithBidirectionalStreamRequest.builder() - .endpointName(config.endpointName()) - .modelInvocationPath(invocationPath); - if (queryString != null && !queryString.isEmpty()) { - requestBuilder.modelQueryString(queryString); + Throwable lastError = null; + while (true) { + try { + attemptConnect(); + // Success: reset retry budget for any future internal reconnects on this transport. + retryAttempt.set(0); + retryWindowStart = 0L; + connected.set(true); + return; + } catch (Throwable t) { + lastError = t; + Classification c = classify(t); + int attempt = retryAttempt.get(); + long elapsed = System.currentTimeMillis() - retryWindowStart; + boolean budgetLeft = attempt < config.maxRetries() + && elapsed < config.retryBudget().toMillis(); + if (c == Classification.TERMINAL || !budgetLeft) { + if (t instanceof RuntimeException) throw (RuntimeException) t; + throw new RuntimeException(t); + } + long backoff = computeBackoff(attempt); + retryAttempt.incrementAndGet(); + try { + Thread.sleep(backoff); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted during retry backoff after " + (attempt + 1) + " attempts", lastError); + } + } } - InvokeEndpointWithBidirectionalStreamRequest request = requestBuilder.build(); - - InvokeEndpointWithBidirectionalStreamResponseHandler handler = - InvokeEndpointWithBidirectionalStreamResponseHandler.builder() - .onResponse(response -> { }) - .subscriber(InvokeEndpointWithBidirectionalStreamResponseHandler - .Visitor.builder() - .onPayloadPart(this::handlePayloadPart) - .build()) - .onError(error -> { - if (closeSent.get()) { - // Model idle timeout after CloseStream — treat as normal close. - inputPublisher.complete(); - notifyClose(1000, "Normal"); - } else { - for (Consumer l : errorListeners) { - l.accept(error); - } - } - }) - .onComplete(() -> { - notifyClose(1000, "Normal"); - }) - .build(); - - streamFuture = smClient.invokeEndpointWithBidirectionalStream( - request, inputPublisher, handler); - - // Wait for the SDK to subscribe to our publisher before sending events + } + } + + /** Single connect attempt — invokes the bidi stream and waits for subscription. */ + private void attemptConnect() throws TimeoutException, InterruptedException { + StreamPublisher publisher = new StreamPublisher(pending); + inputPublisher = publisher; + + InvokeEndpointWithBidirectionalStreamRequest.Builder requestBuilder = + InvokeEndpointWithBidirectionalStreamRequest.builder() + .endpointName(config.endpointName()) + .modelInvocationPath(invocationPath); + if (queryString != null && !queryString.isEmpty()) { + requestBuilder.modelQueryString(queryString); + } + InvokeEndpointWithBidirectionalStreamRequest request = requestBuilder.build(); + + InvokeEndpointWithBidirectionalStreamResponseHandler handler = + InvokeEndpointWithBidirectionalStreamResponseHandler.builder() + .onResponse(response -> { }) + .subscriber(InvokeEndpointWithBidirectionalStreamResponseHandler + .Visitor.builder() + .onPayloadPart(this::handlePayloadPart) + .build()) + .onError(this::handleStreamError) + .onComplete(() -> notifyClose(1000, "Normal")) + .build(); + + streamFuture = smClient.invokeEndpointWithBidirectionalStream(request, publisher, handler); + + if (!publisher.awaitSubscription(config.subscriptionTimeout().toMillis(), TimeUnit.MILLISECONDS)) { + // Subscription never landed — treat as a transient connect failure so the retry loop + // classifies and (if budget allows) tries again on a fresh stream. try { - inputPublisher.awaitSubscription( - config.subscriptionTimeout().toMillis(), TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted waiting for stream subscription", e); + streamFuture.cancel(true); + } catch (Throwable ignored) { + // best-effort + } + throw new TimeoutException( + "Timed out waiting for AWS SDK to subscribe to stream publisher after " + + config.subscriptionTimeout()); + } + } + + /** + * Async error gate: AWS SDK reports stream errors via the response handler's onError. Classify + * here so transient AWS errors trigger an internal reset (next send re-enters ensureConnected + * via the retry loop) instead of bubbling straight to {@code errorListeners}. + */ + private void handleStreamError(Throwable error) { + if (closeSent.get()) { + // Model idle timeout after CloseStream — treat as normal close. + if (inputPublisher != null) inputPublisher.complete(); + notifyClose(1000, "Normal"); + return; + } + + Classification c = classify(error); + int attempt = retryAttempt.get(); + long windowStart = retryWindowStart; + long elapsed = windowStart == 0L ? 0L : System.currentTimeMillis() - windowStart; + boolean budgetLeft = attempt < config.maxRetries() + && elapsed < config.retryBudget().toMillis(); + + if (c == Classification.RETRYABLE && budgetLeft) { + // Internal reset: drop current stream, mark disconnected. Next send re-enters + // ensureConnected → attemptConnect, which will drain `pending` into the new stream. + connected.set(false); + if (inputPublisher != null) inputPublisher.complete(); + if (streamFuture != null) { + try { + streamFuture.cancel(true); + } catch (Throwable ignored) { + // best-effort + } } + return; + } + + // Terminal or budget-exhausted: surface to listeners. + for (Consumer l : errorListeners) { + l.accept(error); + } + } + + private long computeBackoff(int attempt) { + long initial = config.initialBackoff().toMillis(); + long max = config.maxBackoff().toMillis(); + double scaled = initial * Math.pow(config.backoffMultiplier(), attempt); + if (scaled > max || Double.isInfinite(scaled)) { + return max; + } + return Math.max(initial, (long) scaled); + } - connected.set(true); + enum Classification { RETRYABLE, TERMINAL } + + /** + * Classify an AWS-side exception as transient (retry internally, don't surface) vs terminal + * (surface to {@code errorListeners}). Walks the cause chain so SDK-wrapped exceptions are + * inspected too. + */ + static Classification classify(Throwable error) { + for (Throwable t = error; t != null; t = t.getCause()) { + if (t instanceof TimeoutException) return Classification.RETRYABLE; + if (t instanceof ConnectException) return Classification.RETRYABLE; + if (t instanceof IOException) return Classification.RETRYABLE; + if (t instanceof AwsServiceException) { + AwsServiceException ase = (AwsServiceException) t; + int status = ase.statusCode(); + if (status == 429 || (status >= 500 && status < 600)) return Classification.RETRYABLE; + String code = ase.awsErrorDetails() != null ? ase.awsErrorDetails().errorCode() : null; + if (code != null && code.toLowerCase().contains("throttl")) return Classification.RETRYABLE; + return Classification.TERMINAL; + } + if (t instanceof SdkException) { + String msg = t.getMessage() == null ? "" : t.getMessage().toLowerCase(); + if (msg.contains("acquire") || msg.contains("pool") || msg.contains("throttl") + || msg.contains("timeout")) { + return Classification.RETRYABLE; + } + } + if (t == t.getCause()) break; } + return Classification.TERMINAL; } @Override @@ -224,6 +356,8 @@ public boolean isOpen() { @Override public void close() { if (!open.compareAndSet(true, false)) return; + // Terminal close — drop any messages that were queued during a reset window. + pending.clear(); if (inputPublisher != null) { inputPublisher.complete(); } @@ -235,13 +369,21 @@ public void close() { /** * Reactive Streams publisher that buffers events until the SDK subscribes, * then delivers them in order. After subscription, events are forwarded immediately. + * + *

The {@code pending} queue is owned by the enclosing {@link SageMakerTransport} and + * shared across reconnect cycles, so events queued during an internal reset are drained + * onto whichever stream subscribes next. */ static class StreamPublisher implements Publisher { private volatile Subscriber subscriber; private final AtomicBoolean completed = new AtomicBoolean(false); - private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue pending; private final CountDownLatch subscribed = new CountDownLatch(1); + StreamPublisher(ConcurrentLinkedQueue sharedPending) { + this.pending = sharedPending; + } + @Override public void subscribe(Subscriber s) { this.subscriber = s; @@ -256,7 +398,8 @@ public void cancel() { completed.set(true); } }); - // Flush any events that were queued before subscription + // Flush any events that were queued before subscription (including events that + // survived a previous internal reset). RequestStreamEvent event; while ((event = pending.poll()) != null) { s.onNext(event); @@ -270,7 +413,7 @@ void send(RequestStreamEvent event) { if (s != null) { s.onNext(event); } else { - // Buffer until the SDK subscribes + // Buffer until the SDK subscribes (this stream or the next one after a reset) pending.add(event); } } @@ -284,8 +427,11 @@ void complete() { } } - void awaitSubscription(long timeout, TimeUnit unit) throws InterruptedException { - subscribed.await(timeout, unit); + /** + * @return {@code true} if subscription happened within the timeout, {@code false} on timeout. + */ + boolean awaitSubscription(long timeout, TimeUnit unit) throws InterruptedException { + return subscribed.await(timeout, unit); } } } diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index ac5517b..0a7ef14 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -1,5 +1,6 @@ package com.deepgram.sagemaker; +import com.deepgram.core.ReconnectingWebSocketListener; import com.deepgram.core.transport.DeepgramTransport; import com.deepgram.core.transport.DeepgramTransportFactory; @@ -80,6 +81,19 @@ public DeepgramTransport create(String url, Map headers) { return new SageMakerTransport(smClient, config, invocationPath, queryString); } + /** + * Disable the SDK's wrapper-level reconnect loop. {@link SageMakerTransport} owns its own + * retry/backoff/classification (see {@link SageMakerConfig#maxRetries()}, + * {@link SageMakerConfig#retryBudget()}); wrapping it in another retry layer compounds + * transient AWS errors into Throttling-on-Throttling storms under burst load. + */ + @Override + public ReconnectingWebSocketListener.ReconnectOptions reconnectOptions() { + return ReconnectingWebSocketListener.ReconnectOptions.builder() + .maxRetries(0) + .build(); + } + /** Shut down the underlying AWS SDK client. */ public void shutdown() { smClient.close(); diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java index 56d2204..4ea9ec5 100755 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.*; import static org.mockito.Mockito.*; +import com.deepgram.core.ReconnectingWebSocketListener; import com.deepgram.core.transport.DeepgramTransport; import java.time.Duration; import java.util.Map; @@ -121,4 +122,71 @@ void configRejectsNonPositiveMaxConcurrency() { IllegalArgumentException.class, () -> SageMakerConfig.builder().endpointName("e").maxConcurrency(-1).build()); } + + @Test + void retryConfigDefaults() { + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + + assertEquals(5, config.maxRetries()); + assertEquals(Duration.ofMillis(100), config.initialBackoff()); + assertEquals(Duration.ofSeconds(5), config.maxBackoff()); + assertEquals(2.0, config.backoffMultiplier()); + assertEquals(Duration.ofSeconds(30), config.retryBudget()); + } + + @Test + void retryConfigAcceptsCustomValues() { + SageMakerConfig config = + SageMakerConfig.builder() + .endpointName("my-endpoint") + .maxRetries(10) + .initialBackoff(Duration.ofMillis(50)) + .maxBackoff(Duration.ofSeconds(20)) + .backoffMultiplier(3.0) + .retryBudget(Duration.ofMinutes(2)) + .build(); + + assertEquals(10, config.maxRetries()); + assertEquals(Duration.ofMillis(50), config.initialBackoff()); + assertEquals(Duration.ofSeconds(20), config.maxBackoff()); + assertEquals(3.0, config.backoffMultiplier()); + assertEquals(Duration.ofMinutes(2), config.retryBudget()); + } + + @Test + void retryConfigValidates() { + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxRetries(-1).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").initialBackoff(Duration.ZERO).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").maxBackoff(Duration.ofSeconds(-1)).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").backoffMultiplier(0.5).build()); + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder().endpointName("e").retryBudget(null).build()); + } + + @Test + void retryConfigRejectsInitialGreaterThanMax() { + assertThrows(IllegalArgumentException.class, + () -> SageMakerConfig.builder() + .endpointName("e") + .initialBackoff(Duration.ofSeconds(10)) + .maxBackoff(Duration.ofSeconds(5)) + .build()); + } + + @Test + void factoryDeclaresMaxRetriesZeroForReconnectOptions() { + // The plugin declares maxRetries(0) so the SDK's wrapper-level reconnect doesn't compound + // SageMaker's internal retries into a storm. + SageMakerConfig config = SageMakerConfig.builder().endpointName("my-endpoint").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + SageMakerTransportFactory factory = new SageMakerTransportFactory(config, mockClient); + + ReconnectingWebSocketListener.ReconnectOptions opts = factory.reconnectOptions(); + assertNotNull(opts); + assertEquals(0, opts.maxRetries); + } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java new file mode 100644 index 0000000..52ec77e --- /dev/null +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -0,0 +1,231 @@ +package com.deepgram.sagemaker; + +import static org.junit.jupiter.api.Assertions.*; + +import com.deepgram.sagemaker.SageMakerTransport.StreamPublisher; +import java.io.IOException; +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.services.sagemakerruntimehttp2.model.RequestStreamEvent; + +/** + * Unit tests for {@link SageMakerTransport}'s retry/classification logic and the hoisted pending-queue + * machinery. End-to-end retry against the AWS reactive-streams handler isn't covered here (the + * handler indirection makes it hard to deterministically stub); those paths are exercised by the + * burst test in the README. + */ +class SageMakerTransportRetryTest { + + // Helpers live on the outer class because @Nested inner classes can't have static members on Java 11. + + private static RequestStreamEvent payloadEvent(String s) { + return RequestStreamEvent.payloadPartBuilder() + .bytes(SdkBytes.fromUtf8String(s)) + .build(); + } + + private static Subscriber noopSubscriber() { + return new Subscriber() { + @Override + public void onSubscribe(Subscription s) {} + @Override + public void onNext(RequestStreamEvent e) {} + @Override + public void onError(Throwable t) {} + @Override + public void onComplete() {} + }; + } + + private static class CapturingSubscriber implements Subscriber { + final List received = new ArrayList<>(); + final CountDownLatch completed = new CountDownLatch(1); + + @Override + public void onSubscribe(Subscription s) {} + @Override + public void onNext(RequestStreamEvent e) { + received.add(e); + } + @Override + public void onError(Throwable t) {} + @Override + public void onComplete() { + completed.countDown(); + } + } + + @Nested + @DisplayName("classify(Throwable)") + class ClassifyTests { + @Test + @DisplayName("TimeoutException is retryable") + void timeoutIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new TimeoutException("acquire timeout"))); + } + + @Test + @DisplayName("ConnectException is retryable") + void connectExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new ConnectException("connection refused"))); + } + + @Test + @DisplayName("IOException is retryable") + void ioExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new IOException("network error"))); + } + + @Test + @DisplayName("AWS 429 (Too Many Requests) is retryable") + void aws429IsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("Rate exceeded") + .statusCode(429) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 5xx is retryable") + void aws5xxIsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("internal") + .statusCode(503) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS error code containing 'throttl' is retryable regardless of status") + void awsThrottlingErrorCodeIsRetryable() { + AwsServiceException ase = AwsServiceException.builder() + .message("Rate exceeded") + .statusCode(400) + .awsErrorDetails(AwsErrorDetails.builder() + .errorCode("ThrottlingException") + .build()) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 401 (Unauthorized) is terminal") + void aws401IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("Forbidden") + .statusCode(401) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 403 (Forbidden) is terminal") + void aws403IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("Forbidden") + .statusCode(403) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("SdkException with 'pool exhausted' message is retryable (defensive belt)") + void sdkExceptionWithPoolKeywordIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Connection pool exhausted") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("Walks the cause chain — IOException wrapped in RuntimeException is retryable") + void walksCauseChain() { + RuntimeException wrapper = new RuntimeException("oops", new IOException("netty")); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(wrapper)); + } + + @Test + @DisplayName("Unknown exception defaults to terminal") + void unknownDefaultsToTerminal() { + assertEquals(SageMakerTransport.Classification.TERMINAL, + SageMakerTransport.classify(new RuntimeException("mystery"))); + } + } + + @Nested + @DisplayName("StreamPublisher") + class StreamPublisherTests { + @Test + @DisplayName("awaitSubscription returns false on timeout") + void awaitSubscriptionTimeout() throws InterruptedException { + ConcurrentLinkedQueue q = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(q); + assertFalse(pub.awaitSubscription(50, TimeUnit.MILLISECONDS), + "no subscriber within timeout — must report false so callers can fail fast"); + } + + @Test + @DisplayName("awaitSubscription returns true once a subscriber arrives") + void awaitSubscriptionSuccess() throws InterruptedException { + ConcurrentLinkedQueue q = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(q); + pub.subscribe(noopSubscriber()); + assertTrue(pub.awaitSubscription(100, TimeUnit.MILLISECONDS)); + } + + @Test + @DisplayName("Pending queue is shared across publisher instances — surviving an internal reset") + void pendingPersistsAcrossPublishers() { + ConcurrentLinkedQueue shared = new ConcurrentLinkedQueue<>(); + + // First publisher: send 3 events before any subscriber arrives, then complete (simulating reset). + StreamPublisher first = new StreamPublisher(shared); + first.send(payloadEvent("a")); + first.send(payloadEvent("b")); + first.send(payloadEvent("c")); + first.complete(); + + assertEquals(3, shared.size(), "events must remain in the shared queue across publishers"); + + // Second publisher: subscribes and drains the queued events. + StreamPublisher second = new StreamPublisher(shared); + CapturingSubscriber sub = new CapturingSubscriber(); + second.subscribe(sub); + + assertEquals(3, sub.received.size(), "events queued before reset must arrive on the new stream"); + assertTrue(shared.isEmpty(), "queue must be drained after subscription"); + } + + @Test + @DisplayName("send() with subscriber present forwards immediately, no queueing") + void sendForwardsWhenSubscribed() { + ConcurrentLinkedQueue shared = new ConcurrentLinkedQueue<>(); + StreamPublisher pub = new StreamPublisher(shared); + CapturingSubscriber sub = new CapturingSubscriber(); + pub.subscribe(sub); + + pub.send(payloadEvent("hello")); + + assertEquals(1, sub.received.size()); + assertTrue(shared.isEmpty()); + } + } +} From 113309db1b0ea4d0e8bc7d84c5e0b3708fdcd757 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Fri, 1 May 2026 17:32:39 -0700 Subject: [PATCH 03/10] Manage throttling errors when handling payload parts rather than just connection establishment --- README.md | 29 ++++ .../sagemaker/SageMakerTransport.java | 127 +++++++++++++++++- .../sagemaker/SageMakerTransportFactory.java | 94 +++++++++++-- .../SageMakerTransportFactoryTest.java | 124 +++++++++++++++++ .../SageMakerTransportRetryTest.java | 62 +++++++++ 5 files changed, 418 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index f7c0e52..95e1126 100755 --- a/README.md +++ b/README.md @@ -165,6 +165,35 @@ SageMakerConfig config = SageMakerConfig.builder() Set `maxRetries(0)` to disable internal retry entirely (every transient AWS error then surfaces immediately to the application). +#### Connection-pool sharing + +The default `new SageMakerTransportFactory(config)` constructor backs every factory instance with +a **process-wide shared** `SageMakerRuntimeHttp2AsyncClient`, keyed by the parts of +`SageMakerConfig` that affect the underlying Netty HTTP/2 client (region, max concurrency, +connect/acquire timeouts). Multiple factories built with the same config fingerprint reuse one +Netty event loop group and one connection pool — so naive code that constructs a fresh factory +per stream still gets a single, well-behaved client underneath. + +Without sharing, every factory instantiates its own Netty pool, and a burst of N factories +triggers N simultaneous TLS handshakes from N distinct Netty clients against the same SageMaker +endpoint. Under high concurrency (100+ streams) the SageMaker HTTP/2 frontline silently drops a +large fraction of those streams before they ever reach the model container — verified +end-to-end with CloudWatch logs from a 400-stream burst test against a 1× ml.g6.2xlarge endpoint: +without sharing, ~65% of streams never appeared in the Deepgram container's listen log; with +sharing, the burst behaves the same as the canonical Python load-test harness. + +Lifecycle: + +| Constructor | Client backing | `factory.shutdown()` | +|---|---|---| +| `SageMakerTransportFactory(config)` | shared (lazy-init, keyed by config fingerprint) | no-op — call `SageMakerTransportFactory.shutdownAllSharedClients()` once at app shutdown to release Netty resources | +| `SageMakerTransportFactory(config, smClient)` | caller-provided (BYO, used for testing or custom credential providers) | no-op — caller owns the client lifecycle | + +```java +// At app shutdown — releases all shared Netty pools the plugin lazily created. +Runtime.getRuntime().addShutdownHook(new Thread(SageMakerTransportFactory::shutdownAllSharedClients)); +``` + ### Custom AWS Client For custom credential providers, proxy configuration, or testing: diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index bb7b7cf..fc38cf7 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -16,6 +16,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; @@ -28,6 +29,8 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * {@link DeepgramTransport} implementation that routes Deepgram API calls through @@ -47,6 +50,9 @@ */ public class SageMakerTransport implements DeepgramTransport { + private static final Logger log = LoggerFactory.getLogger(SageMakerTransport.class); + private final String transportId = String.format("%08x", System.identityHashCode(this)); + private final SageMakerRuntimeHttp2AsyncClient smClient; private final SageMakerConfig config; private final String invocationPath; @@ -69,9 +75,15 @@ public class SageMakerTransport implements DeepgramTransport { // next stream attempt instead of being dropped with the discarded publisher. private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); - // Retry budget tracking. Reset to 0 once a stream successfully establishes (subscription). + // Retry budget tracking. Reset to 0 once real downstream data flows back to the application + // (handlePayloadPart) — NOT on subscription success. Subscription succeeds in TLS+HTTP/2 + // setup terms even when the bidi-stream request will be throttled milliseconds later. private final AtomicInteger retryAttempt = new AtomicInteger(0); private volatile long retryWindowStart = 0L; + // Earliest wall-clock at which the next attemptConnect is allowed to proceed. Set by + // handleStreamError so post-subscription throttles (which never reach the ensureConnected + // catch-block backoff) still pace the next attempt. + private volatile long retryNotBeforeMs = 0L; SageMakerTransport( SageMakerRuntimeHttp2AsyncClient smClient, @@ -105,11 +117,34 @@ private void ensureConnected() { Throwable lastError = null; while (true) { + // Honor backoff scheduled by handleStreamError so post-subscription throttles + // pace the next attempt (the catch-block backoff below never fires for those — + // attemptConnect succeeds, the throttle hits later). + long now = System.currentTimeMillis(); + if (retryNotBeforeMs > now) { + long sleepMs = retryNotBeforeMs - now; + log.info("[{}] ensureConnected: honoring scheduled backoff ({}ms) before next attemptConnect", + transportId, sleepMs); + try { + Thread.sleep(sleepMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during scheduled backoff", lastError); + } + } + int attemptBefore = retryAttempt.get(); + log.info("[{}] ensureConnected: starting attemptConnect (attempt={}/{}, elapsed={}ms/{}ms)", + transportId, attemptBefore, config.maxRetries(), + System.currentTimeMillis() - retryWindowStart, config.retryBudget().toMillis()); try { attemptConnect(); - // Success: reset retry budget for any future internal reconnects on this transport. - retryAttempt.set(0); - retryWindowStart = 0L; + // Subscription succeeded — but DO NOT reset retryAttempt here. Subscription + // success only proves TLS+HTTP/2 setup; the actual bidi-stream request can + // still be throttled. We reset retryAttempt only when real downstream data + // arrives, in handlePayloadPart. + log.info("[{}] ensureConnected: attemptConnect SUCCEEDED (attempt={} — counter NOT reset; " + + "waits for handlePayloadPart to confirm real data flow before resetting)", + transportId, attemptBefore); connected.set(true); return; } catch (Throwable t) { @@ -119,11 +154,18 @@ private void ensureConnected() { long elapsed = System.currentTimeMillis() - retryWindowStart; boolean budgetLeft = attempt < config.maxRetries() && elapsed < config.retryBudget().toMillis(); + log.info("[{}] ensureConnected: attemptConnect FAILED — class={} attempt={}/{} elapsed={}ms/{}ms budgetLeft={} err={}", + transportId, c, attempt, config.maxRetries(), elapsed, + config.retryBudget().toMillis(), budgetLeft, summarize(t)); if (c == Classification.TERMINAL || !budgetLeft) { + log.warn("[{}] ensureConnected: SURFACING (class={} budgetLeft={}) — err={}", + transportId, c, budgetLeft, summarize(t)); if (t instanceof RuntimeException) throw (RuntimeException) t; throw new RuntimeException(t); } long backoff = computeBackoff(attempt); + log.info("[{}] ensureConnected: backoff={}ms before retry attempt {}", + transportId, backoff, attempt + 1); retryAttempt.incrementAndGet(); try { Thread.sleep(backoff); @@ -186,6 +228,7 @@ private void attemptConnect() throws TimeoutException, InterruptedException { private void handleStreamError(Throwable error) { if (closeSent.get()) { // Model idle timeout after CloseStream — treat as normal close. + log.info("[{}] handleStreamError: closeSent=true → treating as normal close", transportId); if (inputPublisher != null) inputPublisher.complete(); notifyClose(1000, "Normal"); return; @@ -198,9 +241,27 @@ private void handleStreamError(Throwable error) { boolean budgetLeft = attempt < config.maxRetries() && elapsed < config.retryBudget().toMillis(); + log.info("[{}] handleStreamError: class={} attempt={}/{} elapsed={}ms/{}ms budgetLeft={} err={}", + transportId, c, attempt, config.maxRetries(), elapsed, + config.retryBudget().toMillis(), budgetLeft, summarize(error)); + if (c == Classification.RETRYABLE && budgetLeft) { - // Internal reset: drop current stream, mark disconnected. Next send re-enters - // ensureConnected → attemptConnect, which will drain `pending` into the new stream. + // Internal reset: drop current stream, mark disconnected, and SCHEDULE a backoff so + // the next ensureConnected pause-then-reconnect rather than immediately hammering + // the AWS frontline. We don't sleep here (this runs on a Netty event-loop thread); + // we just set retryNotBeforeMs and the next ensureConnected honors it. + // + // We also advance retryAttempt and start the budget window if not already started, + // so under repeated post-subscription throttles the retry budget actually gets + // consumed and eventually surfaces a terminal error to the application. + if (retryWindowStart == 0L) { + retryWindowStart = System.currentTimeMillis(); + } + int attemptForBackoff = retryAttempt.getAndIncrement(); + long backoff = computeBackoff(attemptForBackoff); + retryNotBeforeMs = System.currentTimeMillis() + backoff; + log.info("[{}] handleStreamError: RETRYABLE → internal reset, attempt {} → scheduled backoff {}ms", + transportId, attemptForBackoff + 1, backoff); connected.set(false); if (inputPublisher != null) inputPublisher.complete(); if (streamFuture != null) { @@ -214,11 +275,31 @@ private void handleStreamError(Throwable error) { } // Terminal or budget-exhausted: surface to listeners. + log.warn("[{}] handleStreamError: SURFACING (class={} budgetLeft={}) → invoking {} errorListener(s)", + transportId, c, budgetLeft, errorListeners.size()); for (Consumer l : errorListeners) { l.accept(error); } } + /** One-line summary of a Throwable for log lines. */ + private static String summarize(Throwable t) { + if (t == null) return "null"; + String msg = t.getMessage(); + if (msg == null) msg = ""; + if (msg.length() > 160) msg = msg.substring(0, 157) + "..."; + StringBuilder sb = new StringBuilder(t.getClass().getSimpleName()).append(": ").append(msg); + if (t instanceof AwsServiceException) { + AwsServiceException ase = (AwsServiceException) t; + sb.append(" [status=").append(ase.statusCode()); + if (ase.awsErrorDetails() != null && ase.awsErrorDetails().errorCode() != null) { + sb.append(" code=").append(ase.awsErrorDetails().errorCode()); + } + sb.append("]"); + } + return sb.toString(); + } + private long computeBackoff(int attempt) { long initial = config.initialBackoff().toMillis(); long max = config.maxBackoff().toMillis(); @@ -238,6 +319,14 @@ enum Classification { RETRYABLE, TERMINAL } */ static Classification classify(Throwable error) { for (Throwable t = error; t != null; t = t.getCause()) { + // CancellationException is RETRYABLE because the cancel was either (a) induced by our + // own retry-reset path (in which case the next attempt will run cleanly) or (b) caused + // by some upstream terminal condition (in which case the retry attempt will hit the + // underlying error and classify it as TERMINAL on its own). Either way, treating the + // cancel itself as TERMINAL would surface a self-inflicted error to listeners. + // Covers AWS Netty's FutureCancelledException too — it wraps CancellationException as + // its cause. + if (t instanceof CancellationException) return Classification.RETRYABLE; if (t instanceof TimeoutException) return Classification.RETRYABLE; if (t instanceof ConnectException) return Classification.RETRYABLE; if (t instanceof IOException) return Classification.RETRYABLE; @@ -255,6 +344,19 @@ static Classification classify(Throwable error) { || msg.contains("timeout")) { return Classification.RETRYABLE; } + // Credential-loading failures from the AWS SDK provider chain + // (`SdkClientException: Unable to load credentials from any of the providers...`). + // Retry only when at least one provider hit a transient AWS-side condition — + // typically AWS IAM Identity Center (SSO) or STS rate-limiting credential + // refreshes under burst load (Status Code: 429), or a 5xx from a credential + // backend. Pure misconfig (no provider has credentials at all) still surfaces + // fast — retrying won't conjure credentials that don't exist. + if (msg.contains("unable to load credentials") + && (msg.contains("status code: 429") + || msg.contains("status code: 5") + || msg.contains("rate exceeded"))) { + return Classification.RETRYABLE; + } } if (t == t.getCause()) break; } @@ -308,6 +410,19 @@ private void notifyClose(int code, String reason) { private void handlePayloadPart(ResponsePayloadPart part) { byte[] bytes = part.bytes().asByteArray(); + // First downstream data on this transport (or first since the last retry-loop reset) → + // the stream is genuinely working end-to-end, not just subscription-established. Reset + // the retry budget so a future transient failure gets a fresh budget. + if (retryAttempt.get() != 0 || retryWindowStart != 0L || retryNotBeforeMs != 0L) { + log.info("[{}] handlePayloadPart: data received ({}B) → resetting retry counters " + + "(was attempt={}, windowStart={}, notBeforeMs={})", + transportId, bytes.length, + retryAttempt.get(), retryWindowStart, retryNotBeforeMs); + retryAttempt.set(0); + retryWindowStart = 0L; + retryNotBeforeMs = 0L; + } + // JSON messages start with '{"' (0x7B 0x22). Checking two bytes avoids // false positives from binary audio chunks that happen to start with 0x7B. if (bytes.length > 1 && bytes[0] == '{' && bytes[1] == '"') { diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index 0a7ef14..ae70e5b 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -11,6 +11,7 @@ import java.net.URI; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; /** * Factory that creates SageMaker bidirectional streaming transports. @@ -34,15 +35,73 @@ * .transportFactory(factory) * .build(); * } + * + *

Connection-pool sharing

+ * + *

The default constructor backs the factory with a process-wide shared + * {@link SageMakerRuntimeHttp2AsyncClient} keyed by the parts of {@link SageMakerConfig} that + * affect the underlying Netty HTTP/2 client (region, max concurrency, connect/acquire timeouts). + * Multiple factories built with the same config fingerprint reuse one Netty event loop group and + * one connection pool — so naive code that constructs a fresh factory per stream still gets a + * single, well-behaved client underneath. + * + *

Without sharing, every factory instantiates its own Netty pool, and a burst of N factories + * triggers N simultaneous TLS handshakes from N distinct Netty clients against the same SageMaker + * endpoint — the SageMaker HTTP/2 frontline silently drops a large fraction of those streams + * before they ever reach the model container. Sharing matches the behavior of the canonical + * Python load-test harness, which has been verified to handle 400+ concurrent streams cleanly. + * + *

Lifecycle: + *

*/ public class SageMakerTransportFactory implements DeepgramTransportFactory { + /** + * Process-wide shared clients keyed by config fingerprint. Subsequent factories with the + * same fingerprint reuse the existing client, so one Netty event loop group + connection + * pool serves all of them. + */ + private static final ConcurrentHashMap SHARED_CLIENTS = + new ConcurrentHashMap<>(); + private final SageMakerConfig config; private final SageMakerRuntimeHttp2AsyncClient smClient; public SageMakerTransportFactory(SageMakerConfig config) { this.config = config; - this.smClient = SageMakerRuntimeHttp2AsyncClient.builder() + this.smClient = SHARED_CLIENTS.computeIfAbsent( + sharedClientKey(config), + k -> buildClient(config)); + } + + /** + * Create with a pre-configured SageMaker HTTP/2 client (for testing or custom credential + * providers). The provided client is not closed by {@link #shutdown()}; + * the caller owns its lifecycle. + */ + public SageMakerTransportFactory(SageMakerConfig config, SageMakerRuntimeHttp2AsyncClient smClient) { + this.config = config; + this.smClient = smClient; + } + + /** + * Cache key for the shared-client pool. Includes only the fields that affect the underlying + * Netty client; per-stream config (endpointName, contentType, retry knobs) doesn't. + */ + private static String sharedClientKey(SageMakerConfig c) { + return c.region().id() + + "|" + c.maxConcurrency() + + "|" + c.connectionTimeout().toMillis() + + "|" + c.connectionAcquireTimeout().toMillis(); + } + + private static SageMakerRuntimeHttp2AsyncClient buildClient(SageMakerConfig config) { + return SageMakerRuntimeHttp2AsyncClient.builder() .region(config.region()) .httpClientBuilder( NettyNioAsyncHttpClient.builder() @@ -59,15 +118,6 @@ public SageMakerTransportFactory(SageMakerConfig config) { .build(); } - /** - * Create with a pre-configured SageMaker HTTP/2 client (for testing or - * custom credential providers). - */ - public SageMakerTransportFactory(SageMakerConfig config, SageMakerRuntimeHttp2AsyncClient smClient) { - this.config = config; - this.smClient = smClient; - } - @Override public DeepgramTransport create(String url, Map headers) { // Parse the WebSocket URL to extract invocation path and query string. @@ -94,8 +144,28 @@ public ReconnectingWebSocketListener.ReconnectOptions reconnectOptions() { .build(); } - /** Shut down the underlying AWS SDK client. */ + /** + * No-op for factories backed by the shared client pool or a caller-owned (BYO) client. + * Use {@link #shutdownAllSharedClients()} to close shared clients at app shutdown; close + * your own client directly if you provided one. + */ public void shutdown() { - smClient.close(); + // Intentionally no-op. See class-level Javadoc for lifecycle semantics. + } + + /** + * Close all process-wide shared {@link SageMakerRuntimeHttp2AsyncClient} instances. Call + * once at app shutdown if you want to release Netty resources cleanly before JVM exit. + * Subsequent default-constructor factories will lazily build new shared clients. + */ + public static void shutdownAllSharedClients() { + for (SageMakerRuntimeHttp2AsyncClient client : SHARED_CLIENTS.values()) { + try { + client.close(); + } catch (Exception ignored) { + // best-effort + } + } + SHARED_CLIENTS.clear(); } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java index 4ea9ec5..e9a5871 100755 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java @@ -189,4 +189,128 @@ void factoryDeclaresMaxRetriesZeroForReconnectOptions() { assertNotNull(opts); assertEquals(0, opts.maxRetries); } + + // ------------------------------------------------------------------------- + // Shared-client pool tests. The default constructor backs the factory with a process-wide + // shared SageMakerRuntimeHttp2AsyncClient keyed by config fingerprint, so naive code that + // builds a fresh factory per stream still benefits from a single Netty pool underneath. + // ------------------------------------------------------------------------- + + @Test + void defaultConstructorReusesSharedClientAcrossFactoriesWithSameConfig() { + // Reset the shared pool to isolate this test from other tests in the class. + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig configA = SageMakerConfig.builder() + .endpointName("endpoint-A") // endpoint name does NOT affect the shared client + .region("us-east-1") + .build(); + SageMakerConfig configB = SageMakerConfig.builder() + .endpointName("endpoint-B") // different endpoint, same Netty-relevant config + .region("us-east-1") + .build(); + + SageMakerTransportFactory f1 = new SageMakerTransportFactory(configA); + SageMakerTransportFactory f2 = new SageMakerTransportFactory(configB); + + // Use reflection sparingly — verify both factories point at the same underlying smClient. + assertSame(getSmClient(f1), getSmClient(f2), + "factories with same Netty-relevant config must share one smClient"); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void defaultConstructorBuildsDistinctSharedClientsForDifferentConfigs() { + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig configEast = SageMakerConfig.builder() + .endpointName("e").region("us-east-1").build(); + SageMakerConfig configWest = SageMakerConfig.builder() + .endpointName("e").region("us-west-2").build(); + SageMakerConfig configEastBigPool = SageMakerConfig.builder() + .endpointName("e").region("us-east-1").maxConcurrency(1000).build(); + + SageMakerTransportFactory east1 = new SageMakerTransportFactory(configEast); + SageMakerTransportFactory east2 = new SageMakerTransportFactory(configEast); + SageMakerTransportFactory west = new SageMakerTransportFactory(configWest); + SageMakerTransportFactory eastBig = new SageMakerTransportFactory(configEastBigPool); + + assertSame(getSmClient(east1), getSmClient(east2)); + assertNotSame(getSmClient(east1), getSmClient(west)); + assertNotSame(getSmClient(east1), getSmClient(eastBig)); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void byoClientConstructorIsNotPooled() { + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + + SageMakerTransportFactory byo = new SageMakerTransportFactory(config, mockClient); + SageMakerTransportFactory shared = new SageMakerTransportFactory(config); + + assertSame(mockClient, getSmClient(byo), "BYO factory must use the provided client"); + assertNotSame(mockClient, getSmClient(shared), + "Shared factory must build/lookup its own client, not steal the BYO mock"); + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void shutdownIsNoopForBoth() { + // factory.shutdown() must not close shared or BYO clients — lifecycle belongs elsewhere. + SageMakerTransportFactory.shutdownAllSharedClients(); + try { + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + SageMakerRuntimeHttp2AsyncClient mockClient = mock(SageMakerRuntimeHttp2AsyncClient.class); + + new SageMakerTransportFactory(config, mockClient).shutdown(); + verify(mockClient, never()).close(); + + SageMakerTransportFactory shared = new SageMakerTransportFactory(config); + shared.shutdown(); + // We can't easily verify .close() wasn't called on a real client without re-fetching; + // the assertion above on mockClient is the strong guarantee. The shared variant is + // documented as no-op via Javadoc. + } finally { + SageMakerTransportFactory.shutdownAllSharedClients(); + } + } + + @Test + void shutdownAllSharedClientsClearsThePool() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig config = SageMakerConfig.builder().endpointName("e").build(); + + SageMakerTransportFactory before = new SageMakerTransportFactory(config); + SageMakerRuntimeHttp2AsyncClient firstClient = getSmClient(before); + + SageMakerTransportFactory.shutdownAllSharedClients(); + + SageMakerTransportFactory after = new SageMakerTransportFactory(config); + SageMakerRuntimeHttp2AsyncClient secondClient = getSmClient(after); + + assertNotSame(firstClient, secondClient, + "shutdownAllSharedClients must clear the pool so the next factory builds a fresh client"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + /** Reflection helper — read the package-private smClient field set by the constructor. */ + private static SageMakerRuntimeHttp2AsyncClient getSmClient(SageMakerTransportFactory f) { + try { + java.lang.reflect.Field fld = SageMakerTransportFactory.class.getDeclaredField("smClient"); + fld.setAccessible(true); + return (SageMakerRuntimeHttp2AsyncClient) fld.get(f); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java index 52ec77e..164f2e5 100644 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -7,6 +7,7 @@ import java.net.ConnectException; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -93,6 +94,22 @@ void ioExceptionIsRetryable() { SageMakerTransport.classify(new IOException("network error"))); } + @Test + @DisplayName("CancellationException is retryable (covers self-induced retry-reset cancels)") + void cancellationExceptionIsRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new CancellationException())); + } + + @Test + @DisplayName("FutureCancelledException-style wrapper (RuntimeException with CancellationException cause) is retryable") + void cancellationWrappedInRuntimeIsRetryable() { + // Mirrors AWS Netty's FutureCancelledException, which extends RuntimeException and + // wraps a CancellationException as its cause. + RuntimeException wrapper = new RuntimeException("future cancelled", new CancellationException()); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(wrapper)); + } + @Test @DisplayName("AWS 429 (Too Many Requests) is retryable") void aws429IsRetryable() { @@ -155,6 +172,51 @@ void sdkExceptionWithPoolKeywordIsRetryable() { assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(sdke)); } + @Test + @DisplayName("'Unable to load credentials' is retryable when an SSO/STS provider hit Status Code: 429") + void credentialLoadFailureWithSsoThrottleIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Unable to load credentials from any of the providers in the chain " + + "AwsCredentialsProviderChain(...): [..., " + + "ProfileCredentialsProvider(profileName=shared-dev, ...): " + + "HTTP 429 Unknown Code (Service: Sso, Status Code: 429, Request ID: abc) " + + "(SDK Attempt Count: 4), ...]") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("'Unable to load credentials' is retryable when a credential backend returned 5xx") + void credentialLoadFailureWith5xxIsRetryable() { + SdkException sdke = SdkException.builder() + .message("Unable to load credentials from any of the providers in the chain " + + "AwsCredentialsProviderChain(...): [..., " + + "InstanceProfileCredentialsProvider(): " + + "HTTP 503 (Service: Imds, Status Code: 503, Request ID: xyz)]") + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(sdke)); + } + + @Test + @DisplayName("'Unable to load credentials' is terminal when no provider had a transient cause") + void credentialLoadFailurePureMisconfigIsTerminal() { + SdkException sdke = SdkException.builder() + .message("Unable to load credentials from any of the providers in the chain " + + "AwsCredentialsProviderChain(...): [" + + "SystemPropertyCredentialsProvider(): Access key must be specified..., " + + "EnvironmentVariableCredentialsProvider(): Access key must be specified..., " + + "WebIdentityTokenFileCredentialsProvider(): " + + "Either the environment variable AWS_WEB_IDENTITY_TOKEN_FILE or the " + + "javaproperty aws.webIdentityTokenFile must be set., " + + "ContainerCredentialsProvider(): Cannot fetch credentials from container, " + + "InstanceProfileCredentialsProvider(): Failed to load credentials from IMDS.]") + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, + SageMakerTransport.classify(sdke)); + } + @Test @DisplayName("Walks the cause chain — IOException wrapped in RuntimeException is retryable") void walksCauseChain() { From 87ae582a981cdab878d6a53b9c9e26e1186a9de3 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Fri, 1 May 2026 18:07:15 -0700 Subject: [PATCH 04/10] Untested aws do not retry strategy --- .../deepgram/sagemaker/SageMakerTransportFactory.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index ae70e5b..71e4437 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -4,6 +4,7 @@ import com.deepgram.core.transport.DeepgramTransport; import com.deepgram.core.transport.DeepgramTransportFactory; +import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; import software.amazon.awssdk.http.Protocol; import software.amazon.awssdk.http.nio.netty.Http2Configuration; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @@ -103,6 +104,16 @@ private static String sharedClientKey(SageMakerConfig c) { private static SageMakerRuntimeHttp2AsyncClient buildClient(SageMakerConfig config) { return SageMakerRuntimeHttp2AsyncClient.builder() .region(config.region()) + // Disable the AWS SDK's internal retry strategy. SageMakerTransport owns the + // retry policy (handleStreamError + ensureConnected backoff). The AWS SDK's + // default 3-attempt strategy compounds on top: every "1 retry" in our schedule + // becomes ~4 hits on the SageMaker frontline (the original attempt + 3 SDK + // retries with their own ~25/100/400 ms backoffs). Under a per-LB-IP throttle + // ceiling measured in requests/sec, that amplification keeps the conn pinned + // in the throttle window. AwsRetryStrategy.doNotRetry() removes the SDK layer; + // transient TLS / connection-reset hiccups are still caught by our IOException + // → RETRYABLE classification and the same backoff path. + .overrideConfiguration(c -> c.retryStrategy(AwsRetryStrategy.doNotRetry())) .httpClientBuilder( NettyNioAsyncHttpClient.builder() .protocol(Protocol.HTTP2) From 4fc0bce9f5409e0d937a903968725b465d3f127c Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Sun, 3 May 2026 21:49:44 -0400 Subject: [PATCH 05/10] Update retry classification logic --- .../sagemaker/SageMakerTransport.java | 60 +++++++----------- .../SageMakerTransportRetryTest.java | 61 ++++++++++++------- 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index fc38cf7..bd6b3da 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -4,19 +4,15 @@ import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.sagemakerruntimehttp2.SageMakerRuntimeHttp2AsyncClient; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.InvokeEndpointWithBidirectionalStreamRequest; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.InvokeEndpointWithBidirectionalStreamResponseHandler; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.RequestStreamEvent; import software.amazon.awssdk.services.sagemakerruntimehttp2.model.ResponsePayloadPart; -import java.io.IOException; -import java.net.ConnectException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; @@ -314,53 +310,39 @@ enum Classification { RETRYABLE, TERMINAL } /** * Classify an AWS-side exception as transient (retry internally, don't surface) vs terminal - * (surface to {@code errorListeners}). Walks the cause chain so SDK-wrapped exceptions are - * inspected too. + * (surface to {@code errorListeners}). + * + *

Default is RETRYABLE — the retry budget ({@link SageMakerConfig#retryBudget()}) is the + * safety net, not the classifier. Under high-burst load against SageMaker we've seen a + * long tail of transient failure modes (Netty WriteTimeout, HTTP/2 stream resets, SSO 429s, + * pool-acquire timeouts, AWS-frontline ThrottlingException) that the classifier kept missing + * one-by-one; flipping the default means a new transient hiccup retries instead of surfacing + * as a hard failure. The budget caps the worst case at the configured wall-clock anyway. + * + *

The narrow set we genuinely consider TERMINAL is caller-side rejections from AWS: + * an {@link AwsServiceException} with a 4xx status code (other than 429 and other than + * "throttling"-coded errors). Validation errors, AccessDenied, ResourceNotFound, etc. are + * authoritative — retrying won't change the outcome. Everything else, including the cause + * chain we don't recognize, defaults to RETRYABLE. */ static Classification classify(Throwable error) { for (Throwable t = error; t != null; t = t.getCause()) { - // CancellationException is RETRYABLE because the cancel was either (a) induced by our - // own retry-reset path (in which case the next attempt will run cleanly) or (b) caused - // by some upstream terminal condition (in which case the retry attempt will hit the - // underlying error and classify it as TERMINAL on its own). Either way, treating the - // cancel itself as TERMINAL would surface a self-inflicted error to listeners. - // Covers AWS Netty's FutureCancelledException too — it wraps CancellationException as - // its cause. - if (t instanceof CancellationException) return Classification.RETRYABLE; - if (t instanceof TimeoutException) return Classification.RETRYABLE; - if (t instanceof ConnectException) return Classification.RETRYABLE; - if (t instanceof IOException) return Classification.RETRYABLE; if (t instanceof AwsServiceException) { AwsServiceException ase = (AwsServiceException) t; int status = ase.statusCode(); - if (status == 429 || (status >= 500 && status < 600)) return Classification.RETRYABLE; String code = ase.awsErrorDetails() != null ? ase.awsErrorDetails().errorCode() : null; + // Throttling is sometimes coded as 4xx (e.g. SageMaker frontline returns 400 with + // errorCode=ThrottlingException); always retry these regardless of status. if (code != null && code.toLowerCase().contains("throttl")) return Classification.RETRYABLE; - return Classification.TERMINAL; - } - if (t instanceof SdkException) { - String msg = t.getMessage() == null ? "" : t.getMessage().toLowerCase(); - if (msg.contains("acquire") || msg.contains("pool") || msg.contains("throttl") - || msg.contains("timeout")) { - return Classification.RETRYABLE; - } - // Credential-loading failures from the AWS SDK provider chain - // (`SdkClientException: Unable to load credentials from any of the providers...`). - // Retry only when at least one provider hit a transient AWS-side condition — - // typically AWS IAM Identity Center (SSO) or STS rate-limiting credential - // refreshes under burst load (Status Code: 429), or a 5xx from a credential - // backend. Pure misconfig (no provider has credentials at all) still surfaces - // fast — retrying won't conjure credentials that don't exist. - if (msg.contains("unable to load credentials") - && (msg.contains("status code: 429") - || msg.contains("status code: 5") - || msg.contains("rate exceeded"))) { - return Classification.RETRYABLE; + // 4xx (except 429) = caller-side rejection: validation, auth, notfound — won't + // fix on retry. 429 and 5xx fall through to the RETRYABLE default below. + if (status >= 400 && status < 500 && status != 429) { + return Classification.TERMINAL; } } if (t == t.getCause()) break; } - return Classification.TERMINAL; + return Classification.RETRYABLE; } @Override diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java index 164f2e5..a7b1159 100644 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -199,24 +199,6 @@ void credentialLoadFailureWith5xxIsRetryable() { SageMakerTransport.classify(sdke)); } - @Test - @DisplayName("'Unable to load credentials' is terminal when no provider had a transient cause") - void credentialLoadFailurePureMisconfigIsTerminal() { - SdkException sdke = SdkException.builder() - .message("Unable to load credentials from any of the providers in the chain " - + "AwsCredentialsProviderChain(...): [" - + "SystemPropertyCredentialsProvider(): Access key must be specified..., " - + "EnvironmentVariableCredentialsProvider(): Access key must be specified..., " - + "WebIdentityTokenFileCredentialsProvider(): " - + "Either the environment variable AWS_WEB_IDENTITY_TOKEN_FILE or the " - + "javaproperty aws.webIdentityTokenFile must be set., " - + "ContainerCredentialsProvider(): Cannot fetch credentials from container, " - + "InstanceProfileCredentialsProvider(): Failed to load credentials from IMDS.]") - .build(); - assertEquals(SageMakerTransport.Classification.TERMINAL, - SageMakerTransport.classify(sdke)); - } - @Test @DisplayName("Walks the cause chain — IOException wrapped in RuntimeException is retryable") void walksCauseChain() { @@ -225,11 +207,48 @@ void walksCauseChain() { } @Test - @DisplayName("Unknown exception defaults to terminal") - void unknownDefaultsToTerminal() { - assertEquals(SageMakerTransport.Classification.TERMINAL, + @DisplayName("Unknown exception defaults to RETRYABLE (budget is the safety net)") + void unknownDefaultsToRetryable() { + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(new RuntimeException("mystery"))); } + + @Test + @DisplayName("Netty WriteTimeoutException-style RuntimeException is retryable by default") + void nettyWriteTimeoutIsRetryable() { + // io.netty.handler.timeout.WriteTimeoutException extends Netty's own TimeoutException + // (NOT java.util.concurrent.TimeoutException) which extends RuntimeException. We don't + // want to take a direct Netty compile dep just to instanceof-check it, so the new + // default-RETRYABLE policy covers this organically. + class WriteTimeoutException extends RuntimeException { + WriteTimeoutException() { super(); } + } + assertEquals(SageMakerTransport.Classification.RETRYABLE, + SageMakerTransport.classify(new WriteTimeoutException())); + } + + @Test + @DisplayName("AWS 400 (ValidationException) is terminal — caller-side rejection") + void aws400ValidationIsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("invalid input") + .statusCode(400) + .awsErrorDetails(AwsErrorDetails.builder() + .errorCode("ValidationException") + .build()) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } + + @Test + @DisplayName("AWS 404 (ResourceNotFound) is terminal — won't appear on retry") + void aws404IsTerminal() { + AwsServiceException ase = AwsServiceException.builder() + .message("endpoint not found") + .statusCode(404) + .build(); + assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); + } } @Nested From a6f4f41f5c4807eddc4657a33865e1a18ba2662a Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Mon, 4 May 2026 00:37:30 -0400 Subject: [PATCH 06/10] Update buffer management --- .../deepgram/sagemaker/SageMakerConfig.java | 33 +++ .../sagemaker/SageMakerTransport.java | 208 ++++++++++++++++-- .../SageMakerTransportRetryTest.java | 91 ++++++++ 3 files changed, 309 insertions(+), 23 deletions(-) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java index 3d9fe6c..a246e78 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java @@ -48,6 +48,22 @@ public class SageMakerConfig { /** Total wall-clock budget across all retry attempts before giving up and surfacing the error. */ public static final Duration DEFAULT_RETRY_BUDGET = Duration.ofSeconds(30); + /** + * Cap on the in-memory replay buffer that holds sent-but-unacked stream events for the + * current bidi stream attempt. If the SDK has to retry (throttling, post-subscription + * stream reset), this buffer is drained onto the new stream so AWS sees a continuous + * audio sequence instead of the gap created by the discarded events. + * + *

The buffer is trimmed when {@code handlePayloadPart} fires (the model just produced + * a transcript, so prior audio is acked from our perspective), so under steady-state + * operation it stays small (≤ a few hundred KB). It only grows during a throttle/reset + * window where no payload parts come back. 8 MiB ≈ 256 s of 16 kHz mono + * 16-bit PCM, which covers the longest throttle storms we've seen in practice with + * margin to spare. Lower the cap for tight memory budgets; raise it if you expect + * longer retry windows per stream. + */ + public static final long DEFAULT_MAX_REPLAY_BUFFER_BYTES = 8L * 1024 * 1024; + private final String endpointName; private final Region region; private final String contentType; @@ -61,6 +77,7 @@ public class SageMakerConfig { private final Duration maxBackoff; private final double backoffMultiplier; private final Duration retryBudget; + private final long maxReplayBufferBytes; private SageMakerConfig(Builder builder) { this.endpointName = builder.endpointName; @@ -76,6 +93,7 @@ private SageMakerConfig(Builder builder) { this.maxBackoff = builder.maxBackoff; this.backoffMultiplier = builder.backoffMultiplier; this.retryBudget = builder.retryBudget; + this.maxReplayBufferBytes = builder.maxReplayBufferBytes; } public String endpointName() { return endpointName; } @@ -91,6 +109,7 @@ private SageMakerConfig(Builder builder) { public Duration maxBackoff() { return maxBackoff; } public double backoffMultiplier() { return backoffMultiplier; } public Duration retryBudget() { return retryBudget; } + public long maxReplayBufferBytes() { return maxReplayBufferBytes; } public static Builder builder() { return new Builder(); @@ -110,6 +129,7 @@ public static class Builder { private Duration maxBackoff = DEFAULT_MAX_BACKOFF; private double backoffMultiplier = DEFAULT_BACKOFF_MULTIPLIER; private Duration retryBudget = DEFAULT_RETRY_BUDGET; + private long maxReplayBufferBytes = DEFAULT_MAX_REPLAY_BUFFER_BYTES; public Builder endpointName(String endpointName) { this.endpointName = endpointName; @@ -239,6 +259,19 @@ public Builder retryBudget(Duration retryBudget) { return this; } + /** + * Cap on the in-memory replay buffer that holds sent-but-unacked stream events. Set to + * {@code 0} to disable replay (sent events are dropped on internal reset, matching the + * pre-replay-buffer behavior). See {@link #DEFAULT_MAX_REPLAY_BUFFER_BYTES}. + */ + public Builder maxReplayBufferBytes(long maxReplayBufferBytes) { + if (maxReplayBufferBytes < 0) { + throw new IllegalArgumentException("maxReplayBufferBytes must be non-negative"); + } + this.maxReplayBufferBytes = maxReplayBufferBytes; + return this; + } + public SageMakerConfig build() { if (endpointName == null || endpointName.isBlank()) { throw new IllegalArgumentException("endpointName is required"); diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index bd6b3da..ba2cc9d 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -14,12 +14,14 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import org.reactivestreams.Publisher; @@ -71,6 +73,15 @@ public class SageMakerTransport implements DeepgramTransport { // next stream attempt instead of being dropped with the discarded publisher. private final ConcurrentLinkedQueue pending = new ConcurrentLinkedQueue<>(); + // Replay buffer: events sent on the current stream that AWS hasn't acked yet (no payload part + // received since they were sent). On internal reset (handleStreamError → RETRYABLE), the next + // attemptConnect drains this buffer onto the new stream so audio sent on the rejected stream + // isn't lost. Trimmed in handlePayloadPart (a transcript proves AWS has consumed prior audio) + // and capped at config.maxReplayBufferBytes() with FIFO eviction so unbounded throttle storms + // don't OOM the JVM. + private final ConcurrentLinkedDeque replayBuffer = new ConcurrentLinkedDeque<>(); + private final AtomicLong replayBufferBytes = new AtomicLong(0L); + // Retry budget tracking. Reset to 0 once real downstream data flows back to the application // (handlePayloadPart) — NOT on subscription success. Subscription succeeds in TLS+HTTP/2 // setup terms even when the bidi-stream request will be throttled milliseconds later. @@ -103,9 +114,22 @@ public class SageMakerTransport implements DeepgramTransport { * via the caller's {@code send*} path. */ private void ensureConnected() { - if (connected.get()) return; + // Fast path requires BOTH connected=true AND a live publisher. If either is false we + // re-enter the synchronized reconnect loop. Without the publisher liveness check, a + // race between handleStreamError (which sets connected=false then completes the + // publisher) and ensureConnected's tail (connected.set(true)) can leave us with + // connected=true pointing at a dead publisher — sends silently NOOP forever and the + // conn never recovers (root cause of the 44 silent conns observed in the + // 2026-05-03 replay-buffer-only test). + StreamPublisher pub = inputPublisher; + if (connected.get() && pub != null && !pub.isCompleted()) return; synchronized (connectLock) { - if (connected.get()) return; + pub = inputPublisher; + if (connected.get() && pub != null && !pub.isCompleted()) return; + // If we got here because the publisher died but connected was still true (the race + // above), tell the loop to start a fresh attempt rather than treat ourselves as + // already-connected. + connected.set(false); if (retryWindowStart == 0L) { retryWindowStart = System.currentTimeMillis(); @@ -180,6 +204,20 @@ private void attemptConnect() throws TimeoutException, InterruptedException { StreamPublisher publisher = new StreamPublisher(pending); inputPublisher = publisher; + // Replay any unacked events from a prior failed stream attempt. They go through the + // publisher's pre-subscription pending queue and are flushed onto the new AWS subscriber + // as soon as it arrives — so AWS sees the buffered audio first, in order, before any + // new chunks the caller sends after this attempt succeeds. + if (!replayBuffer.isEmpty()) { + int count = replayBuffer.size(); + long bytes = replayBufferBytes.get(); + log.info("[{}] attemptConnect: replaying {} buffered events ({} bytes) onto new stream", + transportId, count, bytes); + for (BufferedEvent be : replayBuffer) { + publisher.send(be.event); + } + } + InvokeEndpointWithBidirectionalStreamRequest.Builder requestBuilder = InvokeEndpointWithBidirectionalStreamRequest.builder() .endpointName(config.endpointName()) @@ -214,6 +252,21 @@ private void attemptConnect() throws TimeoutException, InterruptedException { "Timed out waiting for AWS SDK to subscribe to stream publisher after " + config.subscriptionTimeout()); } + + // Race guard: under high-concurrency burst-throttle, AWS sometimes accepts the + // subscription and IMMEDIATELY rejects the stream (ThrottlingException, ping-timeout). + // The handleStreamError callback fires on a Netty event-loop thread BEFORE attemptConnect + // returns from awaitSubscription, sets connected=false, and completes this publisher. + // If we then unconditionally let ensureConnected set connected=true, the audio thread + // takes the fast path and pushes chunks into a dead publisher — silently dropped, no + // transcripts ever come back. Detect the race here by re-checking the publisher's + // completion state and surfacing it as a transient failure so ensureConnected re-enters + // the retry loop on the next send. + if (publisher.isCompleted()) { + throw new IllegalStateException( + "Stream rejected immediately after subscription (handleStreamError already " + + "fired and completed the publisher); ensureConnected will re-attempt"); + } } /** @@ -320,10 +373,16 @@ enum Classification { RETRYABLE, TERMINAL } * as a hard failure. The budget caps the worst case at the configured wall-clock anyway. * *

The narrow set we genuinely consider TERMINAL is caller-side rejections from AWS: - * an {@link AwsServiceException} with a 4xx status code (other than 429 and other than + * an {@link AwsServiceException} with a 4xx status code (other than 429, 424, and other than * "throttling"-coded errors). Validation errors, AccessDenied, ResourceNotFound, etc. are * authoritative — retrying won't change the outcome. Everything else, including the cause * chain we don't recognize, defaults to RETRYABLE. + * + *

424 (Failed Dependency) is treated as RETRYABLE because under burst load we observed + * SageMaker emitting {@code ModelErrorException: Received server error (424) from primary + * with message "Failed to establish WebSocket connection"} when the upstream model container + * couldn't accept a new stream at the moment of the request. The next attempt usually + * succeeds, so it doesn't belong in the caller-side-rejection bucket. */ static Classification classify(Throwable error) { for (Throwable t = error; t != null; t = t.getCause()) { @@ -334,9 +393,10 @@ static Classification classify(Throwable error) { // Throttling is sometimes coded as 4xx (e.g. SageMaker frontline returns 400 with // errorCode=ThrottlingException); always retry these regardless of status. if (code != null && code.toLowerCase().contains("throttl")) return Classification.RETRYABLE; - // 4xx (except 429) = caller-side rejection: validation, auth, notfound — won't - // fix on retry. 429 and 5xx fall through to the RETRYABLE default below. - if (status >= 400 && status < 500 && status != 429) { + // 4xx caller-side rejections: validation, auth, notfound — won't fix on retry. + // Exclusions: 429 (rate-limit, retry-with-backoff) and 424 (Failed Dependency, + // SageMaker upstream "Failed to establish WebSocket connection" — transient). + if (status >= 400 && status < 500 && status != 429 && status != 424) { return Classification.TERMINAL; } } @@ -356,6 +416,7 @@ public CompletableFuture sendBinary(byte[] data) { .bytes(SdkBytes.fromByteArray(data)) .build(); inputPublisher.send(event); + bufferForReplay(event, data.length); return CompletableFuture.completedFuture(null); } @@ -366,11 +427,13 @@ public CompletableFuture sendText(String data) { new IllegalStateException("Transport is closed")); } ensureConnected(); + byte[] payload = data.getBytes(StandardCharsets.UTF_8); RequestStreamEvent event = RequestStreamEvent.payloadPartBuilder() - .bytes(SdkBytes.fromByteArray(data.getBytes(StandardCharsets.UTF_8))) + .bytes(SdkBytes.fromByteArray(payload)) .dataType("UTF8") .build(); inputPublisher.send(event); + bufferForReplay(event, payload.length); // Track that we've signaled end-of-audio so we can treat the model's // idle timeout as a normal close rather than an error. @@ -381,6 +444,44 @@ public CompletableFuture sendText(String data) { return CompletableFuture.completedFuture(null); } + /** + * Append {@code event} to the replay buffer for re-delivery on a future internal reset. + * Evicts oldest events FIFO once the configured byte cap is exceeded so an unbounded + * throttle storm can't OOM the JVM. Caller passes the payload byte length so we don't + * have to re-walk the SdkBytes wrapper (which would double-allocate on every send). + */ + private void bufferForReplay(RequestStreamEvent event, int eventBytes) { + long cap = config.maxReplayBufferBytes(); + if (cap == 0L) return; // replay disabled + replayBuffer.addLast(new BufferedEvent(event, eventBytes)); + long total = replayBufferBytes.addAndGet(eventBytes); + while (total > cap) { + BufferedEvent dropped = replayBuffer.pollFirst(); + if (dropped == null) break; + total = replayBufferBytes.addAndGet(-dropped.bytes); + } + } + + /** Drain the replay buffer (e.g. after AWS acks via {@code handlePayloadPart}). */ + private void clearReplayBuffer() { + replayBuffer.clear(); + replayBufferBytes.set(0L); + } + + // Test-only accessors. Package-private so unit tests in the same package can verify buffer + // behavior without going through the full AWS reactive-streams stack. + int replayBufferSize() { return replayBuffer.size(); } + long replayBufferBytes() { return replayBufferBytes.get(); } + void bufferForReplayForTest(RequestStreamEvent event, int eventBytes) { + bufferForReplay(event, eventBytes); + } + void clearReplayBufferForTest() { clearReplayBuffer(); } + java.util.List drainReplayBufferForTest() { + java.util.List out = new java.util.ArrayList<>(replayBuffer.size()); + for (BufferedEvent be : replayBuffer) out.add(be.event); + return out; + } + private void notifyClose(int code, String reason) { if (closeNotified.compareAndSet(false, true)) { for (CloseListener l : closeListeners) { @@ -392,23 +493,58 @@ private void notifyClose(int code, String reason) { private void handlePayloadPart(ResponsePayloadPart part) { byte[] bytes = part.bytes().asByteArray(); - // First downstream data on this transport (or first since the last retry-loop reset) → - // the stream is genuinely working end-to-end, not just subscription-established. Reset - // the retry budget so a future transient failure gets a fresh budget. - if (retryAttempt.get() != 0 || retryWindowStart != 0L || retryNotBeforeMs != 0L) { - log.info("[{}] handlePayloadPart: data received ({}B) → resetting retry counters " - + "(was attempt={}, windowStart={}, notBeforeMs={})", - transportId, bytes.length, - retryAttempt.get(), retryWindowStart, retryNotBeforeMs); - retryAttempt.set(0); - retryWindowStart = 0L; - retryNotBeforeMs = 0L; + // Decode once. JSON messages start with '{"' (0x7B 0x22) — checking two bytes avoids + // false positives from binary audio chunks that happen to start with 0x7B. + boolean isJson = bytes.length > 1 && bytes[0] == '{' && bytes[1] == '"'; + String text = isJson ? new String(bytes, StandardCharsets.UTF_8) : null; + + // Decide whether this payload counts as "the model consumed input" — the signal that + // makes it safe to reset retry counters and trim the replay buffer. + // + // Rather than enumerate every downstream message type per Deepgram product (Listen STT + // emits `Results`, Flux emits `TurnInfo`, Speak/Aura emits binary audio chunks, ...), + // we invert the check: assume ANY downstream payload counts as a real ack EXCEPT for + // end-of-stream-only message types that the container can emit without having consumed + // input. Two types are documented to fire at close regardless of consumption: + // + // - `Metadata` — end-of-stream summary. Under burst-replay we observed the container + // emitting `Metadata` (sha256="incomplete", duration=0) when the model errored out + // before producing any transcript; the original code treated this as an ack and + // cleared the replay buffer, so the next retry had nothing to replay and the + // replayed audio was lost (root cause of conn 125-style front-loss in the + // 2026-05-03 RC run). + // - `Error` — model-side fatal error message, also typically sent at close. + // + // Anything else — Results, TurnInfo, EndOfTurn, Warning, SpeechStarted, UtteranceEnd, + // raw binary audio (TTS) — represents real model activity and is safe to trust. + boolean isCloseOnlyMessage = isJson + && (text.contains("\"type\":\"Metadata\"") || text.contains("\"type\":\"Error\"")); + boolean countsAsAck = !isCloseOnlyMessage; + + if (countsAsAck) { + // Real downstream content on this transport (or first since the last retry-loop + // reset) → the stream is genuinely working end-to-end. Reset the retry budget so a + // future transient failure gets a fresh budget. + if (retryAttempt.get() != 0 || retryWindowStart != 0L || retryNotBeforeMs != 0L) { + log.info("[{}] handlePayloadPart: ack-class data received ({}B) → resetting retry counters " + + "(was attempt={}, windowStart={}, notBeforeMs={})", + transportId, bytes.length, + retryAttempt.get(), retryWindowStart, retryNotBeforeMs); + retryAttempt.set(0); + retryWindowStart = 0L; + retryNotBeforeMs = 0L; + } + + // The model produced output → it consumed input up to (at least) the moment this + // part's window started. Drop the replay buffer; if a retry happens later, we only + // need to replay events sent SINCE this ack. The model's internal latency means a + // small leading window (< 1 s, typically) of audio sent just before the failure + // might be re-transcribed on the new stream — net impact is a brief duplicate, far + // better than dropping 30+ s of audio outright. + clearReplayBuffer(); } - // JSON messages start with '{"' (0x7B 0x22). Checking two bytes avoids - // false positives from binary audio chunks that happen to start with 0x7B. - if (bytes.length > 1 && bytes[0] == '{' && bytes[1] == '"') { - String text = new String(bytes, StandardCharsets.UTF_8); + if (isJson) { for (Consumer l : messageListeners) { l.accept(text); } @@ -453,8 +589,10 @@ public boolean isOpen() { @Override public void close() { if (!open.compareAndSet(true, false)) return; - // Terminal close — drop any messages that were queued during a reset window. + // Terminal close — drop any messages that were queued during a reset window, and free the + // replay buffer so we don't pin its memory after the transport is done. pending.clear(); + clearReplayBuffer(); if (inputPublisher != null) { inputPublisher.complete(); } @@ -463,6 +601,20 @@ public void close() { } } + /** + * Replay-buffer entry — keeps the byte length alongside the event so the cap-and-evict path + * doesn't have to peek inside SdkBytes (which would incur an extra copy on every send). + */ + static final class BufferedEvent { + final RequestStreamEvent event; + final int bytes; + + BufferedEvent(RequestStreamEvent event, int bytes) { + this.event = event; + this.bytes = bytes; + } + } + /** * Reactive Streams publisher that buffers events until the SDK subscribes, * then delivers them in order. After subscription, events are forwarded immediately. @@ -530,5 +682,15 @@ void complete() { boolean awaitSubscription(long timeout, TimeUnit unit) throws InterruptedException { return subscribed.await(timeout, unit); } + + /** + * @return {@code true} if {@link #complete()} or the subscriber's cancel hook has fired. + * Used by the {@code attemptConnect} race-guard to detect when handleStreamError + * has already torn down the publisher between subscription and the + * {@code connected.set(true)} fast-path. + */ + boolean isCompleted() { + return completed.get(); + } } } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java index a7b1159..4e6b2bf 100644 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -249,6 +249,20 @@ void aws404IsTerminal() { .build(); assertEquals(SageMakerTransport.Classification.TERMINAL, SageMakerTransport.classify(ase)); } + + @Test + @DisplayName("AWS 424 (Failed Dependency, SageMaker ModelError) is retryable — upstream container transient") + void aws424IsRetryable() { + // Mirrors the actual SageMaker burst-load error: + // ModelErrorException: Received server error (424) from primary with message + // "Failed to establish WebSocket connection" + AwsServiceException ase = AwsServiceException.builder() + .message("Received server error (424) from primary with message " + + "\"Failed to establish WebSocket connection\"") + .statusCode(424) + .build(); + assertEquals(SageMakerTransport.Classification.RETRYABLE, SageMakerTransport.classify(ase)); + } } @Nested @@ -309,4 +323,81 @@ void sendForwardsWhenSubscribed() { assertTrue(shared.isEmpty()); } } + + @Nested + @DisplayName("Replay buffer") + class ReplayBufferTests { + private SageMakerTransport newTransport(long maxReplayBufferBytes) { + SageMakerConfig cfg = SageMakerConfig.builder() + .endpointName("test") + .region("us-east-1") + .maxReplayBufferBytes(maxReplayBufferBytes) + .build(); + // null AWS client is fine — these tests only exercise the in-memory buffer helpers, + // never attempt a real bidi stream. + return new SageMakerTransport(null, cfg, "v1/listen", ""); + } + + @Test + @DisplayName("buffer accumulates events with running byte count") + void bufferAccumulates() { + SageMakerTransport t = newTransport(1024); + t.bufferForReplayForTest(payloadEvent("aaa"), 3); + t.bufferForReplayForTest(payloadEvent("bbbbb"), 5); + assertEquals(2, t.replayBufferSize()); + assertEquals(8L, t.replayBufferBytes()); + } + + @Test + @DisplayName("clearReplayBuffer drops everything (the AWS-acked path)") + void clearDropsAll() { + SageMakerTransport t = newTransport(1024); + t.bufferForReplayForTest(payloadEvent("a"), 1); + t.bufferForReplayForTest(payloadEvent("b"), 1); + t.clearReplayBufferForTest(); + assertEquals(0, t.replayBufferSize()); + assertEquals(0L, t.replayBufferBytes()); + } + + @Test + @DisplayName("FIFO eviction once cap exceeded — newest events kept, oldest dropped") + void evictionFifo() { + SageMakerTransport t = newTransport(10); // cap = 10 bytes + RequestStreamEvent a = payloadEvent("aaaa"); + RequestStreamEvent b = payloadEvent("bbbb"); + RequestStreamEvent c = payloadEvent("cccc"); + RequestStreamEvent d = payloadEvent("dddd"); + t.bufferForReplayForTest(a, 4); // total 4 + t.bufferForReplayForTest(b, 4); // total 8 + t.bufferForReplayForTest(c, 4); // total 12 → evict a, total 8 + t.bufferForReplayForTest(d, 4); // total 12 → evict b, total 8 + + // Latest two events should survive, oldest two evicted, ordering preserved. + java.util.List remaining = t.drainReplayBufferForTest(); + assertEquals(2, remaining.size()); + assertSame(c, remaining.get(0), "oldest surviving event should be 'cccc'"); + assertSame(d, remaining.get(1), "newest event should be 'dddd'"); + assertEquals(8L, t.replayBufferBytes()); + } + + @Test + @DisplayName("maxReplayBufferBytes=0 disables buffering entirely") + void disabledByZeroCap() { + SageMakerTransport t = newTransport(0L); + t.bufferForReplayForTest(payloadEvent("a"), 1); + t.bufferForReplayForTest(payloadEvent("b"), 1); + assertEquals(0, t.replayBufferSize(), "events must not accumulate when cap=0"); + assertEquals(0L, t.replayBufferBytes()); + } + + @Test + @DisplayName("oversized single event is dropped immediately by the eviction loop") + void oversizedEventCantStick() { + SageMakerTransport t = newTransport(10); + // A single 16-byte event exceeds the 10-byte cap; eviction loop should remove it. + t.bufferForReplayForTest(payloadEvent("0123456789ABCDEF"), 16); + assertEquals(0, t.replayBufferSize()); + assertEquals(0L, t.replayBufferBytes()); + } + } } From 91754f7cf81b144f7b5ac4f2471e21c503bb0afb Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Mon, 4 May 2026 02:11:06 -0400 Subject: [PATCH 07/10] Add Jitter for backoff --- .../com/deepgram/sagemaker/SageMakerTransport.java | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index ba2cc9d..2bc82c2 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -19,6 +19,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -353,10 +354,15 @@ private long computeBackoff(int attempt) { long initial = config.initialBackoff().toMillis(); long max = config.maxBackoff().toMillis(); double scaled = initial * Math.pow(config.backoffMultiplier(), attempt); - if (scaled > max || Double.isInfinite(scaled)) { - return max; + long ceiling = (scaled > max || Double.isInfinite(scaled)) ? max : Math.max(initial, (long) scaled); + // Full jitter: random in [initial, ceiling]. Without this, N conns failing simultaneously + // all compute the same backoff and retry in lockstep, hammering the endpoint in waves + // (worst case: every maxBackoff seconds the entire fleet retries together). Jitter spreads + // the retry load continuously over the backoff window. + if (ceiling <= initial) { + return ceiling; } - return Math.max(initial, (long) scaled); + return ThreadLocalRandom.current().nextLong(initial, ceiling + 1); } enum Classification { RETRYABLE, TERMINAL } From 3e2fc61222eeeb112a4542d0a6d0011b0ffc6622 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Mon, 4 May 2026 02:18:54 -0400 Subject: [PATCH 08/10] add tests for jitter --- .../sagemaker/SageMakerTransport.java | 41 +++++++--- .../SageMakerTransportRetryTest.java | 75 +++++++++++++++++++ 2 files changed, 106 insertions(+), 10 deletions(-) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java index 2bc82c2..efcbbdf 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransport.java @@ -351,18 +351,39 @@ private static String summarize(Throwable t) { } private long computeBackoff(int attempt) { - long initial = config.initialBackoff().toMillis(); - long max = config.maxBackoff().toMillis(); - double scaled = initial * Math.pow(config.backoffMultiplier(), attempt); - long ceiling = (scaled > max || Double.isInfinite(scaled)) ? max : Math.max(initial, (long) scaled); - // Full jitter: random in [initial, ceiling]. Without this, N conns failing simultaneously - // all compute the same backoff and retry in lockstep, hammering the endpoint in waves - // (worst case: every maxBackoff seconds the entire fleet retries together). Jitter spreads - // the retry load continuously over the backoff window. - if (ceiling <= initial) { + // Lambda re-resolves ThreadLocalRandom.current() at each invocation, so each calling thread + // gets its own per-thread RNG state — safe even if the LongBinaryOperator is cached and + // invoked from a different thread later (which the current call path doesn't do, but this + // is defensive). Method-reference form `ThreadLocalRandom.current()::nextLong` would + // capture whichever thread called computeBackoff first and wedge it as the RNG source. + return computeBackoff( + config.initialBackoff().toMillis(), + config.maxBackoff().toMillis(), + config.backoffMultiplier(), + attempt, + (origin, bound) -> ThreadLocalRandom.current().nextLong(origin, bound)); + } + + /** + * Pure-function backoff calculator with full jitter. Package-private + static for testability. + * + *

Without jitter, N conns failing simultaneously all compute the same exponential backoff + * and retry in lockstep, hammering the endpoint in waves (worst case: every {@code maxMs} the + * entire fleet retries together). Full jitter — random uniform in {@code [initialMs, ceiling]} — + * spreads the retry load continuously over the backoff window. See AWS Architecture Blog + * "Exponential Backoff and Jitter". + * + * @param randomLong injected RNG: {@code (originInclusive, boundExclusive) -> long}, allowing + * deterministic tests via a stub. + */ + static long computeBackoff(long initialMs, long maxMs, double multiplier, int attempt, + java.util.function.LongBinaryOperator randomLong) { + double scaled = initialMs * Math.pow(multiplier, attempt); + long ceiling = (scaled > maxMs || Double.isInfinite(scaled)) ? maxMs : Math.max(initialMs, (long) scaled); + if (ceiling <= initialMs) { return ceiling; } - return ThreadLocalRandom.current().nextLong(initial, ceiling + 1); + return randomLong.applyAsLong(initialMs, ceiling + 1); } enum Classification { RETRYABLE, TERMINAL } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java index 4e6b2bf..e956394 100644 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportRetryTest.java @@ -52,6 +52,13 @@ public void onComplete() {} }; } + // RNG stubs for backoff tests — must live on outer class because @Nested inner classes can't + // have static members on Java 11. + private static final java.util.function.LongBinaryOperator MAX_RNG = (origin, bound) -> bound - 1; + private static final java.util.function.LongBinaryOperator MIN_RNG = (origin, bound) -> origin; + private static final java.util.function.LongBinaryOperator MID_RNG = + (origin, bound) -> origin + (bound - origin) / 2; + private static class CapturingSubscriber implements Subscriber { final List received = new ArrayList<>(); final CountDownLatch completed = new CountDownLatch(1); @@ -400,4 +407,72 @@ void oversizedEventCantStick() { assertEquals(0L, t.replayBufferBytes()); } } + + @Nested + @DisplayName("computeBackoff(initial, max, multiplier, attempt) — full jitter") + class ComputeBackoffTests { + @Test + @DisplayName("ceiling grows exponentially up to max; range is [initial, ceiling] inclusive") + void exponentialCeilingGrowth() { + // initial=100, multiplier=2 → ceilings 100, 200, 400, 800, 1600, capped at 1000. + // attempt=0: ceiling=initial=100 → degenerate range, returns 100 without RNG. + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 0, MID_RNG)); + // attempt=1: ceiling=200. Range=[100,201). MID_RNG returns origin + (bound-origin)/2 = 100 + 50 = 150. + assertEquals(150L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 1, MID_RNG)); + // attempt=4: scaled=1600, capped to ceiling=1000. Range=[100,1001). MID_RNG returns 100 + 450 = 550. + assertEquals(550L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 4, MID_RNG)); + } + + @Test + @DisplayName("MIN_RNG returns the initial floor, MAX_RNG returns the ceiling") + void rngBoundsRespected() { + // attempt=2 with initial=100, mult=2: scaled=400, ceiling=400. Range=[100,401). + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 2, MIN_RNG)); + assertEquals(400L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 2, MAX_RNG)); + } + + @Test + @DisplayName("ceiling caps at max regardless of attempt") + void ceilingCappedAtMax() { + // High attempt count would overflow without the cap. + assertEquals(5000L, SageMakerTransport.computeBackoff(100, 5000, 2.0, 100, MAX_RNG)); + // Even infinity scaling caps cleanly. + assertEquals(5000L, SageMakerTransport.computeBackoff(100, 5000, 2.0, 10_000, MAX_RNG)); + } + + @Test + @DisplayName("when ceiling == initial (attempt=0 or multiplier degenerate), returns ceiling without invoking RNG") + void degenerateRangeReturnsCeiling() { + java.util.concurrent.atomic.AtomicInteger rngCalls = new java.util.concurrent.atomic.AtomicInteger(0); + java.util.function.LongBinaryOperator countingRng = (o, b) -> { rngCalls.incrementAndGet(); return o; }; + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 2.0, 0, countingRng)); + // multiplier=1.0 means scaled never grows beyond initial. + assertEquals(100L, SageMakerTransport.computeBackoff(100, 1000, 1.0, 5, countingRng)); + assertEquals(0, rngCalls.get(), "RNG must not be invoked when range collapses to a single value"); + } + + @Test + @DisplayName("with real ThreadLocalRandom: 1000 samples spread continuously across [initial, ceiling]") + void productionRngSpreadsRetries() { + // The whole point of this fix: in production, N concurrent retries should NOT cluster + // at the same ceiling value. Sample many times and assert the spread is meaningful. + int trials = 1000; + long min = Long.MAX_VALUE, max = Long.MIN_VALUE; + long sum = 0; + for (int i = 0; i < trials; i++) { + long b = SageMakerTransport.computeBackoff( + 100, 1000, 2.0, /*attempt*/ 4, + java.util.concurrent.ThreadLocalRandom.current()::nextLong); + min = Math.min(min, b); + max = Math.max(max, b); + sum += b; + } + // attempt=4 → ceiling=1000, range=[100,1000]. Expected spread is large. + assertTrue(min < 200, "min sample should land near initial floor; got " + min); + assertTrue(max > 900, "max sample should land near ceiling; got " + max); + long mean = sum / trials; + assertTrue(mean > 400 && mean < 700, + "mean of uniform [100,1000] should be near 550; got " + mean); + } + } } From 87bad38a00b5a429f9a9e9cdd94dc8ae201f94a5 Mon Sep 17 00:00:00 2001 From: Victor Wang Date: Mon, 4 May 2026 13:51:16 -0400 Subject: [PATCH 09/10] Allow option tuning of health check ping and netty event loop threads --- .../deepgram/sagemaker/SageMakerConfig.java | 84 +++++++++++++++++++ .../sagemaker/SageMakerTransportFactory.java | 35 +++++--- .../SageMakerTransportFactoryTest.java | 52 ++++++++++++ 3 files changed, 158 insertions(+), 13 deletions(-) diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java index a246e78..42329f5 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerConfig.java @@ -64,6 +64,38 @@ public class SageMakerConfig { */ public static final long DEFAULT_MAX_REPLAY_BUFFER_BYTES = 8L * 1024 * 1024; + /** + * Max concurrent HTTP/2 streams multiplexed onto a single underlying TCP connection. Defaults + * to 1, which gives each bidi stream its own dedicated TCP connection — preventing slow-stream + * starvation but creating one HTTP/2 keep-alive ping cycle per logical stream. Under heavy + * concurrent load (hundreds of simultaneous streams from one process), the resulting flood of + * pings can saturate the Netty event-loop pool and trigger spurious {@code PingFailedException} + * connection teardowns. Raise this (e.g. 50–200) to multiplex many streams onto fewer + * connections and slash the ping load. + */ + public static final long DEFAULT_MAX_STREAMS_PER_CONNECTION = 1L; + + /** + * Number of Netty event-loop worker threads handling HTTP/2 frames for the shared client. + * {@code null} (the default) lets the AWS SDK Netty client pick — currently {@code 2 * NCPU}. + * Override when running large numbers of concurrent streams on hardware where the default + * leaves event loops saturated by inbound transcript frames + ping ACK bookkeeping. + */ + public static final Integer DEFAULT_NETTY_EVENT_LOOP_THREADS = null; + + /** + * Period between HTTP/2 keep-alive PING frames sent to the server, and the timeout for the + * PING ACK. {@code null} (the default) leaves the AWS SDK Netty client default in place + * (currently 5 s). Under heavy single-process load (hundreds of concurrent streams + * sharing a small event-loop pool), 5 s is too tight: an event loop briefly busy processing + * inbound transcript frames can fail to read the PING ACK in time, causing + * {@code PingFailedException} → connection-death → cascading retries even though the + * underlying connection is healthy. Bump to 30 s+ to tolerate moderate event-loop + * stalls; pass {@code Duration.ZERO} to disable PING frames entirely (the SDK's own retry + * + stream-error path will still detect genuinely-dead connections). + */ + public static final Duration DEFAULT_HEALTH_CHECK_PING_PERIOD = null; + private final String endpointName; private final Region region; private final String contentType; @@ -78,6 +110,9 @@ public class SageMakerConfig { private final double backoffMultiplier; private final Duration retryBudget; private final long maxReplayBufferBytes; + private final long maxStreamsPerConnection; + private final Integer nettyEventLoopThreads; + private final Duration healthCheckPingPeriod; private SageMakerConfig(Builder builder) { this.endpointName = builder.endpointName; @@ -94,6 +129,9 @@ private SageMakerConfig(Builder builder) { this.backoffMultiplier = builder.backoffMultiplier; this.retryBudget = builder.retryBudget; this.maxReplayBufferBytes = builder.maxReplayBufferBytes; + this.maxStreamsPerConnection = builder.maxStreamsPerConnection; + this.nettyEventLoopThreads = builder.nettyEventLoopThreads; + this.healthCheckPingPeriod = builder.healthCheckPingPeriod; } public String endpointName() { return endpointName; } @@ -110,6 +148,9 @@ private SageMakerConfig(Builder builder) { public double backoffMultiplier() { return backoffMultiplier; } public Duration retryBudget() { return retryBudget; } public long maxReplayBufferBytes() { return maxReplayBufferBytes; } + public long maxStreamsPerConnection() { return maxStreamsPerConnection; } + public Integer nettyEventLoopThreads() { return nettyEventLoopThreads; } + public Duration healthCheckPingPeriod() { return healthCheckPingPeriod; } public static Builder builder() { return new Builder(); @@ -130,6 +171,9 @@ public static class Builder { private double backoffMultiplier = DEFAULT_BACKOFF_MULTIPLIER; private Duration retryBudget = DEFAULT_RETRY_BUDGET; private long maxReplayBufferBytes = DEFAULT_MAX_REPLAY_BUFFER_BYTES; + private long maxStreamsPerConnection = DEFAULT_MAX_STREAMS_PER_CONNECTION; + private Integer nettyEventLoopThreads = DEFAULT_NETTY_EVENT_LOOP_THREADS; + private Duration healthCheckPingPeriod = DEFAULT_HEALTH_CHECK_PING_PERIOD; public Builder endpointName(String endpointName) { this.endpointName = endpointName; @@ -272,6 +316,46 @@ public Builder maxReplayBufferBytes(long maxReplayBufferBytes) { return this; } + /** + * Max concurrent HTTP/2 streams per underlying TCP connection. See + * {@link #DEFAULT_MAX_STREAMS_PER_CONNECTION}. Raise above 1 to multiplex many bidi + * streams onto fewer connections (slashes ping load); leave at 1 for one-stream-per-TCP + * isolation. + */ + public Builder maxStreamsPerConnection(long maxStreamsPerConnection) { + if (maxStreamsPerConnection <= 0) { + throw new IllegalArgumentException("maxStreamsPerConnection must be positive"); + } + this.maxStreamsPerConnection = maxStreamsPerConnection; + return this; + } + + /** + * Number of Netty event-loop worker threads. {@code null} (default) uses the AWS SDK's + * default ({@code 2 * NCPU}). Override for high-burst single-process workloads where + * the default leaves event loops saturated by inbound frame processing. + */ + public Builder nettyEventLoopThreads(Integer nettyEventLoopThreads) { + if (nettyEventLoopThreads != null && nettyEventLoopThreads <= 0) { + throw new IllegalArgumentException("nettyEventLoopThreads must be positive or null"); + } + this.nettyEventLoopThreads = nettyEventLoopThreads; + return this; + } + + /** + * Period between HTTP/2 keep-alive PING frames (and ACK timeout). {@code null} (default) + * uses the AWS SDK Netty client default (5 s). Pass {@code Duration.ZERO} to disable + * PING frames entirely. See {@link #DEFAULT_HEALTH_CHECK_PING_PERIOD}. + */ + public Builder healthCheckPingPeriod(Duration healthCheckPingPeriod) { + if (healthCheckPingPeriod != null && healthCheckPingPeriod.isNegative()) { + throw new IllegalArgumentException("healthCheckPingPeriod must be non-negative or null"); + } + this.healthCheckPingPeriod = healthCheckPingPeriod; + return this; + } + public SageMakerConfig build() { if (endpointName == null || endpointName.isBlank()) { throw new IllegalArgumentException("endpointName is required"); diff --git a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java index 71e4437..9474865 100755 --- a/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java +++ b/sagemaker-transport/src/main/java/com/deepgram/sagemaker/SageMakerTransportFactory.java @@ -98,10 +98,30 @@ private static String sharedClientKey(SageMakerConfig c) { return c.region().id() + "|" + c.maxConcurrency() + "|" + c.connectionTimeout().toMillis() - + "|" + c.connectionAcquireTimeout().toMillis(); + + "|" + c.connectionAcquireTimeout().toMillis() + + "|" + c.maxStreamsPerConnection() + + "|" + (c.nettyEventLoopThreads() == null ? "default" : c.nettyEventLoopThreads()) + + "|" + (c.healthCheckPingPeriod() == null ? "default" : c.healthCheckPingPeriod().toMillis()); } private static SageMakerRuntimeHttp2AsyncClient buildClient(SageMakerConfig config) { + Http2Configuration.Builder http2Builder = Http2Configuration.builder() + .maxStreams(config.maxStreamsPerConnection()); + if (config.healthCheckPingPeriod() != null) { + http2Builder.healthCheckPingPeriod(config.healthCheckPingPeriod()); + } + NettyNioAsyncHttpClient.Builder httpBuilder = NettyNioAsyncHttpClient.builder() + .protocol(Protocol.HTTP2) + .maxConcurrency(config.maxConcurrency()) + .connectionTimeout(config.connectionTimeout()) + .connectionAcquisitionTimeout(config.connectionAcquireTimeout()) + .http2Configuration(http2Builder.build()); + if (config.nettyEventLoopThreads() != null) { + httpBuilder.eventLoopGroupBuilder( + software.amazon.awssdk.http.nio.netty.SdkEventLoopGroup.builder() + .numberOfThreads(config.nettyEventLoopThreads()) + ); + } return SageMakerRuntimeHttp2AsyncClient.builder() .region(config.region()) // Disable the AWS SDK's internal retry strategy. SageMakerTransport owns the @@ -114,18 +134,7 @@ private static SageMakerRuntimeHttp2AsyncClient buildClient(SageMakerConfig conf // transient TLS / connection-reset hiccups are still caught by our IOException // → RETRYABLE classification and the same backoff path. .overrideConfiguration(c -> c.retryStrategy(AwsRetryStrategy.doNotRetry())) - .httpClientBuilder( - NettyNioAsyncHttpClient.builder() - .protocol(Protocol.HTTP2) - .maxConcurrency(config.maxConcurrency()) - .connectionTimeout(config.connectionTimeout()) - .connectionAcquisitionTimeout(config.connectionAcquireTimeout()) - .http2Configuration( - Http2Configuration.builder() - .maxStreams(1L) - .build() - ) - ) + .httpClientBuilder(httpBuilder) .build(); } diff --git a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java index e9a5871..27047b6 100755 --- a/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java +++ b/sagemaker-transport/src/test/java/com/deepgram/sagemaker/SageMakerTransportFactoryTest.java @@ -303,6 +303,58 @@ void shutdownAllSharedClientsClearsThePool() { SageMakerTransportFactory.shutdownAllSharedClients(); } + @Test + void differentMaxStreamsForcesDistinctSharedClients() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig one = SageMakerConfig.builder() + .endpointName("e").maxStreamsPerConnection(1).build(); + SageMakerConfig hundred = SageMakerConfig.builder() + .endpointName("e").maxStreamsPerConnection(100).build(); + + SageMakerRuntimeHttp2AsyncClient c1 = getSmClient(new SageMakerTransportFactory(one)); + SageMakerRuntimeHttp2AsyncClient c100 = getSmClient(new SageMakerTransportFactory(hundred)); + + assertNotSame(c1, c100, + "different maxStreamsPerConnection must produce distinct shared clients (different cache keys)"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + @Test + void differentNettyEventLoopThreadsForcesDistinctSharedClients() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig dflt = SageMakerConfig.builder() + .endpointName("e").build(); + SageMakerConfig eighty = SageMakerConfig.builder() + .endpointName("e").nettyEventLoopThreads(80).build(); + + SageMakerRuntimeHttp2AsyncClient cDefault = getSmClient(new SageMakerTransportFactory(dflt)); + SageMakerRuntimeHttp2AsyncClient c80 = getSmClient(new SageMakerTransportFactory(eighty)); + + assertNotSame(cDefault, c80, + "different nettyEventLoopThreads must produce distinct shared clients (different cache keys)"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + + @Test + void sameStreamAndEventLoopConfigSharesClient() { + SageMakerTransportFactory.shutdownAllSharedClients(); + SageMakerConfig configA = SageMakerConfig.builder() + .endpointName("e-a").maxStreamsPerConnection(50).nettyEventLoopThreads(40).build(); + SageMakerConfig configB = SageMakerConfig.builder() + .endpointName("e-b").maxStreamsPerConnection(50).nettyEventLoopThreads(40).build(); + + SageMakerRuntimeHttp2AsyncClient cA = getSmClient(new SageMakerTransportFactory(configA)); + SageMakerRuntimeHttp2AsyncClient cB = getSmClient(new SageMakerTransportFactory(configB)); + + // endpointName differs but it's not in the cache key — the underlying Netty client should be shared. + assertSame(cA, cB, + "configs with same Netty-affecting fields share the underlying client even with different endpointName"); + + SageMakerTransportFactory.shutdownAllSharedClients(); + } + /** Reflection helper — read the package-private smClient field set by the constructor. */ private static SageMakerRuntimeHttp2AsyncClient getSmClient(SageMakerTransportFactory f) { try { From 7fc22a9d952beac00f241c69363184ae589fc452 Mon Sep 17 00:00:00 2001 From: Greg Holmes Date: Wed, 6 May 2026 13:56:12 +0100 Subject: [PATCH 10/10] Update Deepgram SDK dependency to 0.4.0 Align the transport metadata and Flux examples with the latest typed Listen V2 model API so the library continues to build cleanly against the current SDK release. --- README.md | 4 ++-- .../java/com/deepgram/examples/FluxSageMakerExample.java | 5 +++-- .../com/deepgram/examples/LiveMicFluxSageMakerExample.java | 3 ++- pom.xml | 2 +- sagemaker-transport/build.gradle | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 95e1126..d05b530 100755 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ SageMaker transport for the [Deepgram Java SDK](https://github.com/deepgram/deep ```groovy dependencies { - implementation 'com.deepgram:deepgram-java-sdk:0.3.0' + implementation 'com.deepgram:deepgram-java-sdk:0.4.0' implementation 'com.deepgram:deepgram-sagemaker:0.1.2' // x-release-please-version } ``` @@ -29,7 +29,7 @@ dependencies { ## Requirements - Java 11+ -- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.3.0+ (the `default ReconnectOptions reconnectOptions()` hook on `DeepgramTransportFactory` is required for storm absorption) +- [Deepgram Java SDK](https://github.com/deepgram/deepgram-java-sdk) v0.4.0+ (the `default ReconnectOptions reconnectOptions()` hook on `DeepgramTransportFactory` is required for storm absorption) - AWS credentials configured (environment variables, shared credentials file, or IAM role) - A Deepgram model deployed to an AWS SageMaker endpoint diff --git a/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java b/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java index 89f3d55..ee2dacf 100644 --- a/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java +++ b/examples/src/main/java/com/deepgram/examples/FluxSageMakerExample.java @@ -8,6 +8,7 @@ import com.deepgram.resources.listen.v2.websocket.V2WebSocketClient; import com.deepgram.sagemaker.SageMakerConfig; import com.deepgram.sagemaker.SageMakerTransportFactory; +import com.deepgram.types.ListenV2Model; import java.io.RandomAccessFile; import java.nio.ByteBuffer; @@ -85,10 +86,10 @@ public static void main(String[] args) throws Exception { done.countDown(); }); - // Connect — V2 uses model name as string via additionalProperty + // Connect using the typed Flux model constant from the SDK. CompletableFuture connectFuture = wsClient.connect( V2ConnectOptions.builder() - .model("flux-general-en") + .model(ListenV2Model.FLUX_GENERAL_EN) .build()); connectFuture.get(30, TimeUnit.SECONDS); System.out.println("Connected. Streaming audio...\n"); diff --git a/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java b/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java index b4e7d36..8c05588 100644 --- a/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java +++ b/examples/src/main/java/com/deepgram/examples/LiveMicFluxSageMakerExample.java @@ -9,6 +9,7 @@ import com.deepgram.sagemaker.SageMakerConfig; import com.deepgram.sagemaker.SageMakerTransportFactory; import com.deepgram.types.ListenV2Encoding; +import com.deepgram.types.ListenV2Model; import com.deepgram.types.ListenV2SampleRate; import javax.sound.sampled.AudioFormat; @@ -100,7 +101,7 @@ public static void main(String[] args) throws Exception { CompletableFuture connectFuture = wsClient.connect( V2ConnectOptions.builder() - .model("flux-general-en") + .model(ListenV2Model.FLUX_GENERAL_EN) .encoding(ListenV2Encoding.LINEAR16) .sampleRate(ListenV2SampleRate.of(16000)) .build()); diff --git a/pom.xml b/pom.xml index cfbadb2..1b63052 100644 --- a/pom.xml +++ b/pom.xml @@ -69,7 +69,7 @@ com.deepgram deepgram-java-sdk - 0.2.1 + 0.4.0 software.amazon.awssdk diff --git a/sagemaker-transport/build.gradle b/sagemaker-transport/build.gradle index 544aa30..f7a8536 100755 --- a/sagemaker-transport/build.gradle +++ b/sagemaker-transport/build.gradle @@ -1,6 +1,6 @@ dependencies { // Deepgram Java SDK — provides DeepgramTransport / DeepgramTransportFactory interfaces - api 'com.deepgram:deepgram-java-sdk:0.3.0' + api 'com.deepgram:deepgram-java-sdk:0.4.0' // AWS SDK v2 — SageMaker Runtime HTTP/2 bidirectional streaming api platform('software.amazon.awssdk:bom:2.42.0')