diff --git a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java index 6d4048076a..00974e24d0 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java @@ -547,6 +547,19 @@ public double getCloudFetchSpeedThreshold() { return Double.parseDouble(getParameter(DatabricksJdbcUrlParams.CLOUD_FETCH_SPEED_THRESHOLD)); } + /** Fraction of the JVM max heap used as the default CloudFetch in-memory budget. */ + private static final double CLOUD_FETCH_HEAP_FRACTION = 0.2; + + @Override + public long getCloudFetchMaxBytesInMemory() { + long configured = + Long.parseLong(getParameter(DatabricksJdbcUrlParams.CLOUD_FETCH_MAX_BYTES_IN_MEMORY)); + if (configured > 0) { + return configured; + } + return (long) (Runtime.getRuntime().maxMemory() * CLOUD_FETCH_HEAP_FRACTION); + } + @Override public String getCatalog() { return getParameter(DatabricksJdbcUrlParams.CONN_CATALOG); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java index 17fa6f2531..22f1344f98 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractArrowResultChunk.java @@ -43,6 +43,7 @@ public abstract class AbstractArrowResultChunk { protected final long numRows; protected final long rowOffset; protected final long chunkIndex; + protected final long chunkSizeInBytes; protected final StatementId statementId; protected final BufferAllocator rootAllocator; @@ -89,6 +90,7 @@ protected AbstractArrowResultChunk( long numRows, long rowOffset, long chunkIndex, + long chunkSizeInBytes, StatementId statementId, ChunkStatus initialStatus, ExternalLink chunkLink, @@ -97,6 +99,7 @@ protected AbstractArrowResultChunk( this.numRows = numRows; this.rowOffset = rowOffset; this.chunkIndex = chunkIndex; + this.chunkSizeInBytes = chunkSizeInBytes; this.statementId = statementId; this.rootAllocator = ArrowBufferAllocator.getBufferAllocator(); this.chunkReadyFuture = new CompletableFuture<>(); @@ -115,6 +118,15 @@ public Long getChunkIndex() { return chunkIndex; } + /** + * Returns the size of this chunk in bytes, as reported by the result manifest, or 0 if unknown. + * + * @return chunk size in bytes + */ + public long getChunkSizeInBytes() { + return chunkSizeInBytes; + } + /** * Returns the start row offset of this chunk in the overall result set. * diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractRemoteChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractRemoteChunkProvider.java index 4717852f48..5a813a29ca 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractRemoteChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/AbstractRemoteChunkProvider.java @@ -51,6 +51,8 @@ public abstract class AbstractRemoteChunkProvider ch } } + /** + * Conservative per-chunk byte size charged against the in-memory budget when the result manifest + * does not report a chunk's size (SEA {@code getByteCount()} null or Thrift {@code getBytesNum()} + * unset, surfacing as {@link AbstractArrowResultChunk#getChunkSizeInBytes()} == 0). Without this + * fallback the budget accounting would never grow for size-less chunks, silently disabling the + * byte budget and degrading to the count limit alone — exactly on the large-result workloads this + * budget is meant to protect. + */ + static final long UNKNOWN_CHUNK_SIZE_ESTIMATE_BYTES = 16 * 1024 * 1024L; + + /** + * Returns the byte cost charged to a chunk for in-memory budgeting: the manifest-reported size + * when known, otherwise {@link #UNKNOWN_CHUNK_SIZE_ESTIMATE_BYTES}. The same value must be used + * when reserving budget (scheduling) and releasing it (consumption) so accounting stays balanced. + */ + protected long effectiveChunkSizeInBytes(long declaredChunkSizeInBytes) { + return declaredChunkSizeInBytes > 0 + ? declaredChunkSizeInBytes + : UNKNOWN_CHUNK_SIZE_ESTIMATE_BYTES; + } + /** Release the memory for previous chunk since it is already consumed */ private void releaseChunk() throws DatabricksSQLException { - if (chunkIndexToChunksMap.get(currentChunkIndex).releaseChunk()) { + T chunk = chunkIndexToChunksMap.get(currentChunkIndex); + if (chunk.releaseChunk()) { totalChunksInMemory--; + totalBytesInMemory -= effectiveChunkSizeInBytes(chunk.getChunkSizeInBytes()); downloadNextChunks(); } } + + /** + * Returns whether a chunk of the given (effective) size can be scheduled for download without + * breaching the count or in-memory byte budgets. Callers must pass the value from {@link + * #effectiveChunkSizeInBytes(long)} so that size-less chunks still consume budget. At least one + * chunk is always allowed so a single oversized chunk cannot stall consumption. + */ + protected boolean canScheduleChunkDownload(long chunkSizeInBytes) { + if (totalChunksInMemory >= allowedChunksInMemory) { + return false; + } + if (maxBytesInMemory <= 0 || totalChunksInMemory == 0) { + return true; + } + return totalBytesInMemory + chunkSizeInBytes <= maxBytesInMemory; + } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java index 67c6e57c70..709e4e5b7d 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunk.java @@ -36,6 +36,7 @@ private ArrowResultChunk(Builder builder) throws DatabricksParsingException { builder.numRows, builder.rowOffset, builder.chunkIndex, + builder.chunkSizeInBytes, builder.statementId, builder.status, builder.chunkLink, @@ -101,15 +102,16 @@ protected void downloadData( chunkLink.getExternalLink(), speedThreshold); - // Decompress (if needed) and parse + // Decompress and parse. The decompression is streamed straight into the Arrow reader so the + // full decompressed payload is never materialized on-heap alongside the compressed bytes. long decompressStart = System.nanoTime(); try { String ctx = String.format( "Data decompression for chunk index [%d] and statement [%s]", this.chunkIndex, this.statementId); - InputStream data = DecompressionUtil.decompressToStream(compressed, compressionCodec, ctx); - initializeData(data); + initializeData( + DecompressionUtil.decompressToInputStream(compressed, compressionCodec, ctx)); } catch (Exception e) { handleFailure(e, ChunkStatus.PROCESSING_FAILED); } @@ -193,6 +195,7 @@ public static class Builder { private long chunkIndex; private long numRows; private long rowOffset; + private long chunkSizeInBytes; private ExternalLink chunkLink; private StatementId statementId; private Instant expiryTime; @@ -216,6 +219,8 @@ public Builder withChunkInfo(BaseChunkInfo baseChunkInfo) { this.chunkIndex = baseChunkInfo.getChunkIndex(); this.numRows = baseChunkInfo.getRowCount(); this.rowOffset = baseChunkInfo.getRowOffset(); + this.chunkSizeInBytes = + baseChunkInfo.getByteCount() != null ? baseChunkInfo.getByteCount() : 0L; this.status = status == null ? ChunkStatus.PENDING : status; return this; } @@ -248,6 +253,7 @@ public Builder withThriftChunkInfo(long chunkIndex, TSparkArrowResultLink chunkI this.chunkIndex = chunkIndex; this.numRows = chunkInfo.getRowCount(); this.rowOffset = chunkInfo.getStartRowOffset(); + this.chunkSizeInBytes = chunkInfo.getBytesNum(); this.expiryTime = Instant.ofEpochMilli(chunkInfo.getExpiryTime()); this.status = status == null diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProvider.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProvider.java index c9eefa9b8e..e3c2f3f9e1 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProvider.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProvider.java @@ -100,7 +100,8 @@ protected ArrowResultChunk createChunk( * *
  • Tracks the total chunks in memory and the next chunk to download * @@ -115,13 +116,19 @@ public void downloadNextChunks() { chunkDownloaderExecutorService = createChunksDownloaderExecutorService(); } - while (!isClosed - && nextChunkToDownload < chunkCount - && totalChunksInMemory < allowedChunksInMemory) { + while (!isClosed && nextChunkToDownload < chunkCount) { ArrowResultChunk chunk = chunkIndexToChunksMap.get(nextChunkToDownload); + long chunkSizeInBytes = effectiveChunkSizeInBytes(chunk.getChunkSizeInBytes()); + if (!canScheduleChunkDownload(chunkSizeInBytes)) { + // Budget is full; leave nextChunkToDownload unadvanced so this chunk is retried the next + // time downloadNextChunks() runs (invoked from releaseChunk() once a consumed chunk frees + // budget). The always-allow-one rule in canScheduleChunkDownload guarantees progress. + break; + } chunkDownloaderExecutorService.submit( new ChunkDownloadTask(chunk, httpClient, this, linkDownloadService)); totalChunksInMemory++; + totalBytesInMemory += chunkSizeInBytes; nextChunkToDownload++; } } diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/ArrowResultChunkV2.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/ArrowResultChunkV2.java index fc7052cb7e..4d98b6856a 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/ArrowResultChunkV2.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/ArrowResultChunkV2.java @@ -107,6 +107,7 @@ private ArrowResultChunkV2(Builder builder) { builder.numRows, builder.rowOffset, builder.chunkIndex, + builder.chunkSizeInBytes, builder.statementId, builder.status, builder.chunkLink, @@ -270,6 +271,7 @@ public static class Builder { private long chunkIndex; private long numRows; private long rowOffset; + private long chunkSizeInBytes; private ExternalLink chunkLink; private StatementId statementId; private Instant expiryTime; @@ -286,6 +288,8 @@ public Builder withChunkInfo(BaseChunkInfo baseChunkInfo) { this.chunkIndex = baseChunkInfo.getChunkIndex(); this.numRows = baseChunkInfo.getRowCount(); this.rowOffset = baseChunkInfo.getRowOffset(); + this.chunkSizeInBytes = + baseChunkInfo.getByteCount() != null ? baseChunkInfo.getByteCount() : 0L; this.status = ChunkStatus.PENDING; return this; } @@ -294,6 +298,7 @@ public Builder withThriftChunkInfo(long chunkIndex, TSparkArrowResultLink chunkI this.chunkIndex = chunkIndex; this.numRows = chunkInfo.getRowCount(); this.rowOffset = chunkInfo.getStartRowOffset(); + this.chunkSizeInBytes = chunkInfo.getBytesNum(); this.expiryTime = Instant.ofEpochMilli(chunkInfo.getExpiryTime()); this.status = ChunkStatus.URL_FETCHED; // URL has always been fetched in case of thrift this.chunkLink = createExternalLink(chunkInfo, chunkIndex); diff --git a/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/RemoteChunkProviderV2.java b/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/RemoteChunkProviderV2.java index 848b9be611..daa88b4459 100644 --- a/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/RemoteChunkProviderV2.java +++ b/src/main/java/com/databricks/jdbc/api/impl/arrow/incubator/RemoteChunkProviderV2.java @@ -111,11 +111,17 @@ protected ArrowResultChunkV2 createChunk( */ @Override public void downloadNextChunks() throws DatabricksSQLException { - while (!isClosed - && nextChunkToDownload < chunkCount - && totalChunksInMemory < allowedChunksInMemory) { + while (!isClosed && nextChunkToDownload < chunkCount) { ArrowResultChunkV2 chunk = chunkIndexToChunksMap.get(nextChunkToDownload); + long chunkSizeInBytes = effectiveChunkSizeInBytes(chunk.getChunkSizeInBytes()); + if (!canScheduleChunkDownload(chunkSizeInBytes)) { + // Budget is full; leave nextChunkToDownload unadvanced so this chunk is retried the next + // time downloadNextChunks() runs (invoked from releaseChunk() once a consumed chunk frees + // budget). The always-allow-one rule in canScheduleChunkDownload guarantees progress. + break; + } totalChunksInMemory++; + totalBytesInMemory += chunkSizeInBytes; if (chunk.isChunkLinkInvalid()) { try { ExternalLink link = diff --git a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java index 8ca5ae0cd2..52b532717b 100644 --- a/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java +++ b/src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java @@ -163,6 +163,15 @@ public interface IDatabricksConnectionContext { /** Returns the minimum expected download speed threshold in MB/s for CloudFetch operations */ double getCloudFetchSpeedThreshold(); + /** + * Returns the per-result-set budget, in bytes, for result chunks buffered in memory at once + * during CloudFetch downloads. The budget is compared against the compressed chunk sizes reported + * by the result manifest, so the default (derived from the JVM max heap when the configured value + * is non-positive) is intentionally conservative. Bounds peak memory in addition to the chunk + * download thread-pool limit. + */ + long getCloudFetchMaxBytesInMemory(); + Boolean getDirectResultMode(); Boolean shouldRetryTemporarilyUnavailableError(); diff --git a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java index 639288248d..86295bbb16 100644 --- a/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java +++ b/src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java @@ -74,6 +74,10 @@ public enum DatabricksJdbcUrlParams { IDLE_HTTP_CONNECTION_EXPIRY("IdleHttpConnectionExpiry", "Idle HTTP connection expiry", "60"), SUPPORT_MANY_PARAMETERS("supportManyParameters", "Support many parameters", "0"), CLOUD_FETCH_THREAD_POOL_SIZE("cloudFetchThreadPoolSize", "Cloud fetch thread pool size", "16"), + CLOUD_FETCH_MAX_BYTES_IN_MEMORY( + "cloudFetchMaxBytesInMemory", + "Maximum bytes of result chunks buffered in memory at once (0 = derive from heap)", + "0"), OAUTH_ENDPOINT("OAuth2ConnAuthAuthorizeEndpoint", "OAuth2 authorization endpoint"), AUTH_ENDPOINT( "OAuth2AuthorizationEndPoint", "OAuth2 authorization endpoint"), // Same as OAUTH_ENDPOINT diff --git a/src/main/java/com/databricks/jdbc/common/util/DecompressionUtil.java b/src/main/java/com/databricks/jdbc/common/util/DecompressionUtil.java index 213f15a635..63316261be 100644 --- a/src/main/java/com/databricks/jdbc/common/util/DecompressionUtil.java +++ b/src/main/java/com/databricks/jdbc/common/util/DecompressionUtil.java @@ -60,6 +60,35 @@ public static InputStream decompressToStream( return new ByteArrayInputStream(uncompressed); } + /** + * Returns a stream that decompresses {@code compressedInput} lazily as it is read, so the full + * decompressed payload is never materialized alongside the compressed bytes. + */ + public static InputStream decompressToInputStream( + byte[] compressedInput, CompressionCodec compressionCodec, String context) + throws DatabricksSQLException { + if (compressionCodec == null + || compressedInput == null + || compressionCodec == CompressionCodec.NONE) { + return new ByteArrayInputStream(compressedInput); + } + if (compressionCodec == CompressionCodec.LZ4_FRAME) { + try { + return new LZ4FrameInputStream(new ByteArrayInputStream(compressedInput)); + } catch (IOException e) { + String errorMessage = + String.format("Unable to de-compress LZ4 Frame compressed result %s", context); + LOGGER.error(e, errorMessage); + throw new DatabricksParsingException( + errorMessage, e, DatabricksDriverErrorCode.DECOMPRESSION_ERROR); + } + } + String errorMessage = + String.format("Unknown compression type: %s. Context : %s", compressionCodec, context); + LOGGER.error(errorMessage); + throw new DatabricksSQLException(errorMessage, DatabricksDriverErrorCode.DECOMPRESSION_ERROR); + } + public static InputStream decompress( InputStream compressedStream, CompressionCodec compressionCodec, String context) throws IOException, DatabricksSQLException { diff --git a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java index 6e35930353..b590875b79 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java @@ -392,6 +392,21 @@ public void testEnableCloudFetch() throws DatabricksSQLException { assertTrue(connectionContext.shouldEnableArrow()); } + @Test + public void testCloudFetchMaxBytesInMemory() throws DatabricksSQLException { + // Explicit configured value is honored + IDatabricksConnectionContext ctx = + DatabricksConnectionContext.parse( + TestConstants.VALID_URL_1 + ";cloudFetchMaxBytesInMemory=12345", properties); + assertEquals(12345L, ctx.getCloudFetchMaxBytesInMemory()); + + // Default (0) derives a positive budget bounded by the JVM max heap + IDatabricksConnectionContext defaultCtx = + DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, properties); + assertTrue(defaultCtx.getCloudFetchMaxBytesInMemory() > 0); + assertTrue(defaultCtx.getCloudFetchMaxBytesInMemory() <= Runtime.getRuntime().maxMemory()); + } + @Test public void testShouldEnableArrow_defaultIsTrue() throws DatabricksSQLException { // On non-AIX, Arrow is always enabled regardless of EnableArrow setting diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunkTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunkTest.java index 38db00f44b..db1f830a29 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunkTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/ArrowResultChunkTest.java @@ -104,6 +104,44 @@ public void testGetArrowDataFromThriftInput() throws DatabricksParsingException assertEquals(arrowResultChunk.getChunkIndex(), 0); } + @Test + public void testChunkSizeInBytesIsCaptured() throws DatabricksParsingException { + BaseChunkInfo seaChunkInfo = + new BaseChunkInfo().setChunkIndex(0L).setByteCount(4096L).setRowOffset(0L).setRowCount(1L); + assertEquals( + 4096L, + ArrowResultChunk.builder() + .withStatementId(TEST_STATEMENT_ID) + .withChunkInfo(seaChunkInfo) + .build() + .getChunkSizeInBytes()); + + // A null byte count from the manifest is treated as unknown (0). + BaseChunkInfo seaChunkInfoNoBytes = + new BaseChunkInfo().setChunkIndex(0L).setRowOffset(0L).setRowCount(1L); + assertEquals( + 0L, + ArrowResultChunk.builder() + .withStatementId(TEST_STATEMENT_ID) + .withChunkInfo(seaChunkInfoNoBytes) + .build() + .getChunkSizeInBytes()); + + TSparkArrowResultLink thriftChunkInfo = + new TSparkArrowResultLink() + .setRowCount(1L) + .setFileLink(TEST_STRING) + .setExpiryTime(1000) + .setBytesNum(8192L); + assertEquals( + 8192L, + ArrowResultChunk.builder() + .withStatementId(TEST_STATEMENT_ID) + .withThriftChunkInfo(0, thriftChunkInfo) + .build() + .getChunkSizeInBytes()); + } + private File createTestArrowFile( String fileName, Schema schema, Object[][] testData, RootAllocator allocator) throws IOException { diff --git a/src/test/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProviderTest.java b/src/test/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProviderTest.java index 21cbad1f99..9f64ccc990 100644 --- a/src/test/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProviderTest.java +++ b/src/test/java/com/databricks/jdbc/api/impl/arrow/RemoteChunkProviderTest.java @@ -1,20 +1,25 @@ package com.databricks.jdbc.api.impl.arrow; import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import com.databricks.jdbc.api.internal.IDatabricksConnectionContext; import com.databricks.jdbc.api.internal.IDatabricksSession; +import com.databricks.jdbc.common.CompressionCodec; import com.databricks.jdbc.common.DatabricksClientType; +import com.databricks.jdbc.dbclient.IDatabricksHttpClient; import com.databricks.jdbc.dbclient.impl.common.StatementId; +import com.databricks.jdbc.model.core.ExternalLink; import com.databricks.jdbc.model.core.ResultData; import com.databricks.jdbc.model.core.ResultManifest; import com.databricks.jdbc.model.core.ResultSchema; +import com.databricks.sdk.service.sql.BaseChunkInfo; import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.MockedConstruction; import org.mockito.junit.jupiter.MockitoExtension; @ExtendWith(MockitoExtension.class) @@ -38,4 +43,178 @@ public void testInitEmptyChunkDownloader() { new RemoteChunkProvider( STATEMENT_ID, resultManifest, resultData, mockSession, null, 4)); } + + @Test + public void testByteBudgetLimitsScheduledChunks() + throws com.databricks.jdbc.exception.DatabricksSQLException { + int chunkCount = 5; + long chunkBytes = 10L; + long budget = 25L; // 10 + 10 <= 25, but a third chunk (30) exceeds the budget + List chunks = new ArrayList<>(); + List links = new ArrayList<>(); + for (int i = 0; i < chunkCount; i++) { + chunks.add( + new BaseChunkInfo() + .setChunkIndex((long) i) + .setRowCount(10L) + .setRowOffset(i * 10L) + .setByteCount(chunkBytes)); + links.add( + new ExternalLink() + .setChunkIndex((long) i) + .setExpiration("2099-01-01T00:00:00Z") + .setExternalLink("https://test.databricks.com/chunks/" + i)); + } + ResultManifest manifest = + new ResultManifest() + .setTotalChunkCount((long) chunkCount) + .setTotalRowCount(50L) + .setResultCompression(CompressionCodec.NONE) + .setChunks(chunks) + .setSchema(new ResultSchema().setColumns(new ArrayList<>())); + ResultData resultData = new ResultData().setExternalLinks(links); + + IDatabricksConnectionContext ctx = mock(IDatabricksConnectionContext.class); + when(mockSession.getConnectionContext()).thenReturn(ctx); + lenient().when(ctx.getCloudFetchMaxBytesInMemory()).thenReturn(budget); + + try (MockedConstruction ignored = + mockConstruction(ChunkLinkDownloadService.class)) { + RemoteChunkProvider provider = + new RemoteChunkProvider( + STATEMENT_ID, + manifest, + resultData, + mockSession, + mock(IDatabricksHttpClient.class), + 10 /* maxParallelChunkDownloadsPerQuery, larger than the byte budget allows */); + + // Only two chunks fit within the byte budget even though more parallel slots are available. + assertEquals(2, provider.nextChunkToDownload); + assertEquals(2 * chunkBytes, provider.totalBytesInMemory); + provider.close(); + } + } + + @Test + public void testOversizedChunkIsStillScheduled() + throws com.databricks.jdbc.exception.DatabricksSQLException { + // Budget smaller than a single chunk must still schedule one chunk so consumption can progress. + int scheduled = scheduledChunksFor(/* budget= */ 5L, /* chunkBytes= */ 10L, /* chunks= */ 3); + assertEquals(1, scheduled); + } + + @Test + public void testNonPositiveBudgetMeansNoByteLimit() + throws com.databricks.jdbc.exception.DatabricksSQLException { + // A non-positive budget disables the byte gate; only the parallel-count limit applies. + int scheduled = scheduledChunksFor(/* budget= */ 0L, /* chunkBytes= */ 1_000L, /* chunks= */ 3); + assertEquals(3, scheduled); + } + + @Test + public void testUnknownChunkSizeStillConsumesBudget() + throws com.databricks.jdbc.exception.DatabricksSQLException { + // When the manifest does not report chunk sizes (getByteCount() == null → chunkSizeInBytes 0), + // each chunk is charged the unknown-size estimate against the budget so the byte budget stays + // active instead of silently degrading to the count limit. With a budget just above one + // estimate, only a single chunk can be scheduled (the always-allow-one rule), even though many + // parallel slots are free. + long budget = AbstractRemoteChunkProvider.UNKNOWN_CHUNK_SIZE_ESTIMATE_BYTES + 1; + int chunkCount = 5; + List chunks = new ArrayList<>(); + List links = new ArrayList<>(); + for (int i = 0; i < chunkCount; i++) { + chunks.add( + new BaseChunkInfo() + .setChunkIndex((long) i) + .setRowCount(10L) + .setRowOffset(i * 10L)); // no setByteCount → unknown size + links.add( + new ExternalLink() + .setChunkIndex((long) i) + .setExpiration("2099-01-01T00:00:00Z") + .setExternalLink("https://test.databricks.com/chunks/" + i)); + } + ResultManifest manifest = + new ResultManifest() + .setTotalChunkCount((long) chunkCount) + .setTotalRowCount(chunkCount * 10L) + .setResultCompression(CompressionCodec.NONE) + .setChunks(chunks) + .setSchema(new ResultSchema().setColumns(new ArrayList<>())); + ResultData resultData = new ResultData().setExternalLinks(links); + + IDatabricksConnectionContext ctx = mock(IDatabricksConnectionContext.class); + when(mockSession.getConnectionContext()).thenReturn(ctx); + lenient().when(ctx.getCloudFetchMaxBytesInMemory()).thenReturn(budget); + + try (MockedConstruction ignored = + mockConstruction(ChunkLinkDownloadService.class)) { + RemoteChunkProvider provider = + new RemoteChunkProvider( + STATEMENT_ID, + manifest, + resultData, + mockSession, + mock(IDatabricksHttpClient.class), + 10 /* parallel slots far exceed what the byte budget admits */); + + assertEquals(1, provider.nextChunkToDownload); + assertEquals( + AbstractRemoteChunkProvider.UNKNOWN_CHUNK_SIZE_ESTIMATE_BYTES, + provider.totalBytesInMemory); + provider.close(); + } + } + + /** + * Builds a {@link RemoteChunkProvider} with the given byte budget and uniform-sized chunks and + * returns how many chunks were scheduled for download on construction. + */ + private int scheduledChunksFor(long budget, long chunkBytes, int chunkCount) + throws com.databricks.jdbc.exception.DatabricksSQLException { + List chunks = new ArrayList<>(); + List links = new ArrayList<>(); + for (int i = 0; i < chunkCount; i++) { + chunks.add( + new BaseChunkInfo() + .setChunkIndex((long) i) + .setRowCount(10L) + .setRowOffset(i * 10L) + .setByteCount(chunkBytes)); + links.add( + new ExternalLink() + .setChunkIndex((long) i) + .setExpiration("2099-01-01T00:00:00Z") + .setExternalLink("https://test.databricks.com/chunks/" + i)); + } + ResultManifest manifest = + new ResultManifest() + .setTotalChunkCount((long) chunkCount) + .setTotalRowCount(chunkCount * 10L) + .setResultCompression(CompressionCodec.NONE) + .setChunks(chunks) + .setSchema(new ResultSchema().setColumns(new ArrayList<>())); + ResultData resultData = new ResultData().setExternalLinks(links); + + IDatabricksConnectionContext ctx = mock(IDatabricksConnectionContext.class); + when(mockSession.getConnectionContext()).thenReturn(ctx); + lenient().when(ctx.getCloudFetchMaxBytesInMemory()).thenReturn(budget); + + try (MockedConstruction ignored = + mockConstruction(ChunkLinkDownloadService.class)) { + RemoteChunkProvider provider = + new RemoteChunkProvider( + STATEMENT_ID, + manifest, + resultData, + mockSession, + mock(IDatabricksHttpClient.class), + 10); + int scheduled = (int) provider.nextChunkToDownload; + provider.close(); + return scheduled; + } + } } diff --git a/src/test/java/com/databricks/jdbc/common/util/DecompressionUtilTest.java b/src/test/java/com/databricks/jdbc/common/util/DecompressionUtilTest.java index f8cbe0549e..f0f7bdfa19 100644 --- a/src/test/java/com/databricks/jdbc/common/util/DecompressionUtilTest.java +++ b/src/test/java/com/databricks/jdbc/common/util/DecompressionUtilTest.java @@ -39,6 +39,29 @@ public void testDecompressLZ4Frame() throws Exception { IOUtils.contentEquals(resultStream, new ByteArrayInputStream(INITIAL_STRING.getBytes()))); } + @Test + public void testDecompressToInputStreamLZ4Frame() throws Exception { + byte[] uncompressed = INITIAL_STRING.getBytes(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(out)) { + lz4.write(uncompressed); + } + InputStream resultStream = + DecompressionUtil.decompressToInputStream( + out.toByteArray(), CompressionCodec.LZ4_FRAME, CONTEXT); + assertTrue( + IOUtils.contentEquals(resultStream, new ByteArrayInputStream(uncompressed)), + "Streaming decompression should yield the original bytes"); + } + + @Test + public void testDecompressToInputStreamNoneReturnsRawBytes() throws Exception { + byte[] raw = INITIAL_STRING.getBytes(); + InputStream resultStream = + DecompressionUtil.decompressToInputStream(raw, CompressionCodec.NONE, CONTEXT); + assertTrue(IOUtils.contentEquals(resultStream, new ByteArrayInputStream(raw))); + } + @Test public void testDecompressLZ4FrameSkipsCompression() throws Exception { assertEquals(