From 10d48b0dcb44d333b91a454af5959e982e7d41a1 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Tue, 3 Mar 2026 11:16:08 -0500 Subject: [PATCH 1/7] NIFI-15669: Refactored ConsumeKinesis to remove dependency on KCL. This provides much faster startup times and drastically reduces heap utilization when using Enhanced Fan-Out (EFO) mode. --- .../nifi-aws-bundle/nifi-aws-kinesis/pom.xml | 59 +- .../aws/kinesis/CheckpointTableUtils.java | 255 +++ .../aws/kinesis/ConsumeKinesis.java | 1624 +++++++++++------ .../aws/kinesis/ConsumeKinesisAttributes.java | 80 - .../aws/kinesis/DeaggregatedRecord.java | 41 + .../aws/kinesis/EfoKinesisClient.java | 596 ++++++ .../aws/kinesis/KinesisConsumerClient.java | 169 ++ .../KinesisRecordMetadata.java | 39 +- .../aws/kinesis/KinesisShardManager.java | 562 ++++++ .../aws/kinesis/KplDeaggregator.java | 194 ++ .../aws/kinesis/LegacyCheckpointMigrator.java | 433 +++++ .../aws/kinesis/MemoryBoundRecordBuffer.java | 725 -------- .../aws/kinesis/PollingKinesisClient.java | 431 +++++ .../aws/kinesis/ReaderRecordProcessor.java | 276 --- .../processors/aws/kinesis/RecordBuffer.java | 96 - .../aws/kinesis/ShardCheckpoint.java | 45 + ...itUriFormat.java => ShardFetchResult.java} | 15 +- .../InjectMetadataRecordConverter.java | 49 - .../converter/KinesisRecordConverter.java | 25 - .../converter/ValueRecordConverter.java | 28 - .../converter/WrapperRecordConverter.java | 53 - .../aws/kinesis/CheckpointTableUtilsTest.java | 133 ++ .../aws/kinesis/ConsumeKinesisIT.java | 1394 ++++++++------ .../aws/kinesis/ConsumeKinesisTest.java | 511 +++++- .../aws/kinesis/JsonRecordAssert.java | 73 - .../kinesis/KinesisConsumerClientTest.java | 600 ++++++ .../aws/kinesis/KinesisRecordPayload.java | 36 - .../aws/kinesis/KinesisShardManagerTest.java | 389 ++++ .../aws/kinesis/KplDeaggregatorTest.java | 412 +++++ .../kinesis/MemoryBoundRecordBufferTest.java | 792 -------- .../aws/kinesis/PollingKinesisClientTest.java | 298 +++ .../ProvenanceTransitUriFormatTest.java | 36 - .../kinesis/ReaderRecordProcessorTest.java | 397 ---- .../InjectMetadataRecordConverterTest.java | 84 - .../KinesisRecordConverterTestUtil.java | 98 - .../converter/WrapperRecordConverterTest.java | 76 - 36 files changed, 7006 insertions(+), 4118 deletions(-) create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/{converter => }/KinesisRecordMetadata.java (70%) create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/{ProvenanceTransitUriFormat.java => ShardFetchResult.java} (68%) delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java create mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml index 97e4918a0b80..3f9caf4a3cea 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml @@ -37,10 +37,6 @@ org.apache.nifi nifi-aws-service-api - - org.apache.nifi - nifi-proxy-configuration-api - org.apache.nifi nifi-migration-utils @@ -50,11 +46,50 @@ org.apache.nifi nifi-record-serialization-service-api + + org.apache.nifi + nifi-proxy-configuration-api + + + software.amazon.awssdk + apache-client + + + commons-logging + commons-logging + + + + + software.amazon.awssdk + kinesis + + + commons-logging + commons-logging + + + + + software.amazon.awssdk + dynamodb + + + software.amazon.awssdk + netty-nio-client + + + com.google.protobuf + protobuf-java + + + software.amazon.kinesis amazon-kinesis-client 3.4.1 + test com.google.code.findbugs @@ -78,12 +113,6 @@ - - software.amazon.awssdk - netty-nio-client - - - org.apache.nifi nifi-aws-processors @@ -116,6 +145,16 @@ nifi-ssl-context-service-api test + + org.mockito + mockito-core + test + + + org.testcontainers + testcontainers-junit-jupiter + test + org.testcontainers testcontainers-localstack diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java new file mode 100644 index 000000000000..6661c2cd34d2 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.BillingMode; +import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.ResourceInUseException; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.TableStatus; + +import java.util.List; +import java.util.Map; + +/** + * Shared DynamoDB table lifecycle operations for checkpoint tables. Used by both + * {@link KinesisShardManager} for runtime table management and + * {@link LegacyCheckpointMigrator} for migration and rename operations. + */ +final class CheckpointTableUtils { + + static final String NODE_HEARTBEAT_PREFIX = "__node__#"; + static final String MIGRATION_MARKER_SHARD_ID = "__migration__"; + + private static final long TABLE_POLL_MILLIS = 1_000; + private static final int TABLE_POLL_MAX_ATTEMPTS = 60; + + private CheckpointTableUtils() { } + + enum TableSchema { + NEW, + LEGACY, + UNKNOWN, + NOT_FOUND + } + + static TableSchema getTableSchema(final DynamoDbClient client, final String tableName) { + try { + final DescribeTableResponse describe = client.describeTable( + DescribeTableRequest.builder().tableName(tableName).build()); + final List keySchema = describe.table().keySchema(); + if (keySchema.size() == 2 + && hasKey(keySchema, "streamName", KeyType.HASH) + && hasKey(keySchema, "shardId", KeyType.RANGE)) { + return TableSchema.NEW; + } + if (keySchema.size() == 1 && hasKey(keySchema, "leaseKey", KeyType.HASH)) { + return TableSchema.LEGACY; + } + return TableSchema.UNKNOWN; + } catch (final ResourceNotFoundException notFound) { + return TableSchema.NOT_FOUND; + } + } + + static void createNewSchemaTable(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final TableSchema tableSchema = getTableSchema(client, tableName); + if (tableSchema == TableSchema.NEW) { + logger.info("DynamoDB checkpoint table [{}] already exists", tableName); + return; + } + if (tableSchema == TableSchema.LEGACY || tableSchema == TableSchema.UNKNOWN) { + throw new ProcessException( + "Checkpoint table [%s] exists but does not match expected schema".formatted(tableName)); + } + + logger.info("Creating DynamoDB checkpoint table [{}]", tableName); + try { + final CreateTableRequest request = CreateTableRequest.builder() + .tableName(tableName) + .keySchema( + KeySchemaElement.builder().attributeName("streamName").keyType(KeyType.HASH).build(), + KeySchemaElement.builder().attributeName("shardId").keyType(KeyType.RANGE).build()) + .attributeDefinitions( + AttributeDefinition.builder().attributeName("streamName").attributeType(ScalarAttributeType.S).build(), + AttributeDefinition.builder().attributeName("shardId").attributeType(ScalarAttributeType.S).build()) + .billingMode(BillingMode.PAY_PER_REQUEST) + .build(); + client.createTable(request); + } catch (final ResourceInUseException alreadyCreating) { + logger.info("DynamoDB checkpoint table [{}] is already being created by another node", tableName); + } + } + + static void waitForTableActive(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final DescribeTableRequest request = DescribeTableRequest.builder().tableName(tableName).build(); + for (int i = 0; i < TABLE_POLL_MAX_ATTEMPTS; i++) { + final TableStatus status = client.describeTable(request).table().tableStatus(); + if (status == TableStatus.ACTIVE) { + logger.info("DynamoDB checkpoint table [{}] is now ACTIVE", tableName); + return; + } + + try { + Thread.sleep(TABLE_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for DynamoDB table to become ACTIVE", e); + } + } + throw new ProcessException("DynamoDB checkpoint table [%s] did not become ACTIVE within %d seconds" + .formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + } + + static void deleteTable(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + try { + client.deleteTable(DeleteTableRequest.builder().tableName(tableName).build()); + logger.info("Initiated deletion of DynamoDB table [{}]", tableName); + } catch (final ResourceNotFoundException e) { + logger.debug("Table [{}] already deleted", tableName); + } + } + + static void waitForTableDeleted(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final DescribeTableRequest request = DescribeTableRequest.builder().tableName(tableName).build(); + for (int i = 0; i < TABLE_POLL_MAX_ATTEMPTS; i++) { + try { + client.describeTable(request); + } catch (final ResourceNotFoundException e) { + logger.info("DynamoDB table [{}] has been deleted", tableName); + return; + } + + try { + Thread.sleep(TABLE_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for DynamoDB table deletion", e); + } + } + throw new ProcessException( + "DynamoDB table [%s] was not deleted within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + } + + static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog logger, + final String sourceTableName, final String destTableName) { + logger.info("Copying checkpoint items from [{}] to [{}]", sourceTableName, destTableName); + final TableSchema destinationSchema = getTableSchema(client, destTableName); + if (destinationSchema == TableSchema.NOT_FOUND || destinationSchema == TableSchema.UNKNOWN) { + throw new ProcessException("Cannot copy checkpoint items to [%s]: destination schema is %s" + .formatted(destTableName, destinationSchema)); + } + + Map exclusiveStartKey = null; + int copied = 0; + do { + final ScanRequest scanRequest = exclusiveStartKey == null + ? ScanRequest.builder().tableName(sourceTableName).build() + : ScanRequest.builder().tableName(sourceTableName).exclusiveStartKey(exclusiveStartKey).build(); + final ScanResponse scanResponse = client.scan(scanRequest); + + for (final Map item : scanResponse.items()) { + final AttributeValue shardIdAttr = item.get("shardId"); + if (shardIdAttr != null) { + final String shardId = shardIdAttr.s(); + if (shardId.startsWith(NODE_HEARTBEAT_PREFIX) + || MIGRATION_MARKER_SHARD_ID.equals(shardId)) { + continue; + } + } + + final Map destinationItem = convertItemForDestinationSchema(item, destinationSchema); + if (destinationItem == null) { + logger.debug("Skipping checkpoint item during copy because it cannot be converted for {} schema: keys={}", + destinationSchema, item.keySet()); + continue; + } + + client.putItem(PutItemRequest.builder() + .tableName(destTableName) + .item(destinationItem) + .build()); + copied++; + } + + exclusiveStartKey = scanResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + logger.info("Copied {} checkpoint item(s) from [{}] to [{}]", copied, sourceTableName, destTableName); + } + + private static Map convertItemForDestinationSchema(final Map item, + final TableSchema destinationSchema) { + return switch (destinationSchema) { + case NEW -> item; + case LEGACY -> convertToLegacyItem(item); + case NOT_FOUND, UNKNOWN -> null; + }; + } + + private static Map convertToLegacyItem(final Map item) { + if (item.containsKey("leaseKey")) { + return item; + } + + final AttributeValue streamName = item.get("streamName"); + final AttributeValue shardId = item.get("shardId"); + if (streamName == null || shardId == null) { + return null; + } + + final String shardIdValue = shardId.s(); + if (shardIdValue == null || shardIdValue.isEmpty() + || shardIdValue.startsWith(NODE_HEARTBEAT_PREFIX) + || MIGRATION_MARKER_SHARD_ID.equals(shardIdValue)) { + return null; + } + + final AttributeValue sequenceNumber = item.get("sequenceNumber"); + final String leaseKey = streamName.s() + ":" + shardIdValue; + if (sequenceNumber != null && sequenceNumber.s() != null) { + return Map.of( + "leaseKey", AttributeValue.builder().s(leaseKey).build(), + "checkpoint", AttributeValue.builder().s(sequenceNumber.s()).build()); + } + + return Map.of("leaseKey", AttributeValue.builder().s(leaseKey).build()); + } + + private static boolean hasKey(final List keySchema, final String keyName, final KeyType keyType) { + for (final KeySchemaElement element : keySchema) { + if (keyName.equals(element.attributeName()) && keyType == element.keyType()) { + return true; + } + } + return false; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index 52044ee0f478..61c223877bd2 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -16,7 +16,6 @@ */ package org.apache.nifi.processors.aws.kinesis; -import jakarta.annotation.Nullable; import org.apache.nifi.annotation.behavior.InputRequirement; import org.apache.nifi.annotation.behavior.SystemResource; import org.apache.nifi.annotation.behavior.SystemResourceConsideration; @@ -30,154 +29,132 @@ import org.apache.nifi.components.DescribedValue; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.Validator; -import org.apache.nifi.controller.NodeTypeProvider; import org.apache.nifi.flowfile.FlowFile; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.migration.PropertyConfiguration; -import org.apache.nifi.migration.ProxyServiceMigration; +import org.apache.nifi.migration.RelationshipConfiguration; import org.apache.nifi.processor.AbstractProcessor; import org.apache.nifi.processor.DataUnit; import org.apache.nifi.processor.ProcessContext; import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.io.OutputStreamCallback; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processors.aws.credentials.provider.AwsCredentialsProviderService; -import org.apache.nifi.processors.aws.kinesis.MemoryBoundRecordBuffer.Lease; -import org.apache.nifi.processors.aws.kinesis.ReaderRecordProcessor.ProcessingResult; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.apache.nifi.processors.aws.kinesis.converter.InjectMetadataRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.ValueRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.WrapperRecordConverter; import org.apache.nifi.processors.aws.region.RegionUtil; import org.apache.nifi.proxy.ProxyConfiguration; -import org.apache.nifi.proxy.ProxyConfigurationService; import org.apache.nifi.proxy.ProxySpec; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.RecordSetWriter; import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.SimpleRecordSchema; +import org.apache.nifi.serialization.record.MapRecord; +import org.apache.nifi.serialization.record.RecordField; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.RecordSchema; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.http.Protocol; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; -import software.amazon.awssdk.http.nio.netty.Http2Configuration; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; -import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClientBuilder; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder; -import software.amazon.kinesis.common.ConfigsBuilder; -import software.amazon.kinesis.common.InitialPositionInStream; -import software.amazon.kinesis.common.InitialPositionInStreamExtended; -import software.amazon.kinesis.coordinator.Scheduler; -import software.amazon.kinesis.coordinator.WorkerStateChangeListener; -import software.amazon.kinesis.lifecycle.events.InitializationInput; -import software.amazon.kinesis.lifecycle.events.LeaseLostInput; -import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; -import software.amazon.kinesis.lifecycle.events.ShardEndedInput; -import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; -import software.amazon.kinesis.metrics.LogMetricsFactory; -import software.amazon.kinesis.metrics.MetricsFactory; -import software.amazon.kinesis.metrics.NullMetricsFactory; -import software.amazon.kinesis.processor.ShardRecordProcessor; -import software.amazon.kinesis.processor.ShardRecordProcessorFactory; -import software.amazon.kinesis.processor.SingleStreamTracker; -import software.amazon.kinesis.retrieval.KinesisClientRecord; -import software.amazon.kinesis.retrieval.RetrievalSpecificConfig; -import software.amazon.kinesis.retrieval.fanout.FanOutConfig; -import software.amazon.kinesis.retrieval.polling.PollingConfig; - +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.math.BigInteger; +import java.net.Proxy; import java.net.URI; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; -import java.util.Date; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.NANOSECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.APPROXIMATE_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.PARTITION_KEY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.SHARD_ID; +import java.util.concurrent.TimeUnit; + import static org.apache.nifi.processors.aws.region.RegionUtil.CUSTOM_REGION; import static org.apache.nifi.processors.aws.region.RegionUtil.REGION; @InputRequirement(InputRequirement.Requirement.INPUT_FORBIDDEN) @Tags({"amazon", "aws", "kinesis", "consume", "stream", "record"}) @CapabilityDescription(""" - Consumes data from the specified AWS Kinesis stream and outputs a FlowFile for every processed Record (raw) - or a FlowFile for a batch of processed records if a Record Reader and Record Writer are configured. - The processor may take a few minutes on the first start and several seconds on subsequent starts - to initialize before starting to fetch data. - Uses DynamoDB for check pointing and coordination, and (optional) CloudWatch for metrics. - """) + Consumes records from an Amazon Kinesis Data Stream. Uses \ + DynamoDB-based checkpointing for reliable resumption after restarts. + + Note: when a shard is split or multiple shards are merged, this processor will consume from \ + child and parent shards concurrently. It does not wait for parent shards to be fully consumed \ + before reading child shards, so record ordering is not guaranteed across a split or merge \ + boundary.""") @WritesAttributes({ - @WritesAttribute(attribute = ConsumeKinesisAttributes.STREAM_NAME, - description = "The name of the Kinesis Stream from which all Kinesis Records in the FlowFile were read"), - @WritesAttribute(attribute = SHARD_ID, - description = "Shard ID from which all Kinesis Records in the FlowFile were read"), - @WritesAttribute(attribute = PARTITION_KEY, + @WritesAttribute(attribute = "aws.kinesis.stream.name", + description = "The name of the Kinesis Stream from which records were read"), + @WritesAttribute(attribute = "aws.kinesis.shard.id", + description = "Shard ID from which records were read"), + @WritesAttribute(attribute = "aws.kinesis.partition.key", description = "Partition key of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = FIRST_SEQUENCE_NUMBER, - description = "A Sequence Number of the first Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = FIRST_SUB_SEQUENCE_NUMBER, - description = "A SubSequence Number of the first Kinesis Record in the FlowFile. Generated by KPL when aggregating records into a single Kinesis Record"), - @WritesAttribute(attribute = LAST_SEQUENCE_NUMBER, - description = "A Sequence Number of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = LAST_SUB_SEQUENCE_NUMBER, - description = "A SubSequence Number of the last Kinesis Record in the FlowFile. Generated by KPL when aggregating records into a single Kinesis Record"), - @WritesAttribute(attribute = APPROXIMATE_ARRIVAL_TIMESTAMP, + @WritesAttribute(attribute = "aws.kinesis.first.sequence.number", + description = "Sequence Number of the first Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.first.subsequence.number", + description = "Sub-Sequence Number of the first Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.last.sequence.number", + description = "Sequence Number of the last Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.last.subsequence.number", + description = "Sub-Sequence Number of the last Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.approximate.arrival.timestamp.ms", description = "Approximate arrival timestamp of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = MIME_TYPE, + @WritesAttribute(attribute = "mime.type", description = "Sets the mime.type attribute to the MIME Type specified by the Record Writer (if configured)"), - @WritesAttribute(attribute = RECORD_COUNT, - description = "Number of records written to the FlowFiles by the Record Writer (if configured)"), - @WritesAttribute(attribute = RECORD_ERROR_MESSAGE, - description = "This attribute provides on failure the error message encountered by the Record Reader or Record Writer (if configured)") + @WritesAttribute(attribute = "record.count", + description = "Number of records written to the FlowFile"), + @WritesAttribute(attribute = "record.error.message", + description = "Error message encountered by the Record Reader or Record Writer (if configured)"), + @WritesAttribute(attribute = "kinesis.millis.behind", + description = "How far behind the stream tail we are, in milliseconds") }) @DefaultSettings(yieldDuration = "100 millis") @SystemResourceConsideration(resource = SystemResource.CPU, description = """ - The processor uses additional CPU resources when consuming data from Kinesis. - The consumption is started immediately after this Processor is scheduled. The consumption ends only when the Processor is stopped.""") + The processor uses additional CPU resources when consuming data from Kinesis.""") @SystemResourceConsideration(resource = SystemResource.NETWORK, description = """ - The processor will continually poll for new Records, - requesting up to a maximum number of Records/bytes per call. This can result in sustained network usage.""") + The processor will continually poll for new Records.""") @SystemResourceConsideration(resource = SystemResource.MEMORY, description = """ - ConsumeKinesis buffers Kinesis Records in memory until they can be processed. - The maximum size of the buffer is controlled by the 'Max Bytes to Buffer' property. - In addition, the processor may cache some amount of data for each shard when the processor's buffer is full.""") + ConsumeKinesis buffers Kinesis Records in memory until they can be processed. \ + The maximum size of the buffer is controlled by the 'Max Batch Size' property.""") public class ConsumeKinesis extends AbstractProcessor { - private static final Duration HTTP_CLIENTS_CONNECTION_TIMEOUT = Duration.ofSeconds(30); - private static final Duration HTTP_CLIENTS_READ_TIMEOUT = Duration.ofMinutes(3); - - private static final int KINESIS_HTTP_CLIENT_WINDOW_SIZE_BYTES = 512 * 1024; // 512 KiB - private static final Duration KINESIS_HTTP_HEALTH_CHECK_PERIOD = Duration.ofMinutes(1); - - /** - * How long to wait for a Scheduler initialization to complete in the OnScheduled method. - * If the initialization takes longer than this, the processor will continue initialization checks in the onTrigger method. - */ - private static final Duration KINESIS_SCHEDULER_ON_SCHEDULED_INITIALIZATION_TIMEOUT = Duration.ofSeconds(30); - private static final Duration KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT = Duration.ofMinutes(3); + static final String ATTR_STREAM_NAME = "aws.kinesis.stream.name"; + static final String ATTR_SHARD_ID = "aws.kinesis.shard.id"; + static final String ATTR_FIRST_SEQUENCE = "aws.kinesis.first.sequence.number"; + static final String ATTR_LAST_SEQUENCE = "aws.kinesis.last.sequence.number"; + static final String ATTR_FIRST_SUBSEQUENCE = "aws.kinesis.first.subsequence.number"; + static final String ATTR_LAST_SUBSEQUENCE = "aws.kinesis.last.subsequence.number"; + static final String ATTR_PARTITION_KEY = "aws.kinesis.partition.key"; + static final String ATTR_ARRIVAL_TIMESTAMP = "aws.kinesis.approximate.arrival.timestamp.ms"; + static final String ATTR_MILLIS_BEHIND = "kinesis.millis.behind"; + + private static final long QUEUE_POLL_TIMEOUT_MILLIS = 100; + private static final Duration API_CALL_TIMEOUT = Duration.ofSeconds(30); + private static final Duration API_CALL_ATTEMPT_TIMEOUT = Duration.ofSeconds(10); + private static final byte[] NEWLINE_DELIMITER = new byte[] {'\n'}; + private static final String WRAPPER_VALUE_FIELD = "value"; static final PropertyDescriptor STREAM_NAME = new PropertyDescriptor.Builder() .name("Stream Name") @@ -188,17 +165,14 @@ public class ConsumeKinesis extends AbstractProcessor { static final PropertyDescriptor APPLICATION_NAME = new PropertyDescriptor.Builder() .name("Application Name") - .description("The name of the Kinesis application. This is used for DynamoDB table naming and worker coordination.") + .description("The name of the Kinesis application. Used as the DynamoDB table name for checkpoint storage.") .required(true) .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) .build(); static final PropertyDescriptor AWS_CREDENTIALS_PROVIDER_SERVICE = new PropertyDescriptor.Builder() .name("AWS Credentials Provider Service") - .description(""" - The Controller Service that is used to obtain AWS credentials provider. - Ensure that the credentials provided have access to Kinesis, DynamoDB and (optional) CloudWatch. - """) + .description("The Controller Service used to obtain AWS credentials provider.") .required(true) .identifiesControllerService(AwsCredentialsProviderService.class) .build(); @@ -221,12 +195,7 @@ Ensure that the credentials provided have access to Kinesis, DynamoDB and (optio static final PropertyDescriptor RECORD_READER = new PropertyDescriptor.Builder() .name("Record Reader") - .description(""" - The Record Reader to use for parsing the data received from Kinesis. - - The Record Reader is responsible for providing schemas for the records. If the schemas change frequently, - it might hinder performance of the processor. - """) + .description("The Record Reader to use for parsing the data received from Kinesis.") .required(true) .dependsOn(PROCESSING_STRATEGY, ProcessingStrategy.RECORD) .identifiesControllerService(RecordReaderFactory.class) @@ -252,12 +221,9 @@ Ensure that the credentials provided have access to Kinesis, DynamoDB and (optio static final PropertyDescriptor MESSAGE_DEMARCATOR = new PropertyDescriptor.Builder() .name("Message Demarcator") .description(""" - Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kinesis messages - within a single FlowFile. If not specified, the content of the messages will be concatenated - without any delimiter. - To enter special character such as 'new line' use CTRL+Enter or Shift+Enter depending on the OS. - """) - .required(false) + Specifies the string (interpreted as UTF-8) used to separate multiple Kinesis messages \ + within a single FlowFile when Processing Strategy is DEMARCATOR.""") + .required(true) .addValidator(Validator.VALID) .dependsOn(PROCESSING_STRATEGY, ProcessingStrategy.DEMARCATOR) .build(); @@ -272,47 +238,49 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine static final PropertyDescriptor STREAM_POSITION_TIMESTAMP = new PropertyDescriptor.Builder() .name("Stream Position Timestamp") - .description("Timestamp position in stream from which to start reading Kinesis Records. The timestamp must be in ISO 8601 format.") + .description("Timestamp position in stream from which to start reading Kinesis Records. Must be in ISO 8601 format.") .required(true) .addValidator(StandardValidators.ISO8601_INSTANT_VALIDATOR) .dependsOn(INITIAL_STREAM_POSITION, InitialPosition.AT_TIMESTAMP) .build(); - static final PropertyDescriptor MAX_BYTES_TO_BUFFER = new PropertyDescriptor.Builder() - .name("Max Bytes to Buffer") - .description(""" - The maximum size of Kinesis Records that can be buffered in memory before being processed by NiFi. - If the buffer size exceeds the limit, the processor will stop consuming new records until free space is available. - - Using a larger value may increase the throughput, but will do so at the expense of using more memory. - """) + static final PropertyDescriptor MAX_RECORDS_PER_REQUEST = new PropertyDescriptor.Builder() + .name("Max Records Per Request") + .description("The maximum number of records to retrieve per GetRecords call. Maximum is 10,000.") .required(true) - .addValidator(StandardValidators.DATA_SIZE_VALIDATOR) - .defaultValue("100 MB") + .defaultValue("1000") + .addValidator(StandardValidators.createLongValidator(1, 10000, true)) .build(); - static final PropertyDescriptor CHECKPOINT_INTERVAL = new PropertyDescriptor.Builder() - .name("Checkpoint Interval") + static final PropertyDescriptor MAX_BATCH_DURATION = new PropertyDescriptor.Builder() + .name("Max Batch Duration") .description(""" - Interval between checkpointing consumed Kinesis records. To checkpoint records each time the Processor is run, set this value to 0 seconds. - - More frequent checkpoint may reduce performance and increase DynamoDB costs, - but less frequent checkpointing may result in duplicates when a Shard lease is lost or NiFi is restarted. - """) + The maximum amount of time to spend consuming records in a single invocation before \ + committing the session and checkpointing.""") .required(true) - .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) .defaultValue("5 sec") + .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) .build(); - static final PropertyDescriptor METRICS_PUBLISHING = new PropertyDescriptor.Builder() - .name("Metrics Publishing") - .description("Specifies where Kinesis usage metrics are published to.") + static final PropertyDescriptor MAX_BATCH_SIZE = new PropertyDescriptor.Builder() + .name("Max Batch Size") + .description(""" + The maximum amount of data to consume in a single invocation before committing the \ + session and checkpointing.""") .required(true) - .allowableValues(MetricsPublishing.class) - .defaultValue(MetricsPublishing.DISABLED) + .defaultValue("10 MB") + .addValidator(StandardValidators.DATA_SIZE_VALIDATOR) .build(); - static final PropertyDescriptor PROXY_CONFIGURATION_SERVICE = ProxyConfiguration.createProxyConfigPropertyDescriptor(ProxySpec.HTTP, ProxySpec.HTTP_AUTH); + static final PropertyDescriptor ENDPOINT_OVERRIDE = new PropertyDescriptor.Builder() + .name("Endpoint Override URL") + .description("An optional endpoint override URL for both the Kinesis and DynamoDB clients.") + .required(false) + .addValidator(StandardValidators.URL_VALIDATOR) + .build(); + + static final PropertyDescriptor PROXY_CONFIGURATION_SERVICE = + ProxyConfiguration.createProxyConfigPropertyDescriptor(ProxySpec.HTTP, ProxySpec.HTTP_AUTH); private static final List PROPERTY_DESCRIPTORS = List.of( STREAM_NAME, @@ -328,10 +296,11 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine MESSAGE_DEMARCATOR, INITIAL_STREAM_POSITION, STREAM_POSITION_TIMESTAMP, - MAX_BYTES_TO_BUFFER, - CHECKPOINT_INTERVAL, - PROXY_CONFIGURATION_SERVICE, - METRICS_PUBLISHING + MAX_RECORDS_PER_REQUEST, + MAX_BATCH_DURATION, + MAX_BATCH_SIZE, + ENDPOINT_OVERRIDE, + PROXY_CONFIGURATION_SERVICE ); static final Relationship REL_SUCCESS = new Relationship.Builder() @@ -347,38 +316,30 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine private static final Set RAW_FILE_RELATIONSHIPS = Set.of(REL_SUCCESS); private static final Set RECORD_FILE_RELATIONSHIPS = Set.of(REL_SUCCESS, REL_PARSE_FAILURE); - private volatile DynamoDbAsyncClient dynamoDbClient; - private volatile CloudWatchAsyncClient cloudWatchClient; - private volatile KinesisAsyncClient kinesisClient; - private volatile Scheduler kinesisScheduler; - + private volatile SdkHttpClient kinesisHttpClient; + private volatile SdkHttpClient dynamoHttpClient; + private volatile KinesisClient kinesisClient; + private volatile DynamoDbClient dynamoDbClient; + private volatile SdkAsyncHttpClient asyncHttpClient; + private volatile KinesisShardManager shardManager; + private volatile KinesisConsumerClient consumerClient; private volatile String streamName; - private volatile RecordBuffer.ForProcessor recordBuffer; - - private volatile @Nullable ReaderRecordProcessor readerRecordProcessor; - private volatile @Nullable byte[] demarcatorValue; + private volatile int maxRecordsPerRequest; + private volatile String initialStreamPosition; + private volatile long maxBatchNanos; + private volatile long maxBatchBytes; - private volatile Future initializationResultFuture; - private final AtomicBoolean initialized = new AtomicBoolean(); - - // An instance filed, so that it can be read in getRelationships. - private volatile ProcessingStrategy processingStrategy = ProcessingStrategy.from( - PROCESSING_STRATEGY.getDefaultValue()); + private volatile ProcessingStrategy processingStrategy = ProcessingStrategy.valueOf(PROCESSING_STRATEGY.getDefaultValue()); @Override protected List getSupportedPropertyDescriptors() { return PROPERTY_DESCRIPTORS; } - @Override - public void migrateProperties(final PropertyConfiguration config) { - ProxyServiceMigration.renameProxyConfigurationServiceProperty(config); - } - @Override public Set getRelationships() { return switch (processingStrategy) { - case FLOW_FILE, DEMARCATOR -> RAW_FILE_RELATIONSHIPS; + case FLOW_FILE, LINE_DELIMITED, DEMARCATOR -> RAW_FILE_RELATIONSHIPS; case RECORD -> RECORD_FILE_RELATIONSHIPS; }; } @@ -386,506 +347,1090 @@ public Set getRelationships() { @Override public void onPropertyModified(final PropertyDescriptor descriptor, final String oldValue, final String newValue) { if (descriptor.equals(PROCESSING_STRATEGY)) { - processingStrategy = ProcessingStrategy.from(newValue); + processingStrategy = ProcessingStrategy.valueOf(newValue); } } - @OnScheduled - public void setup(final ProcessContext context) { - readerRecordProcessor = switch (processingStrategy) { - case FLOW_FILE, DEMARCATOR -> null; - case RECORD -> createReaderRecordProcessor(context); - }; - demarcatorValue = switch (processingStrategy) { - case FLOW_FILE, RECORD -> null; - case DEMARCATOR -> { - final String demarcatorValue = context.getProperty(MESSAGE_DEMARCATOR).getValue(); - yield demarcatorValue != null ? demarcatorValue.getBytes(UTF_8) : new byte[0]; - } - }; + @Override + public void migrateProperties(final PropertyConfiguration config) { + config.renameProperty("Max Bytes to Buffer", "Max Batch Size"); + config.removeProperty("Checkpoint Interval"); + config.removeProperty("Metrics Publishing"); + } + + @Override + public void migrateRelationships(final RelationshipConfiguration config) { + config.renameRelationship("parse failure", "parse.failure"); + } + @OnScheduled + public void onScheduled(final ProcessContext context) { final Region region = RegionUtil.getRegion(context); final AwsCredentialsProvider credentialsProvider = context.getProperty(AWS_CREDENTIALS_PROVIDER_SERVICE) .asControllerService(AwsCredentialsProviderService.class).getAwsCredentialsProvider(); + final String endpointOverride = context.getProperty(ENDPOINT_OVERRIDE).getValue(); - kinesisClient = KinesisAsyncClient.builder() - .region(region) - .credentialsProvider(credentialsProvider) - .endpointOverride(getKinesisEndpointOverride()) - .httpClient(createKinesisHttpClient(context)) + final ClientOverrideConfiguration clientConfig = ClientOverrideConfiguration.builder() + .apiCallTimeout(API_CALL_TIMEOUT) + .apiCallAttemptTimeout(API_CALL_ATTEMPT_TIMEOUT) .build(); - dynamoDbClient = DynamoDbAsyncClient.builder() + final KinesisClientBuilder kinesisBuilder = KinesisClient.builder() .region(region) .credentialsProvider(credentialsProvider) - .endpointOverride(getDynamoDbEndpointOverride()) - .httpClient(createHttpClientBuilder(context).build()) - .build(); + .overrideConfiguration(clientConfig); - cloudWatchClient = CloudWatchAsyncClient.builder() + final DynamoDbClientBuilder dynamoBuilder = DynamoDbClient.builder() .region(region) .credentialsProvider(credentialsProvider) - .endpointOverride(getCloudwatchEndpointOverride()) - .httpClient(createHttpClientBuilder(context).build()) - .build(); - - streamName = context.getProperty(STREAM_NAME).getValue(); - final InitialPositionInStreamExtended initialPositionExtended = getInitialPosition(context); - final SingleStreamTracker streamTracker = new SingleStreamTracker(streamName, initialPositionExtended); - - final long maxBytesToBuffer = context.getProperty(MAX_BYTES_TO_BUFFER).asDataSize(DataUnit.B).longValue(); - final Duration checkpointInterval = context.getProperty(CHECKPOINT_INTERVAL).asDuration(); - final MemoryBoundRecordBuffer memoryBoundRecordBuffer = new MemoryBoundRecordBuffer(getLogger(), maxBytesToBuffer, checkpointInterval); - recordBuffer = memoryBoundRecordBuffer; - final ShardRecordProcessorFactory recordProcessorFactory = () -> new ConsumeKinesisRecordProcessor(memoryBoundRecordBuffer); - - final String applicationName = context.getProperty(APPLICATION_NAME).getValue(); - final String workerId = generateWorkerId(); - final ConfigsBuilder configsBuilder = new ConfigsBuilder(streamTracker, applicationName, kinesisClient, dynamoDbClient, cloudWatchClient, workerId, recordProcessorFactory); - - final MetricsFactory metricsFactory = configureMetricsFactory(context); - final RetrievalSpecificConfig retrievalSpecificConfig = configureRetrievalSpecificConfig(context, kinesisClient, streamName, applicationName); - - final InitializationStateChangeListener initializationListener = new InitializationStateChangeListener(getLogger()); - initialized.set(false); - initializationResultFuture = initializationListener.result(); - - kinesisScheduler = new Scheduler( - configsBuilder.checkpointConfig(), - configsBuilder.coordinatorConfig().workerStateChangeListener(initializationListener), - configsBuilder.leaseManagementConfig(), - configsBuilder.lifecycleConfig(), - configsBuilder.metricsConfig().metricsFactory(metricsFactory), - configsBuilder.processorConfig(), - configsBuilder.retrievalConfig().retrievalSpecificConfig(retrievalSpecificConfig) - ); - - final String schedulerThreadName = "%s-Scheduler-%s".formatted(getClass().getSimpleName(), getIdentifier()); - final Thread schedulerThread = new Thread(kinesisScheduler, schedulerThreadName); - schedulerThread.setDaemon(true); - schedulerThread.start(); - // The thread is stopped when kinesisScheduler is shutdown in the onStopped method. + .overrideConfiguration(clientConfig); - try { - final InitializationResult result = initializationResultFuture.get( - KINESIS_SCHEDULER_ON_SCHEDULED_INITIALIZATION_TIMEOUT.getSeconds(), SECONDS); - checkInitializationResult(result); - } catch (final TimeoutException e) { - // During a first run the processor will take more time to initialize. We return from OnSchedule and continue waiting in the onTrigger method. - getLogger().warn("Kinesis Scheduler initialization may take up to 10 minutes on a first run, which is caused by AWS resources initialization"); - } catch (final InterruptedException | ExecutionException e) { - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - cleanUpState(); - throw new ProcessException("Initialization failed for stream [%s]".formatted(streamName), e); + if (endpointOverride != null && !endpointOverride.isEmpty()) { + final URI endpointUri = URI.create(endpointOverride); + kinesisBuilder.endpointOverride(endpointUri); + dynamoBuilder.endpointOverride(endpointUri); } - } - /** - * Creating Kinesis HTTP client, as per - * {@link software.amazon.kinesis.common.KinesisClientUtil#adjustKinesisClientBuilder(KinesisAsyncClientBuilder)}. - */ - private static SdkAsyncHttpClient createKinesisHttpClient(final ProcessContext context) { - return createHttpClientBuilder(context) - .protocol(Protocol.HTTP2) - // Since we're using HTTP/2, multiple concurrent requests will reuse the same HTTP connection. - // Therefore, the number of real connections is going to be relatively small. - .maxConcurrency(Integer.MAX_VALUE) - .http2Configuration(Http2Configuration.builder() - .initialWindowSize(KINESIS_HTTP_CLIENT_WINDOW_SIZE_BYTES) - .healthCheckPingPeriod(KINESIS_HTTP_HEALTH_CHECK_PERIOD) - .build()) - .build(); - } + final ProxyConfiguration proxyConfig = ProxyConfiguration.getConfiguration(context); - private static NettyNioAsyncHttpClient.Builder createHttpClientBuilder(final ProcessContext context) { - final NettyNioAsyncHttpClient.Builder builder = NettyNioAsyncHttpClient.builder() - .connectionTimeout(HTTP_CLIENTS_CONNECTION_TIMEOUT) - .readTimeout(HTTP_CLIENTS_READ_TIMEOUT); + kinesisHttpClient = buildApacheHttpClient(proxyConfig, PollingKinesisClient.MAX_CONCURRENT_FETCHES + 10); + dynamoHttpClient = buildApacheHttpClient(proxyConfig, 50); + kinesisBuilder.httpClient(kinesisHttpClient); + dynamoBuilder.httpClient(dynamoHttpClient); + + kinesisClient = kinesisBuilder.build(); + dynamoDbClient = dynamoBuilder.build(); + + final String checkpointTableName = context.getProperty(APPLICATION_NAME).getValue(); + streamName = context.getProperty(STREAM_NAME).getValue(); + maxRecordsPerRequest = context.getProperty(MAX_RECORDS_PER_REQUEST).asInteger(); + initialStreamPosition = context.getProperty(INITIAL_STREAM_POSITION).getValue(); + maxBatchNanos = context.getProperty(MAX_BATCH_DURATION).asTimePeriod(TimeUnit.NANOSECONDS); + maxBatchBytes = context.getProperty(MAX_BATCH_SIZE).asDataSize(DataUnit.B).longValue(); + + shardManager = createShardManager(kinesisClient, dynamoDbClient, getLogger(), checkpointTableName, streamName); + shardManager.ensureCheckpointTableExists(); + + final boolean efoMode = ConsumerType.ENHANCED_FAN_OUT.equals(context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class)); + consumerClient = createConsumerClient(kinesisClient, getLogger(), efoMode); + + final Instant timestampForPosition = resolveTimestampPosition(context); + if (timestampForPosition != null) { + if (consumerClient instanceof PollingKinesisClient polling) { + polling.setTimestampForInitialPosition(timestampForPosition); + } else if (consumerClient instanceof EfoKinesisClient efo) { + efo.setTimestampForInitialPosition(timestampForPosition); + } + } - final ProxyConfigurationService proxyConfigService = context.getProperty(PROXY_CONFIGURATION_SERVICE).asControllerService(ProxyConfigurationService.class); - if (proxyConfigService != null) { - final ProxyConfiguration proxyConfig = proxyConfigService.getConfiguration(); + if (efoMode) { + final NettyNioAsyncHttpClient.Builder nettyBuilder = NettyNioAsyncHttpClient.builder() + .maxConcurrency(500) + .connectionAcquisitionTimeout(Duration.ofSeconds(60)); - final software.amazon.awssdk.http.nio.netty.ProxyConfiguration.Builder proxyConfigBuilder = software.amazon.awssdk.http.nio.netty.ProxyConfiguration.builder() + if (Proxy.Type.HTTP.equals(proxyConfig.getProxyType())) { + final software.amazon.awssdk.http.nio.netty.ProxyConfiguration.Builder nettyProxyBuilder = software.amazon.awssdk.http.nio.netty.ProxyConfiguration.builder() .host(proxyConfig.getProxyServerHost()) .port(proxyConfig.getProxyServerPort()); - if (proxyConfig.hasCredential()) { - proxyConfigBuilder.username(proxyConfig.getProxyUserName()); - proxyConfigBuilder.password(proxyConfig.getProxyUserPassword()); - } + if (proxyConfig.hasCredential()) { + nettyProxyBuilder.username(proxyConfig.getProxyUserName()); + nettyProxyBuilder.password(proxyConfig.getProxyUserPassword()); + } - builder.proxyConfiguration(proxyConfigBuilder.build()); - } + nettyBuilder.proxyConfiguration(nettyProxyBuilder.build()); + } - return builder; - } + asyncHttpClient = nettyBuilder.build(); - private ReaderRecordProcessor createReaderRecordProcessor(final ProcessContext context) { - final RecordReaderFactory recordReaderFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); - final RecordSetWriterFactory recordWriterFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final KinesisAsyncClientBuilder asyncBuilder = KinesisAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .httpClient(asyncHttpClient); - final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); - final KinesisRecordConverter converter = switch (outputStrategy) { - case USE_VALUE -> new ValueRecordConverter(); - case USE_WRAPPER -> new WrapperRecordConverter(); - case INJECT_METADATA -> new InjectMetadataRecordConverter(); - }; + if (endpointOverride != null && !endpointOverride.isEmpty()) { + asyncBuilder.endpointOverride(URI.create(endpointOverride)); + } - return new ReaderRecordProcessor(recordReaderFactory, converter, recordWriterFactory, getLogger()); + final String consumerName = context.getProperty(APPLICATION_NAME).getValue(); + consumerClient.initialize(asyncBuilder.build(), streamName, consumerName); + } } - private static InitialPositionInStreamExtended getInitialPosition(final ProcessContext context) { - final InitialPosition initialPosition = context.getProperty(INITIAL_STREAM_POSITION).asAllowableValue(InitialPosition.class); - return switch (initialPosition) { - case TRIM_HORIZON -> - InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); - case LATEST -> InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST); - case AT_TIMESTAMP -> { - final String timestampValue = context.getProperty(STREAM_POSITION_TIMESTAMP).getValue(); - final Instant timestamp = Instant.parse(timestampValue); - yield InitialPositionInStreamExtended.newInitialPositionAtTimestamp(Date.from(timestamp)); - } - }; + private static Instant resolveTimestampPosition(final ProcessContext context) { + final InitialPosition position = context.getProperty(INITIAL_STREAM_POSITION).asAllowableValue(InitialPosition.class); + if (position == InitialPosition.AT_TIMESTAMP) { + return Instant.parse(context.getProperty(STREAM_POSITION_TIMESTAMP).getValue()); + } + return null; } - private String generateWorkerId() { - final String processorId = getIdentifier(); - final NodeTypeProvider nodeTypeProvider = getNodeTypeProvider(); - - final String workerId; + /** + * Builds an {@link ApacheHttpClient} with the given connection pool size and optional proxy + * configuration. Each AWS service client (Kinesis, DynamoDB) should receive its own HTTP client + * so their connection pools are isolated and cannot starve each other under high shard counts. + */ + private static SdkHttpClient buildApacheHttpClient(final ProxyConfiguration proxyConfig, final int maxConnections) { + final ApacheHttpClient.Builder builder = ApacheHttpClient.builder() + .maxConnections(maxConnections); - if (nodeTypeProvider.isClustered()) { - // If a node id is not available for some reason, generating a random UUID helps to avoid collisions. - final String nodeId = nodeTypeProvider.getCurrentNode().orElse(UUID.randomUUID().toString()); - workerId = "%s@%s".formatted(processorId, nodeId); - } else { - workerId = processorId; - } + if (Proxy.Type.HTTP.equals(proxyConfig.getProxyType())) { + final URI proxyEndpoint = URI.create(String.format("http://%s:%s", proxyConfig.getProxyServerHost(), proxyConfig.getProxyServerPort())); + final software.amazon.awssdk.http.apache.ProxyConfiguration.Builder proxyBuilder = + software.amazon.awssdk.http.apache.ProxyConfiguration.builder().endpoint(proxyEndpoint); - return workerId; - } + if (proxyConfig.hasCredential()) { + proxyBuilder.username(proxyConfig.getProxyUserName()); + proxyBuilder.password(proxyConfig.getProxyUserPassword()); + } - private static @Nullable MetricsFactory configureMetricsFactory(final ProcessContext context) { - final MetricsPublishing metricsPublishing = context.getProperty(METRICS_PUBLISHING).asAllowableValue(MetricsPublishing.class); - return switch (metricsPublishing) { - case DISABLED -> new NullMetricsFactory(); - case LOGS -> new LogMetricsFactory(); - case CLOUDWATCH -> null; // If no metrics factory was provided, CloudWatch metrics factory is used by default. - }; - } + builder.proxyConfiguration(proxyBuilder.build()); + } - private static RetrievalSpecificConfig configureRetrievalSpecificConfig( - final ProcessContext context, - final KinesisAsyncClient kinesisClient, - final String streamName, - final String applicationName) { - final ConsumerType consumerType = context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class); - return switch (consumerType) { - case SHARED_THROUGHPUT -> new PollingConfig(kinesisClient).streamName(streamName); - case ENHANCED_FAN_OUT -> new FanOutConfig(kinesisClient).streamName(streamName).applicationName(applicationName); - }; + return builder.build(); } @OnStopped public void onStopped() { - cleanUpState(); + if (shardManager != null) { + shardManager.releaseAllLeases(); + shardManager.close(); + shardManager = null; + } - initialized.set(false); - initializationResultFuture = null; - } + if (consumerClient != null) { + consumerClient.close(); + consumerClient = null; + } - private void cleanUpState() { - if (kinesisScheduler != null) { - shutdownScheduler(); - kinesisScheduler = null; + if (asyncHttpClient != null) { + asyncHttpClient.close(); + asyncHttpClient = null; } if (kinesisClient != null) { kinesisClient.close(); kinesisClient = null; } + if (dynamoDbClient != null) { dynamoDbClient.close(); dynamoDbClient = null; } - if (cloudWatchClient != null) { - cloudWatchClient.close(); - cloudWatchClient = null; - } - recordBuffer = null; - readerRecordProcessor = null; - demarcatorValue = null; + closeQuietly(kinesisHttpClient); + kinesisHttpClient = null; + closeQuietly(dynamoHttpClient); + dynamoHttpClient = null; } - private void shutdownScheduler() { - if (kinesisScheduler.shutdownComplete()) { + @Override + public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { + final int clusterMemberCount = Math.max(1, getNodeTypeProvider().getClusterMembers().size()); + shardManager.refreshLeasesIfNecessary(clusterMemberCount); + final List ownedShards = shardManager.getOwnedShards(); + + if (ownedShards.isEmpty()) { + context.yield(); return; } - final long start = System.nanoTime(); - getLogger().debug("Shutting down Kinesis Scheduler"); + final Set ownedShardIds = new HashSet<>(); + for (final Shard shard : ownedShards) { + ownedShardIds.add(shard.shardId()); + } + + consumerClient.removeUnownedShards(ownedShardIds); + consumerClient.startFetches(ownedShards, streamName, maxRecordsPerRequest, initialStreamPosition, shardManager); + consumerClient.logDiagnostics(ownedShards.size(), shardManager.getCachedShardCount()); + + final Set claimedShards = new HashSet<>(); + final List consumed = consumeRecords(claimedShards); + final List accepted = discardRelinquishedResults(consumed, claimedShards); + + if (accepted.isEmpty()) { + consumerClient.releaseShards(claimedShards); + context.yield(); + return; + } - boolean gracefulShutdownSucceeded; + final PartitionedBatch batch = partitionByShardAndCheckpoint(accepted); + + final WriteResult output; try { - gracefulShutdownSucceeded = kinesisScheduler.startGracefulShutdown().get(KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); - if (!gracefulShutdownSucceeded) { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully. See the logs for more details"); - } - } catch (final RuntimeException | InterruptedException | ExecutionException | TimeoutException e) { - if (e instanceof TimeoutException) { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully after {} seconds", KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT.getSeconds(), e); - } else { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully", e); - } - gracefulShutdownSucceeded = false; + output = writeResults(session, context, batch.resultsByShard()); + } catch (final Exception e) { + handleWriteFailure(e, accepted, claimedShards, context); + return; } - if (!gracefulShutdownSucceeded) { - kinesisScheduler.shutdown(); + if (output.produced().isEmpty() && output.parseFailures().isEmpty()) { + consumerClient.releaseShards(claimedShards); + context.yield(); + return; } - final long finish = System.nanoTime(); - getLogger().debug("Kinesis Scheduler shutdown finished after {} seconds", NANOSECONDS.toSeconds(finish - start)); + session.transfer(output.produced(), REL_SUCCESS); + if (!output.parseFailures().isEmpty()) { + session.transfer(output.parseFailures(), REL_PARSE_FAILURE); + session.adjustCounter("Records Parse Failure", output.parseFailures().size(), false); + } + session.adjustCounter("Records Consumed", output.totalRecordCount(), false); + final long dedupEvents = consumerClient.drainDeduplicatedEventCount(); + if (dedupEvents > 0) { + session.adjustCounter("EFO Deduplicated Events", dedupEvents, false); + } + + session.commitAsync( + () -> { + try { + shardManager.writeCheckpoints(batch.checkpoints()); + consumerClient.acknowledgeResults(accepted); + } finally { + consumerClient.releaseShards(claimedShards); + } + }, + failure -> { + try { + getLogger().error("Session commit failed; resetting shard iterators for re-consumption", failure); + consumerClient.rollbackResults(accepted); + } finally { + consumerClient.releaseShards(claimedShards); + } + }); } - @Override - public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { - if (!initialized.get()) { - if (!initializationResultFuture.isDone()) { - getLogger().debug("Waiting for Kinesis Scheduler to finish initialization"); - context.yield(); - return; + private List discardRelinquishedResults(final List consumedResults, final Set claimedShards) { + final List accepted = new ArrayList<>(); + final List discarded = new ArrayList<>(); + for (final ShardFetchResult result : consumedResults) { + if (shardManager.shouldProcessFetchedResult(result.shardId())) { + accepted.add(result); + } else { + discarded.add(result); } + } - checkInitializationResult(initializationResultFuture.resultNow()); + if (!discarded.isEmpty()) { + getLogger().debug("Discarding {} fetched shard result(s) for relinquished shards", discarded.size()); + consumerClient.rollbackResults(discarded); + for (final ShardFetchResult r : discarded) { + claimedShards.remove(r.shardId()); + } + consumerClient.releaseShards(discarded.stream().map(ShardFetchResult::shardId).toList()); } - final Optional leaseAcquired = recordBuffer.acquireBufferLease(); + return accepted; + } - leaseAcquired.ifPresentOrElse( - lease -> processRecordsFromBuffer(session, lease), - context::yield - ); + private PartitionedBatch partitionByShardAndCheckpoint(final List accepted) { + final Map> resultsByShard = new LinkedHashMap<>(); + for (final ShardFetchResult result : accepted) { + resultsByShard.computeIfAbsent(result.shardId(), k -> new ArrayList<>()).add(result); + } + for (final List shardResults : resultsByShard.values()) { + shardResults.sort(Comparator.comparing(r -> new BigInteger(r.firstSequenceNumber()))); + } + + final Map checkpoints = new HashMap<>(); + for (final ShardFetchResult result : accepted) { + final ShardCheckpoint incoming = new ShardCheckpoint(result.lastSequenceNumber(), result.lastSubSequenceNumber()); + checkpoints.merge(result.shardId(), incoming, ShardCheckpoint::max); + } + + return new PartitionedBatch(resultsByShard, checkpoints); } - private void checkInitializationResult(final InitializationResult initializationResult) { - switch (initializationResult) { - case InitializationResult.Success ignored -> { - final boolean wasInitialized = initialized.getAndSet(true); - if (!wasInitialized) { - getLogger().info( - "Started Kinesis Scheduler for stream [{}] with application name [{}] and workerId [{}]", - streamName, kinesisScheduler.applicationName(), kinesisScheduler.leaseManagementConfig().workerIdentifier()); + private List consumeRecords(final Set claimedShards) { + final List results = new ArrayList<>(); + final long startNanos = System.nanoTime(); + long estimatedBytes = 0; + + while (System.nanoTime() < startNanos + maxBatchNanos && estimatedBytes < maxBatchBytes) { + boolean foundAny = false; + + for (final String shardId : consumerClient.getShardIdsWithResults()) { + if (estimatedBytes >= maxBatchBytes) { + break; + } + if (!claimedShards.contains(shardId) && !consumerClient.claimShard(shardId)) { + continue; + } + claimedShards.add(shardId); + + ShardFetchResult result; + while ((result = consumerClient.pollShardResult(shardId)) != null) { + results.add(result); + estimatedBytes += estimateResultBytes(result); + foundAny = true; + if (estimatedBytes >= maxBatchBytes) { + break; + } } } - case InitializationResult.Failure failure -> { - cleanUpState(); - final ProcessException ex = failure.error() - .map(err -> new ProcessException("Initialization failed for stream [%s]".formatted(streamName), err)) - // This branch is active only when a scheduler was shutdown, but no initialization error was provided. - // This behavior isn't typical and wasn't observed. - .orElseGet(() -> new ProcessException("Initialization failed for stream [%s]".formatted(streamName))); + if (!foundAny) { + if (!consumerClient.hasPendingFetches()) { + break; + } - throw ex; + try { + consumerClient.awaitResults(QUEUE_POLL_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + break; + } } } + + return results; + } + + private void handleWriteFailure(final Exception cause, final List accepted, + final Set claimedShards, final ProcessContext context) { + getLogger().error("Failed to write consumed Kinesis records", cause); + consumerClient.rollbackResults(accepted); + consumerClient.releaseShards(claimedShards); + context.yield(); } - private void processRecordsFromBuffer(final ProcessSession session, final Lease lease) { + private WriteResult writeResults(final ProcessSession session, final ProcessContext context, + final Map> resultsByShard) { + final List produced = new ArrayList<>(); + final List parseFailures = new ArrayList<>(); + long totalRecordCount = 0; + long totalBytesConsumed = 0; + long maxMillisBehind = -1; + try { - final List records = recordBuffer.consumeRecords(lease); + if (processingStrategy == ProcessingStrategy.FLOW_FILE) { + final BatchAccumulator batch = new BatchAccumulator(); + for (final List shardResults : resultsByShard.values()) { + for (final ShardFetchResult result : shardResults) { + batch.updateMillisBehind(result.millisBehindLatest()); + for (final DeaggregatedRecord record : result.records()) { + batch.addBytes(record.data().length); + } + } + writeFlowFilePerRecord(session, shardResults, streamName, batch, produced); + } + totalRecordCount = batch.getRecordCount(); + totalBytesConsumed = batch.getBytesConsumed(); + maxMillisBehind = batch.getMaxMillisBehind(); + } else { + for (final Map.Entry> entry : resultsByShard.entrySet()) { + final BatchAccumulator batch = new BatchAccumulator(); + batch.setLastShardId(entry.getKey()); + for (final ShardFetchResult result : entry.getValue()) { + batch.updateMillisBehind(result.millisBehindLatest()); + batch.updateSequenceRange(result); + for (final DeaggregatedRecord record : result.records()) { + batch.addBytes(record.data().length); + batch.updateRecordRange(record); + } + } - if (records.isEmpty()) { - recordBuffer.returnBufferLease(lease); - return; - } + if (processingStrategy == ProcessingStrategy.LINE_DELIMITED || processingStrategy == ProcessingStrategy.DEMARCATOR) { + final byte[] delimiter; + if (processingStrategy == ProcessingStrategy.LINE_DELIMITED) { + delimiter = NEWLINE_DELIMITER; + } else { + final String demarcatorValue = context.getProperty(MESSAGE_DEMARCATOR).getValue(); + delimiter = demarcatorValue.getBytes(StandardCharsets.UTF_8); + } + writeDelimited(session, entry.getValue(), streamName, batch, delimiter, produced); + } else { + writeRecordOriented(session, context, entry.getValue(), streamName, batch, produced, parseFailures); + } - final String shardId = lease.shardId(); - switch (processingStrategy) { - case FLOW_FILE -> processRecordsAsRaw(session, shardId, records); - case RECORD -> processRecordsWithReader(session, shardId, records); - case DEMARCATOR -> processRecordsAsDemarcated(session, shardId, records); + totalRecordCount += batch.getRecordCount(); + totalBytesConsumed += batch.getBytesConsumed(); + maxMillisBehind = Math.max(maxMillisBehind, batch.getMaxMillisBehind()); + } } + } catch (final Exception e) { + session.remove(produced); + session.remove(parseFailures); + throw e; + } + + return new WriteResult(produced, parseFailures, totalRecordCount, totalBytesConsumed, maxMillisBehind); + } - session.adjustCounter("Records Processed", records.size(), false); + private void writeFlowFilePerRecord(final ProcessSession session, final List results, + final String streamName, final BatchAccumulator batch, final List output) { + for (final ShardFetchResult result : results) { + for (final DeaggregatedRecord record : result.records()) { + final byte[] recordBytes = record.data(); + FlowFile flowFile = session.create(); + try { + flowFile = session.write(flowFile, out -> out.write(recordBytes)); + + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put(ATTR_SHARD_ID, result.shardId()); + attributes.put(ATTR_FIRST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_LAST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_PARTITION_KEY, record.partitionKey()); + if (record.approximateArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(record.approximateArrivalTimestamp().toEpochMilli())); + } + attributes.put("record.count", "1"); + if (result.millisBehindLatest() >= 0) { + attributes.put(ATTR_MILLIS_BEHIND, String.valueOf(result.millisBehindLatest())); + } - session.commitAsync( - () -> commitRecords(lease), - __ -> rollbackRecords(lease) - ); - } catch (final RuntimeException e) { - rollbackRecords(lease); - throw e; + flowFile = session.putAllAttributes(flowFile, attributes); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, result.shardId())); + output.add(flowFile); + batch.incrementRecordCount(); + } catch (final Exception e) { + session.remove(flowFile); + throw e; + } + } } } - private void commitRecords(final Lease lease) { + private void writeDelimited(final ProcessSession session, final List results, + final String streamName, final BatchAccumulator batch, final byte[] delimiter, + final List output) { + FlowFile flowFile = session.create(); try { - recordBuffer.commitConsumedRecords(lease); - } finally { - recordBuffer.returnBufferLease(lease); + flowFile = session.write(flowFile, new OutputStreamCallback() { + @Override + public void process(final OutputStream out) throws IOException { + boolean first = true; + for (final ShardFetchResult result : results) { + for (final DeaggregatedRecord record : result.records()) { + if (!first) { + out.write(delimiter); + } + out.write(record.data()); + first = false; + batch.incrementRecordCount(); + } + } + } + }); + + flowFile = session.putAllAttributes(flowFile, createFlowFileAttributes(streamName, batch)); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; } } - private void rollbackRecords(final Lease lease) { + /** + * Writes Kinesis records as NiFi records using the configured Record Reader and Record Writer. + * + *

This method may appear unnecessarily complex, but it is intended to address specific requirements:

+ *
    + *
  • Keep records ordered in the same order they are received from Kinesis
  • + *
  • Create as few FlowFiles as necessary, keeping many records together in larger FlowFiles for performance reasons.
  • + *
+ * + *

Alternative options have been considered, as well:

+ *
    + *
  • Read each Record one at a time with a separate RecordReader. If its schema is different than the previous + * record, create a new FlowFile. However, when the stream is filled with JSON and many fields are nullable, this + * can look like a different schema for each Record when inference is used, thus creating many tiny FlowFiles.
  • + *
  • Read each Record one at a time with a separate RecordReader. Map the RecordSchema to the existing RecordWriter + * for that schema, if one exists, and write to that writer; if none exists, create a new one. This results in better + * grouping in many cases, but it results in the output being reordered, as we may write records 1, 2, 3 to writers + * A, B, A.
  • + *
  • Create a single InputStream and RecordReader for the entire batch. Create a single Writer for the entire batch. + * This way, we infer a single schema for the entire batch that is appropriate for all records. This bundles all records + * in the batch into a single FlowFile, which is ideal. However, this approach fails when we are not inferring the schema + * and the records do not all have the same schema. In that case, we can fail when attempting to read the records or when + * we attempt to write the records due to schema incompatibility.
  • + *
+ * + *

+ * Additionally, the existing RecordSchema API does not tell us whether or not a schema was inferred, + * so we cannot easily make a decision based on that knowledge. Therefore, we have taken an approach that + * attempts to process data using our preferred method, falling back as necessary to other options. + *

+ * + *

+ * The primary path ({@link #writeRecordBatch}) combines all records into a single InputStream + * via {@link KinesisRecordInputStream} and creates one RecordReader. This is optimal for formats + * like JSON where the schema is inferred from the data: a single InputStream lets the reader see + * all records and produce a unified schema for the writer. + *

+ * + *

+ * However, this approach fails when records carry incompatible embedded schemas (e.g. Avro + * containers with different field sets). The single reader sees only the first schema and cannot + * parse subsequent records that differ from it. When this happens, the method falls back to + * {@link #writeRecordBatchPerRecord}, which processes each record individually and splits output + * across multiple FlowFiles when schemas change. + *

+ */ + private void writeRecordOriented(final ProcessSession session, final ProcessContext context, + final List results, final String streamName, + final BatchAccumulator batch, final List output, + final List parseFailureOutput) { + + final List allRecords = new ArrayList<>(); + for (final ShardFetchResult result : results) { + allRecords.addAll(result.records()); + } + try { - recordBuffer.rollbackConsumedRecords(lease); - } finally { - recordBuffer.returnBufferLease(lease); + final RecordReaderFactory readerFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); + final RecordSetWriterFactory writerFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); + writeRecordBatch(session, readerFactory, writerFactory, outputStrategy, + allRecords, streamName, batch, output); + } catch (final Exception e) { + getLogger().debug("Combined-stream record processing failed; falling back to per-record processing", e); + batch.resetRecordCount(); + final RecordBatchResult result = writeRecordBatchPerRecord(session, context, allRecords, streamName, batch); + output.addAll(result.output()); + parseFailureOutput.addAll(result.parseFailures()); } } - private void processRecordsAsRaw(final ProcessSession session, final String shardId, final List records) { - for (final KinesisClientRecord record : records) { - FlowFile flowFile = session.create(); - flowFile = session.putAllAttributes(flowFile, ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, record, record)); + private void writeRecordBatch(final ProcessSession session, final RecordReaderFactory readerFactory, + final RecordSetWriterFactory writerFactory, final OutputStrategy outputStrategy, + final List records, + final String streamName, final BatchAccumulator batch, final List output) { - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - channel.write(record.data()); + FlowFile flowFile = session.create(); + try { + flowFile = session.write(flowFile, new OutputStreamCallback() { + @Override + public void process(final OutputStream out) throws IOException { + try (final InputStream kinesisInput = new KinesisRecordInputStream(records); + final RecordReader reader = readerFactory.createRecordReader(Map.of(), kinesisInput, -1, getLogger())) { + + RecordSchema writeSchema = reader.getSchema(); + if (outputStrategy == OutputStrategy.INJECT_METADATA) { + final List fields = new ArrayList<>(writeSchema.getFields()); + fields.add(KinesisRecordMetadata.FIELD_METADATA); + writeSchema = new SimpleRecordSchema(fields); + } else if (outputStrategy == OutputStrategy.USE_WRAPPER) { + writeSchema = new SimpleRecordSchema(List.of( + KinesisRecordMetadata.FIELD_METADATA, + new RecordField(WRAPPER_VALUE_FIELD, RecordFieldType.RECORD.getRecordDataType(writeSchema)))); + } + + try (final RecordSetWriter writer = writerFactory.createWriter(getLogger(), writeSchema, out, Map.of())) { + writer.beginRecordSet(); + + int recordIndex = 0; + org.apache.nifi.serialization.record.Record nifiRecord; + while ((nifiRecord = reader.nextRecord()) != null) { + if (recordIndex < records.size()) { + final DeaggregatedRecord record = records.get(recordIndex); + nifiRecord = decorateRecord(nifiRecord, record, record.shardId(), streamName, outputStrategy, writeSchema); + recordIndex++; + } + + writer.write(nifiRecord); + batch.incrementRecordCount(); + } + + writer.finishRecordSet(); + } + } catch (final MalformedRecordException | SchemaNotFoundException e) { + throw new IOException(e); + } } }); - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - session.transfer(flowFile, REL_SUCCESS); + flowFile = session.putAllAttributes(flowFile, createFlowFileAttributes(streamName, batch)); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; } } - private void processRecordsWithReader(final ProcessSession session, final String shardId, final List records) { - final ReaderRecordProcessor recordProcessor = readerRecordProcessor; - if (recordProcessor == null) { - throw new IllegalStateException("RecordProcessor has not been initialized"); - } + /** + * Fallback path that processes each Kinesis record individually, splitting output across multiple + * FlowFiles when the record schema changes between consecutive records. + * + *

This is invoked when the combined-stream approach ({@link #writeRecordBatch}) fails, which + * typically happens when the batch contains records with incompatible embedded schemas (e.g. Avro + * containers whose field sets differ). Rather than grouping or buffering records up front, this + * method makes a single pass: for each record it creates a RecordReader, compares the schema to + * the current writer's schema, and either continues writing to the same FlowFile or finalizes the + * current FlowFile and starts a new one. This preserves record ordering without demultiplexing.

+ * + *

Records that cannot be parsed (empty data, malformed content, missing schema) are collected + * and routed to the parse-failure relationship at the end.

+ * + * @param session the current process session + * @param context the current process context (used to resolve Record Reader, Record Writer, and Output Strategy) + * @param records the Kinesis records to process, in order + * @param streamName the Kinesis stream name, used for FlowFile attributes + * @param batch accumulator for batch-level attributes and record counting + * @return a {@link RecordBatchResult} containing the successfully written FlowFiles and any parse-failure FlowFiles + */ + private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session, final ProcessContext context, + final List records, + final String streamName, final BatchAccumulator batch) { - final ProcessingResult result = recordProcessor.processRecords(session, streamName, shardId, records); + final RecordReaderFactory readerFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); + final RecordSetWriterFactory writerFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); - session.transfer(result.successFlowFiles(), REL_SUCCESS); - session.transfer(result.parseFailureFlowFiles(), REL_PARSE_FAILURE); - } + final List output = new ArrayList<>(); + final List parseFailureOutput = new ArrayList<>(); + final List unparseable = new ArrayList<>(); + FlowFile currentFlowFile = null; + OutputStream currentOut = null; + RecordSetWriter currentWriter = null; + RecordSchema currentReadSchema = null; + RecordSchema currentWriteSchema = null; + int currentRecordCount = 0; - private void processRecordsAsDemarcated(final ProcessSession session, final String shardId, final List records) { - final byte[] demarcator = demarcatorValue; - if (demarcator == null) { - throw new IllegalStateException("Demarcator has not been initialized"); - } + try { + for (final DeaggregatedRecord record : records) { - FlowFile flowFile = session.create(); + if (record.data().length == 0) { + unparseable.add(record); + continue; + } - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, records.getFirst(), records.getLast()); - attributes.put(RECORD_COUNT, String.valueOf(records.size())); - flowFile = session.putAllAttributes(flowFile, attributes); + RecordSchema readSchema = null; + final List parsedRecords = new ArrayList<>(); + RecordReader reader = null; + try { + reader = readerFactory.createRecordReader(Map.of(), new ByteArrayInputStream(record.data()), record.data().length, getLogger()); + readSchema = reader.getSchema(); + org.apache.nifi.serialization.record.Record nifiRecord; + while ((nifiRecord = reader.nextRecord()) != null) { + parsedRecords.add(nifiRecord); + } + } catch (final MalformedRecordException | SchemaNotFoundException | IOException e) { + getLogger().debug("Kinesis record seq {} classified as unparseable: {}", + record.sequenceNumber(), e.getMessage()); + unparseable.add(record); + continue; + } finally { + closeQuietly(reader); + } - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - boolean writtenData = false; - for (final KinesisClientRecord record : records) { - if (writtenData) { - out.write(demarcator); + if (parsedRecords.isEmpty()) { + unparseable.add(record); + continue; + } + + if (currentWriter == null || !readSchema.equals(currentReadSchema)) { + if (currentWriter != null) { + currentWriter.finishRecordSet(); + currentWriter.close(); + currentOut.close(); + final Map attrs = createFlowFileAttributes(streamName, batch); + attrs.put("record.count", String.valueOf(currentRecordCount)); + currentFlowFile = session.putAllAttributes(currentFlowFile, attrs); + session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(currentFlowFile); + currentFlowFile = null; } - channel.write(record.data()); - writtenData = true; + + currentReadSchema = readSchema; + currentWriteSchema = buildWriteSchema(readSchema, outputStrategy); + currentFlowFile = session.create(); + currentOut = session.write(currentFlowFile); + currentWriter = writerFactory.createWriter(getLogger(), currentWriteSchema, currentOut, Map.of()); + currentWriter.beginRecordSet(); + currentRecordCount = 0; } + + for (final org.apache.nifi.serialization.record.Record parsed : parsedRecords) { + // TODO: We need to use decorateRecord above also. + final org.apache.nifi.serialization.record.Record decorated = + decorateRecord(parsed, record, record.shardId(), streamName, outputStrategy, currentWriteSchema); + currentWriter.write(decorated); + currentRecordCount++; + batch.incrementRecordCount(); + } + } + + if (currentWriter != null) { + currentWriter.finishRecordSet(); + currentWriter.close(); + currentOut.close(); + final Map attrs = createFlowFileAttributes(streamName, batch); + attrs.put("record.count", String.valueOf(currentRecordCount)); + currentFlowFile = session.putAllAttributes(currentFlowFile, attrs); + session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(currentFlowFile); + currentFlowFile = null; + } + } catch (final Exception e) { + closeQuietly(currentWriter); + closeQuietly(currentOut); + if (currentFlowFile != null) { + session.remove(currentFlowFile); + } + if (e instanceof RuntimeException re) { + throw re; } - }); + throw new ProcessException(e); + } - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); + if (!unparseable.isEmpty()) { + getLogger().warn("Encountered {} unparseable record(s) in shard {}; routing to parse failure", + unparseable.size(), batch.getLastShardId()); + writeParseFailures(session, unparseable, streamName, batch, parseFailureOutput); + } - session.transfer(flowFile, REL_SUCCESS); + return new RecordBatchResult(output, parseFailureOutput); } /** - * An adapter between Kinesis Consumer Library and {@link RecordBuffer}. + * Adjusts a read schema to the write schema required by the configured OutputStrategy. For + * {@code INJECT_METADATA} the metadata field is appended; for {@code USE_WRAPPER} a two-field + * wrapper schema is created; for {@code USE_VALUE} the read schema is returned unchanged. */ - private static class ConsumeKinesisRecordProcessor implements ShardRecordProcessor { + private static RecordSchema buildWriteSchema(final RecordSchema readSchema, final OutputStrategy outputStrategy) { + return switch (outputStrategy) { + case INJECT_METADATA -> { + final List fields = new ArrayList<>(readSchema.getFields()); + fields.add(KinesisRecordMetadata.FIELD_METADATA); + yield new SimpleRecordSchema(fields); + } + case USE_WRAPPER -> { + yield new SimpleRecordSchema(List.of( + KinesisRecordMetadata.FIELD_METADATA, + new RecordField(WRAPPER_VALUE_FIELD, RecordFieldType.RECORD.getRecordDataType(readSchema)))); + } + case USE_VALUE -> readSchema; + }; + } + + /** + * Attaches Kinesis metadata to a NiFi record according to the configured OutputStrategy. + */ + private static org.apache.nifi.serialization.record.Record decorateRecord( + final org.apache.nifi.serialization.record.Record nifiRecord, + final DeaggregatedRecord kinesisRecord, final String shardId, + final String streamName, final OutputStrategy outputStrategy, + final RecordSchema writeSchema) { + return switch (outputStrategy) { + case INJECT_METADATA -> { + final Map values = new HashMap<>(nifiRecord.toMap()); + values.put(KinesisRecordMetadata.METADATA, + KinesisRecordMetadata.composeMetadataObject(kinesisRecord, streamName, shardId)); + yield new MapRecord(writeSchema, values); + } + case USE_WRAPPER -> { + final Map wrapperValues = new HashMap<>(2, 1.0f); + wrapperValues.put(KinesisRecordMetadata.METADATA, + KinesisRecordMetadata.composeMetadataObject(kinesisRecord, streamName, shardId)); + wrapperValues.put(WRAPPER_VALUE_FIELD, nifiRecord); + yield new MapRecord(writeSchema, wrapperValues); + } + case USE_VALUE -> nifiRecord; + }; + } - private final RecordBuffer.ForKinesisClientLibrary recordBuffer; - private volatile @Nullable ShardBufferId bufferId; + private void writeParseFailures(final ProcessSession session, final List unparseable, + final String streamName, final BatchAccumulator batch, final List parseFailureOutput) { - ConsumeKinesisRecordProcessor(final MemoryBoundRecordBuffer recordBuffer) { - this.recordBuffer = recordBuffer; + for (final DeaggregatedRecord record : unparseable) { + FlowFile flowFile = session.create(); + try { + final byte[] rawBytes = record.data(); + flowFile = session.write(flowFile, out -> out.write(rawBytes)); + + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put(ATTR_FIRST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_LAST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_PARTITION_KEY, record.partitionKey()); + if (record.approximateArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(record.approximateArrivalTimestamp().toEpochMilli())); + } + attributes.put("record.count", "1"); + if (batch.getLastShardId() != null) { + attributes.put(ATTR_SHARD_ID, batch.getLastShardId()); + } + flowFile = session.putAllAttributes(flowFile, attributes); + parseFailureOutput.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; + } } + } - @Override - public void initialize(final InitializationInput initializationInput) { - bufferId = recordBuffer.createBuffer(initializationInput.shardId()); + private static long estimateResultBytes(final ShardFetchResult result) { + long bytes = 0; + for (final DeaggregatedRecord record : result.records()) { + bytes += record.data().length; + } + return bytes; + } + + private static Map createFlowFileAttributes(final String streamName, final BatchAccumulator batch) { + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put("record.count", String.valueOf(batch.getRecordCount())); + + if (batch.getMaxMillisBehind() >= 0) { + attributes.put(ATTR_MILLIS_BEHIND, String.valueOf(batch.getMaxMillisBehind())); + } + if (batch.getLastShardId() != null) { + attributes.put(ATTR_SHARD_ID, batch.getLastShardId()); + } + if (batch.getMinSequenceNumber() != null) { + attributes.put(ATTR_FIRST_SEQUENCE, batch.getMinSequenceNumber()); + } + if (batch.getMaxSequenceNumber() != null) { + attributes.put(ATTR_LAST_SEQUENCE, batch.getMaxSequenceNumber()); + } + if (batch.getMinSubSequenceNumber() != Long.MAX_VALUE) { + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(batch.getMinSubSequenceNumber())); + } + if (batch.getMaxSubSequenceNumber() != Long.MIN_VALUE) { + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(batch.getMaxSubSequenceNumber())); + } + if (batch.getLastPartitionKey() != null) { + attributes.put(ATTR_PARTITION_KEY, batch.getLastPartitionKey()); + } + if (batch.getEarliestArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(batch.getEarliestArrivalTimestamp().toEpochMilli())); + } + + return attributes; + } + + private static void closeQuietly(final AutoCloseable closeable) { + if (closeable != null) { + try { + closeable.close(); + } catch (final Exception ignored) { + } + } + } + + private static String buildTransitUri(final String streamName, final String shardId) { + if (shardId == null || shardId.isEmpty()) { + return "kinesis://" + streamName; + } + return "kinesis://" + streamName + "/" + shardId; + } + + // Exposed for testing to allow injection of mock Shard Manager + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return new KinesisShardManager(kinesisClient, dynamoDbClient, logger, checkpointTableName, streamName); + } + + // Exposed for testing to allow injection of a mock client + protected KinesisConsumerClient createConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger, final boolean efoMode) { + if (efoMode) { + return new EfoKinesisClient(kinesisClient, logger); + } + return new PollingKinesisClient(kinesisClient, logger); + } + + private record RecordBatchResult(List output, List parseFailures) { + } + + private static final class KinesisRecordInputStream extends InputStream { + private final List chunks; + private int chunkIndex; + private int positionInChunk; + private int markChunkIndex = -1; + private int markPositionInChunk; + + KinesisRecordInputStream(final List records) { + this.chunks = new ArrayList<>(records.size()); + for (final DeaggregatedRecord record : records) { + final byte[] data = record.data(); + if (data.length > 0) { + chunks.add(data); + } + } } @Override - public void processRecords(final ProcessRecordsInput processRecordsInput) { - if (bufferId == null) { - throw new IllegalStateException("Buffer ID not found: Record Processor not initialized"); + public int read() { + while (chunkIndex < chunks.size()) { + final byte[] current = chunks.get(chunkIndex); + if (positionInChunk < current.length) { + return current[positionInChunk++] & 0xFF; + } + chunkIndex++; + positionInChunk = 0; } - recordBuffer.addRecords(bufferId, processRecordsInput.records(), processRecordsInput.checkpointer()); + return -1; } @Override - public void leaseLost(final LeaseLostInput leaseLostInput) { - if (bufferId != null) { - recordBuffer.consumerLeaseLost(bufferId); + public int read(final byte[] buffer, final int offset, final int length) { + if (chunkIndex >= chunks.size()) { + return -1; + } + if (length == 0) { + return 0; + } + + int totalRead = 0; + while (totalRead < length && chunkIndex < chunks.size()) { + final byte[] current = chunks.get(chunkIndex); + final int remaining = current.length - positionInChunk; + if (remaining <= 0) { + chunkIndex++; + positionInChunk = 0; + continue; + } + + final int toRead = Math.min(length - totalRead, remaining); + System.arraycopy(current, positionInChunk, buffer, offset + totalRead, toRead); + positionInChunk += toRead; + totalRead += toRead; } + + return totalRead == 0 ? -1 : totalRead; } @Override - public void shardEnded(final ShardEndedInput shardEndedInput) { - if (bufferId != null) { - recordBuffer.checkpointEndedShard(bufferId, shardEndedInput.checkpointer()); + public int available() { + if (chunkIndex >= chunks.size()) { + return 0; } + return chunks.get(chunkIndex).length - positionInChunk; + } + + @Override + public boolean markSupported() { + return true; } @Override - public void shutdownRequested(final ShutdownRequestedInput shutdownRequestedInput) { - if (bufferId != null) { - recordBuffer.shutdownShardConsumption(bufferId, shutdownRequestedInput.checkpointer()); + public void mark(final int readLimit) { + markChunkIndex = chunkIndex; + markPositionInChunk = positionInChunk; + } + + @Override + public void reset() throws IOException { + if (markChunkIndex < 0) { + throw new IOException("Stream not marked"); } + chunkIndex = markChunkIndex; + positionInChunk = markPositionInChunk; } } - private static final class InitializationStateChangeListener implements WorkerStateChangeListener { + private record PartitionedBatch(Map> resultsByShard, Map checkpoints) { + } - private final ComponentLog logger; + private record WriteResult(List produced, List parseFailures, + long totalRecordCount, long totalBytesConsumed, long maxMillisBehind) { + } - private final CompletableFuture resultFuture = new CompletableFuture<>(); + private static final class BatchAccumulator { + private long bytesConsumed; + private long recordCount; + private long maxMillisBehind = -1; + private String minSequenceNumber; + private String maxSequenceNumber; + private long minSubSequenceNumber = Long.MAX_VALUE; + private long maxSubSequenceNumber = Long.MIN_VALUE; + private String lastPartitionKey; + private Instant earliestArrivalTimestamp; + private String lastShardId; + + long getBytesConsumed() { + return bytesConsumed; + } - private volatile @Nullable Throwable initializationFailure; + long getRecordCount() { + return recordCount; + } - InitializationStateChangeListener(final ComponentLog logger) { - this.logger = logger; + long getMaxMillisBehind() { + return maxMillisBehind; } - @Override - public void onWorkerStateChange(final WorkerState newState) { - logger.info("Worker state changed to [{}]", newState); + String getMinSequenceNumber() { + return minSequenceNumber; + } - if (newState == WorkerState.STARTED) { - resultFuture.complete(new InitializationResult.Success()); - } else if (newState == WorkerState.SHUT_DOWN) { - resultFuture.complete(new InitializationResult.Failure(Optional.ofNullable(initializationFailure))); - } + String getMaxSequenceNumber() { + return maxSequenceNumber; } - @Override - public void onAllInitializationAttemptsFailed(final Throwable e) { - // This method is called before the SHUT_DOWN_STARTED phase. - // Memorizing the error until the Scheduler is SHUT_DOWN. - initializationFailure = e; + long getMinSubSequenceNumber() { + return minSubSequenceNumber; } - Future result() { - return resultFuture; + long getMaxSubSequenceNumber() { + return maxSubSequenceNumber; + } + + String getLastPartitionKey() { + return lastPartitionKey; + } + + Instant getEarliestArrivalTimestamp() { + return earliestArrivalTimestamp; + } + + String getLastShardId() { + return lastShardId; + } + + void setLastShardId(final String shardId) { + lastShardId = shardId; } - } - private sealed interface InitializationResult { - record Success() implements InitializationResult { + void addBytes(final long bytes) { + bytesConsumed += bytes; } - record Failure(Optional error) implements InitializationResult { + void incrementRecordCount() { + recordCount++; + } + + void resetRecordCount() { + recordCount = 0; + } + + void updateMillisBehind(final long millisBehindLatest) { + maxMillisBehind = Math.max(maxMillisBehind, millisBehindLatest); + } + + void updateSequenceRange(final ShardFetchResult result) { + final String firstSeq = result.firstSequenceNumber(); + final String lastSeq = result.lastSequenceNumber(); + if (minSequenceNumber == null || new BigInteger(firstSeq).compareTo(new BigInteger(minSequenceNumber)) < 0) { + minSequenceNumber = firstSeq; + } + if (maxSequenceNumber == null || new BigInteger(lastSeq).compareTo(new BigInteger(maxSequenceNumber)) > 0) { + maxSequenceNumber = lastSeq; + } + } + + void updateRecordRange(final DeaggregatedRecord record) { + final long subSeq = record.subSequenceNumber(); + if (subSeq < minSubSequenceNumber) { + minSubSequenceNumber = subSeq; + } + if (subSeq > maxSubSequenceNumber) { + maxSubSequenceNumber = subSeq; + } + lastPartitionKey = record.partitionKey(); + final Instant arrival = record.approximateArrivalTimestamp(); + if (arrival != null && (earliestArrivalTimestamp == null || arrival.isBefore(earliestArrivalTimestamp))) { + earliestArrivalTimestamp = arrival; + } } } @@ -919,6 +1464,7 @@ public String getDescription() { enum ProcessingStrategy implements DescribedValue { FLOW_FILE("Write one FlowFile for each consumed Kinesis Record"), + LINE_DELIMITED("Write one FlowFile containing multiple consumed Kinesis Records separated by line delimiters"), RECORD("Write one FlowFile containing multiple consumed Kinesis Records processed with Record Reader and Record Writer"), DEMARCATOR("Write one FlowFile containing multiple consumed Kinesis Records separated by a configurable demarcator"); @@ -943,15 +1489,11 @@ public String getDescription() { return description; } - private static ProcessingStrategy from(final String name) { - // As long as getValue() returns name(), using valueOf is fine. - return ProcessingStrategy.valueOf(name); - } } enum InitialPosition implements DescribedValue { - TRIM_HORIZON("Trim Horizon", "Start reading at the last untrimmed record in the shard in the system, which is the oldest data record in the shard."), - LATEST("Latest", "Start reading just after the most recent record in the shard, so that you always read the most recent data in the shard."), + TRIM_HORIZON("Trim Horizon", "Start reading at the last untrimmed record in the shard."), + LATEST("Latest", "Start reading just after the most recent record in the shard."), AT_TIMESTAMP("At Timestamp", "Start reading at the record with the specified timestamp."); private final String displayName; @@ -978,39 +1520,10 @@ public String getDescription() { } } - enum MetricsPublishing implements DescribedValue { - DISABLED("Disabled", "No metrics are published"), - LOGS("Logs", "Metrics are published to application logs"), - CLOUDWATCH("CloudWatch", "Metrics are published to Amazon CloudWatch"); - - private final String displayName; - private final String description; - - MetricsPublishing(final String displayName, final String description) { - this.displayName = displayName; - this.description = description; - } - - @Override - public String getValue() { - return name(); - } - - @Override - public String getDisplayName() { - return displayName; - } - - @Override - public String getDescription() { - return description; - } - } - enum OutputStrategy implements DescribedValue { USE_VALUE("Use Content as Value", "Write only the Kinesis Record value to the FlowFile record."), USE_WRAPPER("Use Wrapper", "Write the Kinesis Record value and metadata into the FlowFile record."), - INJECT_METADATA("Inject Metadata", "Write the Kinesis Record value to the FlowFile record and add a sub-record to it with metadata."); + INJECT_METADATA("Inject Metadata", "Write the Kinesis Record value to the FlowFile record and add a sub-record with metadata."); private final String displayName; private final String description; @@ -1035,19 +1548,4 @@ public String getDescription() { return description; } } - - // Visible for tests only. - @Nullable URI getKinesisEndpointOverride() { - return null; - } - - // Visible for tests only. - @Nullable URI getDynamoDbEndpointOverride() { - return null; - } - - // Visible for tests only. - @Nullable URI getCloudwatchEndpointOverride() { - return null; - } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java deleted file mode 100644 index acfc91b74e29..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.HashMap; -import java.util.Map; - -final class ConsumeKinesisAttributes { - - private static final String PREFIX = "aws.kinesis."; - - // AWS Kinesis attributes. - static final String STREAM_NAME = PREFIX + "stream.name"; - static final String SHARD_ID = PREFIX + "shard.id"; - static final String FIRST_SEQUENCE_NUMBER = PREFIX + "first.sequence.number"; - static final String FIRST_SUB_SEQUENCE_NUMBER = PREFIX + "first.subsequence.number"; - static final String LAST_SEQUENCE_NUMBER = PREFIX + "last.sequence.number"; - static final String LAST_SUB_SEQUENCE_NUMBER = PREFIX + "last.subsequence.number"; - - static final String PARTITION_KEY = PREFIX + "partition.key"; - static final String APPROXIMATE_ARRIVAL_TIMESTAMP = PREFIX + "approximate.arrival.timestamp.ms"; - - // Record attributes. - static final String MIME_TYPE = "mime.type"; - static final String RECORD_COUNT = "record.count"; - static final String RECORD_ERROR_MESSAGE = "record.error.message"; - - /** - * Creates a map of FlowFile attributes from the provided Kinesis records. - * - * @param streamName the name of the Kinesis stream the FileFile records came from. - * @param shardId the shard ID the FlowFile records came from. - * @param firstRecord the first Kinesis record in the FlowFile. - * @param lastRecord the last Kinesis record in the FlowFile. - * @return a mutable map with kinesis attributes. - */ - static Map fromKinesisRecords( - final String streamName, - final String shardId, - final KinesisClientRecord firstRecord, - final KinesisClientRecord lastRecord) { - final Map attributes = new HashMap<>(8); - - attributes.put(STREAM_NAME, streamName); - attributes.put(SHARD_ID, shardId); - - attributes.put(FIRST_SEQUENCE_NUMBER, firstRecord.sequenceNumber()); - attributes.put(FIRST_SUB_SEQUENCE_NUMBER, String.valueOf(firstRecord.subSequenceNumber())); - - attributes.put(LAST_SEQUENCE_NUMBER, lastRecord.sequenceNumber()); - attributes.put(LAST_SUB_SEQUENCE_NUMBER, String.valueOf(lastRecord.subSequenceNumber())); - - attributes.put(PARTITION_KEY, lastRecord.partitionKey()); - - if (lastRecord.approximateArrivalTimestamp() != null) { - attributes.put(APPROXIMATE_ARRIVAL_TIMESTAMP, String.valueOf(lastRecord.approximateArrivalTimestamp().toEpochMilli())); - } - - return attributes; - } - - private ConsumeKinesisAttributes() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java new file mode 100644 index 000000000000..95290725881e --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import java.time.Instant; + +/** + * A single user record extracted from a Kinesis record. For non-aggregated records, + * the fields map directly from the Kinesis API {@code Record}. For KPL-aggregated + * records, each sub-record within the aggregate gets its own instance with a unique + * {@code subSequenceNumber}. + * + * @param shardId the shard from which this record was fetched + * @param sequenceNumber the Kinesis sequence number of the enclosing record + * @param subSequenceNumber zero for non-aggregated records; index within the aggregate for KPL records + * @param partitionKey the partition key (from the enclosing record or the KPL sub-record) + * @param data the user payload bytes + * @param approximateArrivalTimestamp approximate time the enclosing record arrived at Kinesis + */ +record DeaggregatedRecord( + String shardId, + String sequenceNumber, + long subSequenceNumber, + String partitionKey, + byte[] data, + Instant approximateArrivalTimestamp) { +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java new file mode 100644 index 000000000000..32b6132662e4 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java @@ -0,0 +1,596 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ConsumerStatus; +import software.amazon.awssdk.services.kinesis.model.DeregisterStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.LimitExceededException; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.RegisterStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.RegisterStreamConsumerResponse; +import software.amazon.awssdk.services.kinesis.model.ResourceInUseException; +import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; + +import java.io.IOException; +import java.math.BigInteger; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +/** + * Enhanced Fan-Out Kinesis consumer that uses SubscribeToShard with dedicated throughput + * per shard via HTTP/2. Uses Reactive Streams demand-driven backpressure to control the + * rate of event delivery. + */ +final class EfoKinesisClient extends KinesisConsumerClient { + + private static final long SUBSCRIBE_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(5); + private static final long CONSUMER_REGISTRATION_POLL_MILLIS = 1_000; + private static final int CONSUMER_REGISTRATION_MAX_ATTEMPTS = 60; + private static final int MAX_SUBSCRIPTIONS_PER_TRIGGER = 10; + private static final int MAX_QUEUED_RESULTS = 200; + + private final Map shardConsumers = new ConcurrentHashMap<>(); + private volatile KinesisAsyncClient kinesisAsyncClient; + private volatile String consumerArn; + private volatile Instant timestampForInitialPosition; + + EfoKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger) { + super(kinesisClient, logger); + } + + void setTimestampForInitialPosition(final Instant timestamp) { + this.timestampForInitialPosition = timestamp; + } + + @Override + void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { + this.kinesisAsyncClient = asyncClient; + registerEfoConsumer(streamName, consumerName); + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, final String initialStreamPosition, final KinesisShardManager shardManager) { + if (totalQueuedResults() >= MAX_QUEUED_RESULTS) { + return; + } + + int subscriptionsCreated = 0; + final long now = System.nanoTime(); + + for (final Shard shard : shards) { + if (subscriptionsCreated >= MAX_SUBSCRIPTIONS_PER_TRIGGER) { + break; + } + + final String shardId = shard.shardId(); + final ShardConsumer existing = shardConsumers.get(shardId); + + if (existing == null) { + final String lastSeq = shardManager.readCheckpoint(shardId); + final StartingPosition startingPosition = buildStartingPosition(lastSeq, initialStreamPosition); + logger.info("Creating EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), lastSeq); + final ShardConsumer sc = new ShardConsumer(shardId, EfoKinesisClient.this::enqueueResult, logger); + final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, sc); + if (prior == null) { + try { + sc.subscribe(kinesisAsyncClient, consumerArn, startingPosition); + } catch (final Exception e) { + shardConsumers.remove(shardId, sc); + throw e; + } + subscriptionsCreated++; + } + } else if (existing.isSubscriptionExpired()) { + final long lastAttempt = existing.getLastSubscribeAttemptNanos(); + if (lastAttempt > 0 && now < lastAttempt + SUBSCRIBE_BACKOFF_NANOS) { + continue; + } + + final String resumeSeq = maxSequenceNumber( + existing.getLastQueuedSequenceNumber(), + existing.getLastAcknowledgedSequenceNumber(), + shardManager.readCheckpoint(shardId)); + final StartingPosition startingPosition = buildStartingPosition(resumeSeq, initialStreamPosition); + logger.debug("Renewing expired EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), resumeSeq); + existing.subscribe(kinesisAsyncClient, consumerArn, startingPosition); + subscriptionsCreated++; + } + } + } + + @Override + boolean hasPendingFetches() { + return !shardConsumers.isEmpty(); + } + + @Override + long drainDeduplicatedEventCount() { + long total = 0; + for (final ShardConsumer sc : shardConsumers.values()) { + total += sc.drainDeduplicatedEventCount(); + } + return total; + } + + @Override + void acknowledgeResults(final List results) { + final Map consumersToRequest = new HashMap<>(); + + for (final ShardFetchResult result : results) { + final ShardConsumer consumer = shardConsumers.get(result.shardId()); + if (consumer != null) { + consumer.acknowledgeSequenceNumber(result.lastSequenceNumber()); + consumer.markAcknowledged(); + consumersToRequest.putIfAbsent(result.shardId(), consumer); + } + } + + for (final ShardConsumer consumer : consumersToRequest.values()) { + consumer.requestNext(); + } + } + + @Override + void rollbackResults(final List results) { + for (final ShardFetchResult result : results) { + final ShardConsumer sc = shardConsumers.remove(result.shardId()); + if (sc != null) { + sc.cancel(); + } + } + } + + @Override + void removeUnownedShards(final Set ownedShards) { + shardConsumers.entrySet().removeIf(entry -> { + if (!ownedShards.contains(entry.getKey())) { + entry.getValue().cancel(); + return true; + } + return false; + }); + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + if (!shouldLogDiagnostics()) { + return; + } + + final long now = System.nanoTime(); + int activeSubscriptions = 0; + int expiredSubscriptions = 0; + int backedOff = 0; + for (final ShardConsumer sc : shardConsumers.values()) { + if (sc.isSubscriptionExpired()) { + expiredSubscriptions++; + final long lastAttempt = sc.getLastSubscribeAttemptNanos(); + if (lastAttempt > 0 && now < lastAttempt + SUBSCRIBE_BACKOFF_NANOS) { + backedOff++; + } + } else { + activeSubscriptions++; + } + } + + final int queueDepth = totalQueuedResults(); + logger.debug("Kinesis EFO diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, activeSubscriptions={}, expiredSubscriptions={}, backedOff={}", + cachedShardCount, ownedCount, queueDepth, MAX_QUEUED_RESULTS, shardConsumers.size(), activeSubscriptions, expiredSubscriptions, backedOff); + } + + @Override + void close() { + for (final ShardConsumer sc : shardConsumers.values()) { + sc.cancel(); + } + shardConsumers.clear(); + + deregisterEfoConsumer(); + + if (kinesisAsyncClient != null) { + kinesisAsyncClient.close(); + kinesisAsyncClient = null; + } + + super.close(); + } + + private void deregisterEfoConsumer() { + final String arn = consumerArn; + consumerArn = null; + if (arn == null) { + return; + } + + try { + kinesisClient.deregisterStreamConsumer(DeregisterStreamConsumerRequest.builder() + .consumerARN(arn) + .build()); + logger.info("Deregistered EFO consumer [{}]", arn); + } catch (final Exception e) { + logger.warn("Failed to deregister EFO consumer [{}]; manual cleanup may be required", arn, e); + } + } + + private void registerEfoConsumer(final String streamName, final String consumerName) { + final DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder().streamName(streamName).build(); + final DescribeStreamResponse describeResponse = kinesisClient.describeStream(describeStreamRequest); + final String arn = describeResponse.streamDescription().streamARN(); + + try { + final DescribeStreamConsumerRequest describeConsumerReq = DescribeStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerStatus(); + if (status == ConsumerStatus.ACTIVE) { + consumerArn = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerARN(); + logger.info("EFO consumer [{}] already registered and ACTIVE", consumerName); + return; + } + } catch (final ResourceNotFoundException ignored) { + } + + try { + final RegisterStreamConsumerRequest registerRequest = RegisterStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + final RegisterStreamConsumerResponse registerResponse = kinesisClient.registerStreamConsumer(registerRequest); + consumerArn = registerResponse.consumer().consumerARN(); + logger.info("Registered EFO consumer [{}], waiting for ACTIVE status", consumerName); + } catch (final ResourceInUseException e) { + final DescribeStreamConsumerRequest fallbackRequest = DescribeStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + consumerArn = kinesisClient.describeStreamConsumer(fallbackRequest).consumerDescription().consumerARN(); + logger.info("EFO consumer [{}] already being registered", consumerName); + } + + waitForConsumerActive(arn, consumerName); + } + + private void waitForConsumerActive(final String theStreamArn, final String consumerName) { + final DescribeStreamConsumerRequest describeConsumerRequest = DescribeStreamConsumerRequest.builder() + .streamARN(theStreamArn) + .consumerName(consumerName) + .build(); + + for (int i = 0; i < CONSUMER_REGISTRATION_MAX_ATTEMPTS; i++) { + final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerRequest).consumerDescription().consumerStatus(); + if (status == ConsumerStatus.ACTIVE) { + logger.info("EFO consumer [{}] is now ACTIVE", consumerName); + return; + } + + try { + Thread.sleep(CONSUMER_REGISTRATION_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for EFO consumer registration", e); + } + } + + throw new ProcessException("EFO consumer [%s] did not become ACTIVE within %d seconds".formatted(consumerName, CONSUMER_REGISTRATION_MAX_ATTEMPTS)); + } + + private static String maxSequenceNumber(final String... candidates) { + BigInteger max = null; + String maxStr = null; + for (final String candidate : candidates) { + if (candidate != null) { + final BigInteger value = new BigInteger(candidate); + if (max == null || value.compareTo(max) > 0) { + max = value; + maxStr = candidate; + } + } + } + return maxStr; + } + + private StartingPosition buildStartingPosition(final String sequenceNumber, final String initialStreamPosition) { + if (sequenceNumber != null) { + return StartingPosition.builder() + .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER) + .sequenceNumber(sequenceNumber) + .build(); + } + final ShardIteratorType iteratorType = ShardIteratorType.fromValue(initialStreamPosition); + final StartingPosition.Builder builder = StartingPosition.builder().type(iteratorType); + if (iteratorType == ShardIteratorType.AT_TIMESTAMP && timestampForInitialPosition != null) { + builder.timestamp(timestampForInitialPosition); + } + return builder.build(); + } + + static final class ShardConsumer { + private final String shardId; + private final Consumer resultSink; + private final ComponentLog consumerLogger; + private final AtomicBoolean subscribing = new AtomicBoolean(false); + private final AtomicInteger subscriptionGeneration = new AtomicInteger(); + private final AtomicLong deduplicatedEvents = new AtomicLong(); + private volatile Subscription subscription; + private volatile CompletableFuture subscriptionFuture; + private volatile long lastSubscribeAttemptNanos; + private volatile String lastQueuedSequenceNumber; + private volatile String lastOnNextMaxSequence; + private final AtomicReference lastAcknowledgedSequenceNumber = new AtomicReference<>(); + private final AtomicInteger queuedResultCount = new AtomicInteger(); + + ShardConsumer(final String shardId, final Consumer resultSink, final ComponentLog consumerLogger) { + this.shardId = shardId; + this.resultSink = resultSink; + this.consumerLogger = consumerLogger; + } + + void subscribe(final KinesisAsyncClient asyncClient, final String theConsumerArn, final StartingPosition startingPosition) { + if (!subscribing.compareAndSet(false, true)) { + return; + } + + final int generation = subscriptionGeneration.incrementAndGet(); + + try { + final SubscribeToShardRequest request = SubscribeToShardRequest.builder() + .consumerARN(theConsumerArn) + .shardId(shardId) + .startingPosition(startingPosition) + .build(); + + final SubscribeToShardResponseHandler handler = SubscribeToShardResponseHandler.builder() + .subscriber(() -> new DemandDrivenSubscriber(generation)) + .onError(t -> { + logSubscriptionError(t); + endSubscriptionIfCurrent(generation); + }) + .build(); + + lastSubscribeAttemptNanos = System.nanoTime(); + subscriptionFuture = asyncClient.subscribeToShard(request, handler); + } catch (final Exception e) { + subscribing.set(false); + throw e; + } + } + + void requestNext() { + final Subscription sub = subscription; + if (sub != null) { + sub.request(1); + } + } + + boolean isSubscriptionExpired() { + final CompletableFuture future = subscriptionFuture; + return future == null || future.isDone(); + } + + long getLastSubscribeAttemptNanos() { + return lastSubscribeAttemptNanos; + } + + String getLastAcknowledgedSequenceNumber() { + return lastAcknowledgedSequenceNumber.get(); + } + + String getLastQueuedSequenceNumber() { + return lastQueuedSequenceNumber; + } + + long drainDeduplicatedEventCount() { + return deduplicatedEvents.getAndSet(0); + } + + void markAcknowledged() { + queuedResultCount.updateAndGet(value -> value > 0 ? value - 1 : 0); + } + + void acknowledgeSequenceNumber(final String sequenceNumber) { + if (sequenceNumber == null) { + return; + } + + final BigInteger incoming = new BigInteger(sequenceNumber); + String existing; + while (true) { + existing = lastAcknowledgedSequenceNumber.get(); + if (existing != null && incoming.compareTo(new BigInteger(existing)) <= 0) { + return; + } + if (lastAcknowledgedSequenceNumber.compareAndSet(existing, sequenceNumber)) { + return; + } + } + } + + void cancel() { + final CompletableFuture future = subscriptionFuture; + if (future != null) { + future.cancel(true); + } + subscription = null; + } + + private void logSubscriptionError(final Throwable t) { + if (isCancellation(t)) { + consumerLogger.debug("EFO subscription cancelled for shard {}", shardId); + } else if (isRetryableSubscriptionError(t)) { + consumerLogger.info("EFO subscription temporarily rejected for shard {}; will retry after backoff", shardId); + } else if (isRetryableStreamDisconnect(t)) { + consumerLogger.info("EFO subscription disconnected for shard {}; will retry", shardId); + } else { + consumerLogger.error("EFO subscription error for shard {}", shardId, t); + } + } + + private static boolean isCancellation(final Throwable t) { + if (t instanceof CancellationException) { + return true; + } + if (t instanceof IOException && "Request cancelled".equals(t.getMessage())) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isCancellation(cause); + } + + private static boolean isRetryableSubscriptionError(final Throwable t) { + if (t instanceof LimitExceededException || t instanceof ResourceInUseException) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isRetryableSubscriptionError(cause); + } + + private static boolean isRetryableStreamDisconnect(final Throwable t) { + if (t instanceof IOException || t instanceof SdkClientException) { + return true; + } + if (t instanceof AwsServiceException ase && ase.statusCode() >= 500) { + return true; + } + final String className = t.getClass().getName(); + if (className.startsWith("io.netty.")) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isRetryableStreamDisconnect(cause); + } + + private void endSubscriptionIfCurrent(final int generation) { + if (subscriptionGeneration.get() == generation) { + subscription = null; + subscribing.set(false); + } + } + + private List deduplicateRecords(final List records) { + final String prevMax = lastOnNextMaxSequence; + if (prevMax == null) { + return records; + } + + final BigInteger threshold = new BigInteger(prevMax); + int firstNewIndex = records.size(); + for (int i = 0; i < records.size(); i++) { + if (new BigInteger(records.get(i).sequenceNumber()).compareTo(threshold) > 0) { + firstNewIndex = i; + break; + } + } + + if (firstNewIndex == 0) { + return records; + } + + final int kept = records.size() - firstNewIndex; + deduplicatedEvents.incrementAndGet(); + if (kept == 0) { + consumerLogger.debug("Skipped re-delivered EFO event for shard {} ({} records already seen)", shardId, records.size()); + } else { + consumerLogger.debug("Filtered {} duplicate record(s) from EFO event for shard {} (kept {})", firstNewIndex, shardId, kept); + } + + return records.subList(firstNewIndex, records.size()); + } + + private class DemandDrivenSubscriber implements Subscriber { + private final int generation; + + DemandDrivenSubscriber(final int generation) { + this.generation = generation; + } + + @Override + public void onSubscribe(final Subscription sub) { + subscription = sub; + sub.request(1); + } + + @Override + public void onNext(final SubscribeToShardEventStream eventStream) { + if (!(eventStream instanceof SubscribeToShardEvent event)) { + requestNext(); + return; + } + + if (event.records().isEmpty()) { + requestNext(); + return; + } + + final long millisBehind = event.millisBehindLatest() != null ? event.millisBehindLatest() : -1; + final List records = deduplicateRecords(event.records()); + if (records.isEmpty()) { + requestNext(); + return; + } + + final ShardFetchResult result = createFetchResult(shardId, records, millisBehind); + lastOnNextMaxSequence = result.lastSequenceNumber(); + lastQueuedSequenceNumber = result.lastSequenceNumber(); + queuedResultCount.incrementAndGet(); + resultSink.accept(result); + } + + @Override + public void onError(final Throwable t) { + logSubscriptionError(t); + endSubscriptionIfCurrent(generation); + } + + @Override + public void onComplete() { + consumerLogger.debug("EFO subscription completed normally for shard {}", shardId); + endSubscriptionIfCurrent(generation); + } + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java new file mode 100644 index 000000000000..b0b98dc32476 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base for Kinesis consumer clients. Provides per-shard result queues and + * shard-claim infrastructure used by both polling (shared-throughput) and EFO + * (enhanced-fan-out) implementations. + * + *

Results are stored in per-shard FIFO queues rather than a single shared queue. + * This guarantees that results for the same shard are always consumed in enqueue order, + * preventing out-of-order delivery when concurrent tasks cannot claim the same shard. + */ +abstract class KinesisConsumerClient { + + private static final long DIAGNOSTIC_INTERVAL_NANOS = TimeUnit.SECONDS.toNanos(30); + + protected final KinesisClient kinesisClient; + protected final ComponentLog logger; + private final Map> shardQueues = new ConcurrentHashMap<>(); + private final Semaphore resultNotification = new Semaphore(0); + protected final Set shardsInFlight = ConcurrentHashMap.newKeySet(); + + private volatile long lastDiagnosticLogNanos; + + KinesisConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger) { + this.kinesisClient = kinesisClient; + this.logger = logger; + } + + void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { + } + + abstract void startFetches(List shards, String streamName, int batchSize, + String initialStreamPosition, KinesisShardManager shardManager); + + abstract boolean hasPendingFetches(); + + abstract void acknowledgeResults(List results); + + abstract void rollbackResults(List results); + + abstract void removeUnownedShards(Set ownedShards); + + abstract void logDiagnostics(int ownedCount, int cachedShardCount); + + void close() { + shardQueues.clear(); + resultNotification.drainPermits(); + shardsInFlight.clear(); + } + + void enqueueResult(final ShardFetchResult result) { + shardQueues.computeIfAbsent(result.shardId(), k -> new ConcurrentLinkedQueue<>()).add(result); + resultNotification.release(); + } + + ShardFetchResult pollShardResult(final String shardId) { + final Queue queue = shardQueues.get(shardId); + return queue != null ? queue.poll() : null; + } + + ShardFetchResult pollAnyResult(final long timeout, final TimeUnit unit) throws InterruptedException { + final long deadlineNanos = System.nanoTime() + unit.toNanos(timeout); + while (System.nanoTime() < deadlineNanos) { + for (final Queue queue : shardQueues.values()) { + final ShardFetchResult result = queue.poll(); + if (result != null) { + return result; + } + } + final long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadlineNanos - System.nanoTime()); + if (remainingMs <= 0) { + break; + } + resultNotification.drainPermits(); + resultNotification.tryAcquire(Math.min(remainingMs, 100), TimeUnit.MILLISECONDS); + } + return null; + } + + Set getShardIdsWithResults() { + final Set ids = new HashSet<>(); + for (final Map.Entry> entry : shardQueues.entrySet()) { + if (!entry.getValue().isEmpty()) { + ids.add(entry.getKey()); + } + } + return ids; + } + + boolean awaitResults(final long timeout, final TimeUnit unit) throws InterruptedException { + resultNotification.drainPermits(); + return resultNotification.tryAcquire(timeout, unit); + } + + int totalQueuedResults() { + int total = 0; + for (final Queue queue : shardQueues.values()) { + total += queue.size(); + } + return total; + } + + boolean hasQueuedResults() { + for (final Queue queue : shardQueues.values()) { + if (!queue.isEmpty()) { + return true; + } + } + return false; + } + + boolean claimShard(final String shardId) { + return shardsInFlight.add(shardId); + } + + void releaseShards(final Collection shardIds) { + shardsInFlight.removeAll(shardIds); + } + + protected static ShardFetchResult createFetchResult(final String shardId, final List records, final long millisBehindLatest) { + return new ShardFetchResult(shardId, KplDeaggregator.deaggregate(shardId, records), millisBehindLatest); + } + + long drainDeduplicatedEventCount() { + return 0; + } + + protected boolean shouldLogDiagnostics() { + final long now = System.nanoTime(); + if (now < DIAGNOSTIC_INTERVAL_NANOS + lastDiagnosticLogNanos) { + return false; + } + lastDiagnosticLogNanos = now; + return true; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java similarity index 70% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java index 9fdc480727f1..9d9243d6dcc0 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.nifi.processors.aws.kinesis.converter; +package org.apache.nifi.processors.aws.kinesis; import org.apache.nifi.serialization.SimpleRecordSchema; import org.apache.nifi.serialization.record.MapRecord; @@ -22,16 +22,15 @@ import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; import java.util.HashMap; import java.util.List; import java.util.Map; -final class KinesisRecordMetadata { +public final class KinesisRecordMetadata { - static final String METADATA = "kinesisMetadata"; - static final String APPROX_ARRIVAL_TIMESTAMP = "approximateArrival"; + public static final String METADATA = "kinesisMetadata"; + public static final String APPROX_ARRIVAL_TIMESTAMP = "approximateArrival"; private static final String STREAM = "stream"; private static final String SHARD_ID = "shardId"; @@ -49,28 +48,28 @@ final class KinesisRecordMetadata { private static final RecordField FIELD_APPROX_ARRIVAL_TIMESTAMP = new RecordField(APPROX_ARRIVAL_TIMESTAMP, RecordFieldType.TIMESTAMP.getDataType()); private static final RecordSchema SCHEMA_METADATA = new SimpleRecordSchema(List.of( - FIELD_STREAM, - FIELD_SHARD_ID, - FIELD_SEQUENCE_NUMBER, - FIELD_SUB_SEQUENCE_NUMBER, - FIELD_SHARDED_SEQUENCE_NUMBER, - FIELD_PARTITION_KEY, - FIELD_APPROX_ARRIVAL_TIMESTAMP)); + FIELD_STREAM, + FIELD_SHARD_ID, + FIELD_SEQUENCE_NUMBER, + FIELD_SUB_SEQUENCE_NUMBER, + FIELD_SHARDED_SEQUENCE_NUMBER, + FIELD_PARTITION_KEY, + FIELD_APPROX_ARRIVAL_TIMESTAMP)); - static final RecordField FIELD_METADATA = new RecordField(METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)); + public static final RecordField FIELD_METADATA = new RecordField(METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)); - static Record composeMetadataObject(final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { + public static Record composeMetadataObject(final DeaggregatedRecord record, final String streamName, final String shardId) { final Map metadata = new HashMap<>(7, 1.0f); metadata.put(STREAM, streamName); metadata.put(SHARD_ID, shardId); - metadata.put(SEQUENCE_NUMBER, kinesisRecord.sequenceNumber()); - metadata.put(SUB_SEQUENCE_NUMBER, kinesisRecord.subSequenceNumber()); - metadata.put(SHARDED_SEQUENCE_NUMBER, "%s%020d".formatted(kinesisRecord.sequenceNumber(), kinesisRecord.subSequenceNumber())); - metadata.put(PARTITION_KEY, kinesisRecord.partitionKey()); + metadata.put(SEQUENCE_NUMBER, record.sequenceNumber()); + metadata.put(SUB_SEQUENCE_NUMBER, record.subSequenceNumber()); + metadata.put(SHARDED_SEQUENCE_NUMBER, "%s%020d".formatted(record.sequenceNumber(), record.subSequenceNumber())); + metadata.put(PARTITION_KEY, record.partitionKey()); - if (kinesisRecord.approximateArrivalTimestamp() != null) { - metadata.put(APPROX_ARRIVAL_TIMESTAMP, kinesisRecord.approximateArrivalTimestamp().toEpochMilli()); + if (record.approximateArrivalTimestamp() != null) { + metadata.put(APPROX_ARRIVAL_TIMESTAMP, record.approximateArrivalTimestamp().toEpochMilli()); } return new MapRecord(SCHEMA_METADATA, metadata); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java new file mode 100644 index 000000000000..322389317982 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; +import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.math.BigInteger; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Coordinates shard ownership and checkpoints across clustered processor instances using + * a single DynamoDB table. + * + *

The table stores two record types under the same stream hash key: + *

    + *
  • Shard rows: key {@code (streamName, shardId)} with lease/checkpoint fields
  • + *
  • Node heartbeat rows: key {@code (streamName, "__node__#")}
  • + *
+ * + *

Lease lifecycle used during refresh: + *

    + *
  1. Discover active shard leases and identify currently available shards
  2. + *
  3. Compute fair-share target from active node heartbeats
  4. + *
  5. If this node is over target, mark excess shards for graceful relinquish
  6. + *
  7. Continue renewing leases for owned shards and draining relinquishing shards
  8. + *
  9. After drain deadline, explicitly release relinquishing shards
  10. + *
  11. Acquire available shards until fair-share target is reached
  12. + *
+ * + *

Graceful relinquish is designed to reduce duplicate replay at rebalance boundaries by: + * (a) stopping new fetches immediately (shard removed from {@code ownedShards}), + * (b) briefly retaining lease ownership to allow in-flight work to finish, + * then (c) explicitly releasing the lease for fast handoff. + * + *

Shard split/merge: this implementation does not enforce parent-before-child + * ordering. When a shard is split or shards are merged, child shards become eligible for + * consumption immediately alongside any still-active parent shards. Callers that require strict + * ordering across split/merge boundaries would need to defer child shard assignment until the + * parent shard's {@code SHARD_END} has been reached and checkpointed. + */ +final class KinesisShardManager { + + private static final long DEFAULT_SHARD_CACHE_MILLIS = 60_000; + private static final long DEFAULT_LEASE_DURATION_MILLIS = 30_000; + private static final long DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS = 10_000; + private static final long DEFAULT_NODE_HEARTBEAT_EXPIRATION_MILLIS = + DEFAULT_LEASE_DURATION_MILLIS + DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS; + + private final KinesisClient kinesisClient; + private final DynamoDbClient dynamoDbClient; + private final ComponentLog logger; + private final String nodeId; + private final String checkpointTableName; + private final String streamName; + private final long shardCacheMillis; + private final long leaseDurationMillis; + private final long leaseRefreshIntervalMillis; + private final long nodeHeartbeatExpirationMillis; + private final long relinquishDrainMillis; + + private volatile ShardCache shardCache = new ShardCache(List.of(), Instant.EPOCH); + private final Set ownedShards = ConcurrentHashMap.newKeySet(); + private final Map pendingRelinquishDeadlines = new ConcurrentHashMap<>(); + private final AtomicBoolean leaseRefreshInProgress = new AtomicBoolean(false); + private final Map highestWrittenCheckpoints = new ConcurrentHashMap<>(); + private volatile Instant lastLeaseRefresh = Instant.EPOCH; + private volatile String activeCheckpointTableName; + + KinesisShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, final ComponentLog logger, + final String checkpointTableName, final String streamName) { + this(kinesisClient, dynamoDbClient, logger, checkpointTableName, streamName, + DEFAULT_SHARD_CACHE_MILLIS, DEFAULT_LEASE_DURATION_MILLIS, + DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS, DEFAULT_NODE_HEARTBEAT_EXPIRATION_MILLIS); + } + + KinesisShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, final ComponentLog logger, + final String checkpointTableName, final String streamName, final long shardCacheMillis, + final long leaseDurationMillis, final long leaseRefreshIntervalMillis, final long nodeHeartbeatExpirationMillis) { + this.kinesisClient = kinesisClient; + this.dynamoDbClient = dynamoDbClient; + this.logger = logger; + this.nodeId = UUID.randomUUID().toString(); + this.checkpointTableName = checkpointTableName; + this.streamName = streamName; + this.shardCacheMillis = shardCacheMillis; + this.leaseDurationMillis = leaseDurationMillis; + this.leaseRefreshIntervalMillis = leaseRefreshIntervalMillis; + this.nodeHeartbeatExpirationMillis = nodeHeartbeatExpirationMillis; + this.relinquishDrainMillis = Math.max(2_000L, leaseRefreshIntervalMillis); + this.activeCheckpointTableName = checkpointTableName; + } + + void ensureCheckpointTableExists() { + final CheckpointTableUtils.TableSchema currentSchema = + CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName); + logger.debug("Checkpoint table [{}] detected as {}", checkpointTableName, currentSchema); + + final LegacyCheckpointMigrator migrator = + new LegacyCheckpointMigrator(dynamoDbClient, checkpointTableName, streamName, nodeId, logger); + + switch (currentSchema) { + case NOT_FOUND -> { + final String orphanedMigration = migrator.findMigrationTable(); + if (orphanedMigration == null) { + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + } else { + logger.info("Found orphaned migration table [{}]; renaming to [{}]", + orphanedMigration, checkpointTableName); + migrator.renameMigrationTable(orphanedMigration); + } + } + case NEW -> { + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + migrator.cleanupLingeringMigration(); + } + case LEGACY -> { + migrator.migrateAndRename(); + } + default -> throw new ProcessException( + "Unsupported DynamoDB schema for checkpoint table [%s]".formatted(checkpointTableName)); + } + + activeCheckpointTableName = checkpointTableName; + logger.info("Using checkpoint table [{}] for stream [{}]", activeCheckpointTableName, streamName); + } + + List getShards() { + final ShardCache current = shardCache; + if (!current.shards().isEmpty() && Instant.now().toEpochMilli() < current.refreshTime().toEpochMilli() + shardCacheMillis) { + return current.shards(); + } + return refreshShards(); + } + + int getCachedShardCount() { + return shardCache.shards().size(); + } + + Set getOwnedShardIds() { + return Collections.unmodifiableSet(ownedShards); + } + + List getOwnedShards() { + final List result = new ArrayList<>(); + for (final Shard shard : getShards()) { + if (ownedShards.contains(shard.shardId())) { + result.add(shard); + } + } + return result; + } + + boolean shouldProcessFetchedResult(final String shardId) { + return ownedShards.contains(shardId) || pendingRelinquishDeadlines.containsKey(shardId); + } + + void refreshLeasesIfNecessary(final int clusterMemberCount) { + if (Instant.now().toEpochMilli() < leaseRefreshIntervalMillis + lastLeaseRefresh.toEpochMilli()) { + return; + } + + if (!leaseRefreshInProgress.compareAndSet(false, true)) { + return; + } + + try { + final List allShards = getShards(); + final Set currentShardIds = new HashSet<>(); + for (final Shard shard : allShards) { + currentShardIds.add(shard.shardId()); + } + + final long now = Instant.now().toEpochMilli(); + updateNodeHeartbeat(now); + final Map> ownerToShards = new HashMap<>(); + final List availableShardIds = new ArrayList<>(); + + final Map> leaseItemsByShardId = queryAllLeaseItems(); + for (final String shardId : currentShardIds) { + final Map item = leaseItemsByShardId.get(shardId); + if (item != null && item.containsKey("leaseOwner")) { + final String owner = item.get("leaseOwner").s(); + final AttributeValue expiryAttr = item.get("leaseExpiry"); + final long expiry = expiryAttr != null ? Long.parseLong(expiryAttr.n()) : 0; + if (expiry >= now) { + ownerToShards.computeIfAbsent(owner, k -> new ArrayList<>()).add(shardId); + } else { + availableShardIds.add(shardId); + } + } else { + availableShardIds.add(shardId); + } + } + + final Set stillOwned = new HashSet<>(ownerToShards.getOrDefault(nodeId, List.of())); + ownedShards.retainAll(stillOwned); + pendingRelinquishDeadlines.keySet().removeIf(shardId -> !currentShardIds.contains(shardId) || !stillOwned.contains(shardId)); + + final int heartbeatNodes = countActiveNodes(now); + final int totalOwners = Math.max(heartbeatNodes, Math.max(1, clusterMemberCount)); + final int targetCount = (allShards.size() + totalOwners - 1) / totalOwners; + + final List currentlyOwned = new ArrayList<>(ownerToShards.getOrDefault(nodeId, List.of())); + final int excessCount = Math.max(0, currentlyOwned.size() - targetCount); + if (excessCount > 0) { + for (int index = 0; index < excessCount && index < currentlyOwned.size(); index++) { + final String shardToRelinquish = currentlyOwned.get(index); + if (!pendingRelinquishDeadlines.containsKey(shardToRelinquish)) { + pendingRelinquishDeadlines.put(shardToRelinquish, now + relinquishDrainMillis); + ownedShards.remove(shardToRelinquish); + logger.info("Starting graceful relinquish for shard {}", shardToRelinquish); + } + } + } + + for (final String shardId : ownedShards) { + tryAcquireLease(shardId); + } + + for (final Map.Entry pendingRelinquish : new ArrayList<>(pendingRelinquishDeadlines.entrySet())) { + final String shardId = pendingRelinquish.getKey(); + final long relinquishDeadline = pendingRelinquish.getValue(); + if (!stillOwned.contains(shardId)) { + pendingRelinquishDeadlines.remove(shardId); + continue; + } + + if (now < relinquishDeadline) { + tryAcquireLease(shardId); + } else { + try { + releaseLease(shardId); + pendingRelinquishDeadlines.remove(shardId); + logger.info("Completed graceful relinquish for shard {}", shardId); + } catch (final Exception e) { + logger.warn("Failed to complete graceful relinquish for shard {}", shardId, e); + } + } + } + + for (final String shardId : availableShardIds) { + if (ownedShards.size() >= targetCount) { + break; + } + if (pendingRelinquishDeadlines.containsKey(shardId)) { + continue; + } + if (tryAcquireLease(shardId)) { + ownedShards.add(shardId); + } + } + + ownedShards.removeIf(id -> !currentShardIds.contains(id)); + lastLeaseRefresh = Instant.now(); + } finally { + leaseRefreshInProgress.set(false); + } + } + + String readCheckpoint(final String shardId) { + final GetItemRequest getItemRequest = GetItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .consistentRead(true) + .build(); + final GetItemResponse response = dynamoDbClient.getItem(getItemRequest); + + if (response.hasItem() && response.item().containsKey("sequenceNumber")) { + final String value = response.item().get("sequenceNumber").s(); + if (isValidSequenceNumber(value)) { + logger.debug("Read checkpoint for shard {}: {}", shardId, value); + return value; + } + logger.warn("Ignoring non-numeric checkpoint [{}] for shard {} in table [{}]", value, shardId, activeCheckpointTableName); + } else { + logger.debug("No checkpoint found for shard {} in table [{}]", shardId, activeCheckpointTableName); + } + return null; + } + + void writeCheckpoints(final Map checkpoints) { + for (final Map.Entry entry : checkpoints.entrySet()) { + writeCheckpoint(entry.getKey(), entry.getValue()); + } + } + + void releaseAllLeases() { + final Set shardsToRelease = new HashSet<>(ownedShards); + shardsToRelease.addAll(pendingRelinquishDeadlines.keySet()); + for (final String shardId : shardsToRelease) { + try { + releaseLease(shardId); + } catch (final Exception e) { + logger.warn("Failed to release lease for shard {}", shardId, e); + } + } + + try { + final UpdateItemRequest heartbeatReleaseRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(nodeHeartbeatKey()) + .updateExpression("REMOVE nodeHeartbeat, lastUpdateTimestamp") + .build(); + dynamoDbClient.updateItem(heartbeatReleaseRequest); + } catch (final Exception e) { + logger.debug("Failed to clear node heartbeat record for node {}", nodeId, e); + } + } + + void close() { + ownedShards.clear(); + pendingRelinquishDeadlines.clear(); + highestWrittenCheckpoints.clear(); + leaseRefreshInProgress.set(false); + lastLeaseRefresh = Instant.EPOCH; + shardCache = new ShardCache(List.of(), Instant.EPOCH); + } + + private synchronized List refreshShards() { + final ShardCache current = shardCache; + if (!current.shards().isEmpty() && Instant.now().toEpochMilli() < current.refreshTime().toEpochMilli() + shardCacheMillis) { + return current.shards(); + } + + final List allShards = new ArrayList<>(); + String nextToken = null; + do { + final ListShardsRequest request = nextToken != null + ? ListShardsRequest.builder().nextToken(nextToken).build() + : ListShardsRequest.builder().streamName(streamName).build(); + + final ListShardsResponse response = kinesisClient.listShards(request); + allShards.addAll(response.shards()); + nextToken = response.nextToken(); + } while (nextToken != null); + + logger.debug("ListShards returned {} shards for stream {}", allShards.size(), streamName); + shardCache = new ShardCache(allShards, Instant.now()); + return allShards; + } + + private static boolean isValidSequenceNumber(final String value) { + if (value == null || value.isEmpty()) { + return false; + } + for (int idx = 0; idx < value.length(); idx++) { + if (!Character.isDigit(value.charAt(idx))) { + return false; + } + } + return true; + } + + private boolean tryAcquireLease(final String shardId) { + final long now = Instant.now().toEpochMilli(); + final long expiry = now + leaseDurationMillis; + final AttributeValue ownerVal = AttributeValue.builder().s(nodeId).build(); + final AttributeValue expiryVal = AttributeValue.builder().n(String.valueOf(expiry)).build(); + final AttributeValue nowVal = AttributeValue.builder().n(String.valueOf(now)).build(); + + try { + final UpdateItemRequest leaseRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("SET leaseOwner = :owner, leaseExpiry = :exp, lastUpdateTimestamp = :ts") + .conditionExpression("attribute_not_exists(leaseOwner) OR leaseOwner = :owner OR leaseExpiry < :now") + .expressionAttributeValues(Map.of( + ":owner", ownerVal, + ":exp", expiryVal, + ":now", nowVal, + ":ts", nowVal)) + .build(); + dynamoDbClient.updateItem(leaseRequest); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } catch (final Exception e) { + logger.warn("Failed to acquire lease for shard {}", shardId, e); + return false; + } + } + + private void updateNodeHeartbeat(final long now) { + final UpdateItemRequest heartbeatRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(nodeHeartbeatKey()) + .updateExpression("SET nodeHeartbeat = :heartbeat, lastUpdateTimestamp = :ts") + .expressionAttributeValues(Map.of( + ":heartbeat", AttributeValue.builder().n(String.valueOf(now)).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(heartbeatRequest); + } + + private Map> queryAllLeaseItems() { + final Map> itemsByShardId = new HashMap<>(); + final QueryRequest.Builder queryBuilder = QueryRequest.builder() + .tableName(activeCheckpointTableName) + .consistentRead(true) + .keyConditionExpression("streamName = :streamName") + .expressionAttributeValues(Map.of( + ":streamName", AttributeValue.builder().s(streamName).build())); + + Map exclusiveStartKey = null; + do { + final QueryRequest queryRequest = exclusiveStartKey == null + ? queryBuilder.build() + : queryBuilder.exclusiveStartKey(exclusiveStartKey).build(); + final QueryResponse queryResponse = dynamoDbClient.query(queryRequest); + for (final Map item : queryResponse.items()) { + final AttributeValue shardIdAttr = item.get("shardId"); + if (shardIdAttr == null) { + continue; + } + final String shardId = shardIdAttr.s(); + if (shardId.startsWith(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX) + || CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID.equals(shardId)) { + continue; + } + itemsByShardId.put(shardId, item); + } + + exclusiveStartKey = queryResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + return itemsByShardId; + } + + private int countActiveNodes(final long now) { + final QueryRequest.Builder queryRequestBuilder = QueryRequest.builder() + .tableName(activeCheckpointTableName) + .consistentRead(true) + .keyConditionExpression("streamName = :streamName AND begins_with(shardId, :nodePrefix)") + .expressionAttributeValues(Map.of( + ":streamName", AttributeValue.builder().s(streamName).build(), + ":nodePrefix", AttributeValue.builder().s(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX).build())); + + Map exclusiveStartKey = null; + int activeNodes = 0; + do { + final QueryRequest queryRequest = exclusiveStartKey == null + ? queryRequestBuilder.build() + : queryRequestBuilder.exclusiveStartKey(exclusiveStartKey).build(); + final QueryResponse queryResponse = dynamoDbClient.query(queryRequest); + for (final Map item : queryResponse.items()) { + final AttributeValue heartbeatValue = item.get("nodeHeartbeat"); + if (heartbeatValue == null) { + continue; + } + + final long heartbeatMillis = Long.parseLong(heartbeatValue.n()); + if (now <= heartbeatMillis + nodeHeartbeatExpirationMillis) { + activeNodes++; + } + } + + exclusiveStartKey = queryResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + return Math.max(1, activeNodes); + } + + private void writeCheckpoint(final String shardId, final ShardCheckpoint checkpoint) { + final BigInteger incomingSeq = new BigInteger(checkpoint.sequenceNumber()); + + final ShardCheckpoint written = highestWrittenCheckpoints.compute(shardId, (key, existing) -> { + if (existing != null && checkpoint.max(existing) == existing) { + return existing; + } + + try { + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest checkpointRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("SET sequenceNumber = :seq, subSequenceNumber = :subSeq," + + " lastUpdateTimestamp = :ts, leaseExpiry = :exp") + .conditionExpression("leaseOwner = :owner") + .expressionAttributeValues(Map.of( + ":seq", AttributeValue.builder().s(checkpoint.sequenceNumber()).build(), + ":subSeq", AttributeValue.builder().n(String.valueOf(checkpoint.subSequenceNumber())).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build(), + ":exp", AttributeValue.builder().n(String.valueOf(now + leaseDurationMillis)).build(), + ":owner", AttributeValue.builder().s(nodeId).build())) + .build(); + dynamoDbClient.updateItem(checkpointRequest); + logger.debug("Checkpointed shard {} at sequence {}/{}", shardId, checkpoint.sequenceNumber(), checkpoint.subSequenceNumber()); + return checkpoint; + } catch (final ConditionalCheckFailedException e) { + logger.warn("Lost lease on shard {} during checkpoint; another node may have taken it", shardId); + } catch (final Exception e) { + logger.error("Failed to write checkpoint for shard {}", shardId, e); + } + return existing; + }); + + if (written != null && incomingSeq.compareTo(new BigInteger(written.sequenceNumber())) < 0) { + logger.debug("Skipped checkpoint regression for shard {} (highest: {}, attempted: {})", shardId, written.sequenceNumber(), checkpoint.sequenceNumber()); + } + } + + private Map checkpointKey(final String shardId) { + return Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(shardId).build()); + } + + private void releaseLease(final String shardId) { + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("REMOVE leaseOwner, leaseExpiry") + .conditionExpression("leaseOwner = :owner") + .expressionAttributeValues(Map.of(":owner", AttributeValue.builder().s(nodeId).build())) + .build(); + dynamoDbClient.updateItem(request); + } + + private Map nodeHeartbeatKey() { + return checkpointKey(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX + nodeId); + } + + private record ShardCache(List shards, Instant refreshTime) { } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java new file mode 100644 index 000000000000..0e9918f129c7 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.WireFormat; +import software.amazon.awssdk.services.kinesis.model.Record; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Deaggregates KPL (Kinesis Producer Library) aggregated records into individual user records. + * + *

KPL aggregation packs multiple user records into a single Kinesis record using a protobuf + * envelope with a 4-byte magic header and a 16-byte MD5 trailer. Non-aggregated records + * pass through unchanged as a single {@link DeaggregatedRecord} with {@code subSequenceNumber=0}. + * + *

If a record has the magic header but fails MD5 verification or protobuf parsing, it falls + * back to passthrough to avoid data loss. + * + *

We could make direct use of the KPL's protobuf definition and generated classes, + * but doing so requires bringing in 20+ transitive dependencies. Since the protobuf format is + * simple and well-documented, we implement a minimal custom parser using the protobuf wire format + * instead. Additionally, we have integration tests to verify compatibility. + * + * @see KPL Aggregation Format + */ +final class KplDeaggregator { + + static final byte[] KPL_MAGIC = {(byte) 0xF3, (byte) 0x89, (byte) 0x9A, (byte) 0xC2}; + private static final int MD5_DIGEST_LENGTH = 16; + private static final int MIN_AGGREGATED_LENGTH = KPL_MAGIC.length + MD5_DIGEST_LENGTH + 1; + + private static final int FIELD_PARTITION_KEY_TABLE = 1; + private static final int FIELD_EXPLICIT_HASH_KEY_TABLE = 2; + private static final int FIELD_RECORDS = 3; + + private static final int RECORD_FIELD_PARTITION_KEY_INDEX = 1; + private static final int RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX = 2; + private static final int RECORD_FIELD_DATA = 3; + + private KplDeaggregator() { + } + + /** + * Deaggregates a list of Kinesis records, expanding any KPL-aggregated records into + * their constituent sub-records. + * + * @param shardId the shard these records were fetched from + * @param records raw Kinesis records from the API + * @return list of deaggregated records preserving original order + */ + static List deaggregate(final String shardId, final List records) { + final List result = new ArrayList<>(); + for (final Record record : records) { + deaggregateRecord(shardId, record, result); + } + return result; + } + + private static void deaggregateRecord(final String shardId, final Record record, final List out) { + final byte[] data = record.data().asByteArrayUnsafe(); + + if (!isAggregated(data)) { + out.add(passthrough(shardId, record, data)); + return; + } + + final byte[] protobufBytes = Arrays.copyOfRange(data, KPL_MAGIC.length, data.length - MD5_DIGEST_LENGTH); + final byte[] trailingMd5 = Arrays.copyOfRange(data, data.length - MD5_DIGEST_LENGTH, data.length); + + if (!verifyMd5(protobufBytes, trailingMd5)) { + out.add(passthrough(shardId, record, data)); + return; + } + + try { + parseAggregatedRecord(shardId, record, protobufBytes, out); + } catch (final Exception e) { + out.add(passthrough(shardId, record, data)); + } + } + + static boolean isAggregated(final byte[] data) { + if (data.length < MIN_AGGREGATED_LENGTH) { + return false; + } + return data[0] == KPL_MAGIC[0] + && data[1] == KPL_MAGIC[1] + && data[2] == KPL_MAGIC[2] + && data[3] == KPL_MAGIC[3]; + } + + private static boolean verifyMd5(final byte[] protobufBytes, final byte[] expectedMd5) { + try { + final byte[] computed = MessageDigest.getInstance("MD5").digest(protobufBytes); + return Arrays.equals(computed, expectedMd5); + } catch (final NoSuchAlgorithmException e) { + return false; + } + } + + private static void parseAggregatedRecord(final String shardId, final Record kinesisRecord, final byte[] protobufBytes, + final List out) throws Exception { + final List partitionKeyTable = new ArrayList<>(); + final List subRecordDataList = new ArrayList<>(); + final List subRecordPkIndexList = new ArrayList<>(); + + final CodedInputStream input = CodedInputStream.newInstance(protobufBytes); + while (!input.isAtEnd()) { + final int tag = input.readTag(); + final int fieldNumber = WireFormat.getTagFieldNumber(tag); + switch (fieldNumber) { + case FIELD_PARTITION_KEY_TABLE: + partitionKeyTable.add(input.readString()); + break; + case FIELD_EXPLICIT_HASH_KEY_TABLE: + input.readString(); + break; + case FIELD_RECORDS: + final int length = input.readRawVarint32(); + final int oldLimit = input.pushLimit(length); + int pkIndex = 0; + byte[] subData = new byte[0]; + while (!input.isAtEnd()) { + final int innerTag = input.readTag(); + final int innerField = WireFormat.getTagFieldNumber(innerTag); + switch (innerField) { + case RECORD_FIELD_PARTITION_KEY_INDEX: + pkIndex = (int) input.readUInt64(); + break; + case RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX: + input.readUInt64(); + break; + case RECORD_FIELD_DATA: + subData = input.readByteArray(); + break; + default: + input.skipField(innerTag); + break; + } + } + input.popLimit(oldLimit); + subRecordDataList.add(subData); + subRecordPkIndexList.add(pkIndex); + break; + default: + input.skipField(tag); + break; + } + } + + final String sequenceNumber = kinesisRecord.sequenceNumber(); + final Instant arrival = kinesisRecord.approximateArrivalTimestamp(); + final String fallbackPartitionKey = kinesisRecord.partitionKey(); + + for (int i = 0; i < subRecordDataList.size(); i++) { + final int pkIdx = subRecordPkIndexList.get(i); + final String partitionKey = pkIdx < partitionKeyTable.size() + ? partitionKeyTable.get(pkIdx) + : fallbackPartitionKey; + out.add(new DeaggregatedRecord(shardId, sequenceNumber, i, partitionKey, subRecordDataList.get(i), arrival)); + } + } + + private static DeaggregatedRecord passthrough(final String shardId, final Record record, final byte[] data) { + return new DeaggregatedRecord( + shardId, + record.sequenceNumber(), + 0, + record.partitionKey(), + data, + record.approximateArrivalTimestamp()); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java new file mode 100644 index 000000000000..69ed6065584a --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; + +import java.time.Instant; +import java.util.Map; + +/** + * Handles one-time migration of checkpoint data from legacy KCL-format DynamoDB tables + * to the new composite-key schema used by {@link KinesisShardManager}, including the + * table rename lifecycle. Uses distributed locks in the target table to coordinate + * migration and rename operations across clustered nodes. + */ +final class LegacyCheckpointMigrator { + + private static final String MIGRATION_TABLE_SUFFIX = "_migration"; + private static final String LEGACY_LEASE_KEY_ATTRIBUTE = "leaseKey"; + private static final String LEGACY_CHECKPOINT_ATTRIBUTE = "checkpoint"; + private static final String MIGRATION_STATUS_ATTRIBUTE = "migrationStatus"; + private static final String MIGRATION_STATUS_IN_PROGRESS = "IN_PROGRESS"; + private static final String MIGRATION_STATUS_COMPLETE = "COMPLETE"; + private static final long MIGRATION_LOCK_STALE_MILLIS = 600_000; + private static final long MIGRATION_WAIT_MILLIS = 2_000; + private static final int MIGRATION_WAIT_MAX_ATTEMPTS = 180; + private static final long RENAME_LOCK_STALE_MILLIS = 120_000; + private static final long RENAME_POLL_MILLIS = 1_000; + private static final int RENAME_POLL_MAX_ATTEMPTS = 60; + + private final DynamoDbClient dynamoDbClient; + private final String checkpointTableName; + private final String streamName; + private final String nodeId; + private final ComponentLog logger; + + LegacyCheckpointMigrator(final DynamoDbClient dynamoDbClient, final String checkpointTableName, + final String streamName, final String nodeId, final ComponentLog logger) { + this.dynamoDbClient = dynamoDbClient; + this.checkpointTableName = checkpointTableName; + this.streamName = streamName; + this.nodeId = nodeId; + this.logger = logger; + } + + String findMigrationTable() { + final String migrationTableName = checkpointTableName + MIGRATION_TABLE_SUFFIX; + if (CheckpointTableUtils.getTableSchema(dynamoDbClient, migrationTableName) + == CheckpointTableUtils.TableSchema.NEW) { + return migrationTableName; + } + return null; + } + + void cleanupLingeringMigration() { + final String lingeringMigration = findMigrationTable(); + if (lingeringMigration == null) { + return; + } + logger.info("Cleaning up lingering migration table [{}]", lingeringMigration); + CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, lingeringMigration, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, lingeringMigration); + } + + void migrateAndRename() { + final String existingMigration = findMigrationTable(); + final String migrationTableName; + + if (existingMigration == null) { + migrationTableName = checkpointTableName + MIGRATION_TABLE_SUFFIX; + logger.info("Legacy checkpoint table detected; migrating via [{}]", migrationTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, migrationTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, migrationTableName); + migrateCheckpoints(checkpointTableName, migrationTableName); + } else { + migrationTableName = existingMigration; + logger.info("Found existing migration table [{}]; completing rename to [{}]", + migrationTableName, checkpointTableName); + } + + renameMigrationTable(migrationTableName); + } + + void renameMigrationTable(final String migrationTableName) { + if (acquireRenameLock(migrationTableName)) { + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableDeleted(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, migrationTableName, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, migrationTableName); + } else { + waitForTableRenamed(migrationTableName); + } + } + + private boolean acquireRenameLock(final String migrationTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final Map key = migrationMarkerKey(); + final Map values = Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build()); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(migrationTableName) + .key(key) + .updateExpression("SET renameOwner = :owner, renameStartedAt = :now") + .conditionExpression("attribute_not_exists(renameOwner)") + .expressionAttributeValues(values) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Acquired rename lock for migration table [{}]", migrationTableName); + return true; + } catch (final ConditionalCheckFailedException e) { + logger.debug("Rename lock already held for migration table [{}]", migrationTableName); + return false; + } catch (final ResourceNotFoundException e) { + logger.debug("Migration table [{}] already deleted; rename must be complete", migrationTableName); + return false; + } + } + + private boolean forceAcquireStaleRenameLock(final String migrationTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final long staleThreshold = now - RENAME_LOCK_STALE_MILLIS; + final Map key = migrationMarkerKey(); + final Map values = Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build(), + ":staleThreshold", AttributeValue.builder().n(String.valueOf(staleThreshold)).build()); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(migrationTableName) + .key(key) + .updateExpression("SET renameOwner = :owner, renameStartedAt = :now") + .conditionExpression("attribute_exists(renameOwner) AND renameStartedAt < :staleThreshold") + .expressionAttributeValues(values) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Force-acquired stale rename lock for migration table [{}]", migrationTableName); + return true; + } catch (final ConditionalCheckFailedException | ResourceNotFoundException e) { + return false; + } + } + + private Map migrationMarkerKey() { + return Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build()); + } + + private void waitForTableRenamed(final String migrationTableName) { + for (int i = 0; i < RENAME_POLL_MAX_ATTEMPTS; i++) { + if (CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName) + == CheckpointTableUtils.TableSchema.NEW) { + logger.info("Migration table rename complete; table [{}] is now available", checkpointTableName); + return; + } + + try { + Thread.sleep(RENAME_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for migration table rename to complete", e); + } + } + + logger.warn("Timed out waiting for migration table rename; attempting stale lock takeover"); + if (forceAcquireStaleRenameLock(migrationTableName)) { + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableDeleted(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, migrationTableName, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, migrationTableName); + } else if (CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName) + == CheckpointTableUtils.TableSchema.NEW) { + logger.info("Migration table rename completed during takeover attempt"); + } else { + throw new ProcessException( + "Unable to complete migration table rename for [%s]".formatted(checkpointTableName)); + } + } + + private void migrateCheckpoints(final String sourceTableName, final String targetTableName) { + if (isMigrationComplete(targetTableName)) { + logger.debug("Legacy checkpoint migration already complete for stream [{}]", streamName); + return; + } + + if (acquireMigrationLock(targetTableName)) { + performMigration(sourceTableName, targetTableName); + } else { + waitForMigrationComplete(targetTableName); + if (!isMigrationComplete(targetTableName)) { + logger.warn("Migration wait timed out with stale lock; attempting takeover for stream [{}]", streamName); + if (forceAcquireStaleMigrationLock(targetTableName)) { + performMigration(sourceTableName, targetTableName); + } else { + throw new ProcessException( + "Unable to acquire migration lock for stream [%s]".formatted(streamName)); + } + } + } + } + + private void performMigration(final String sourceTableName, final String targetTableName) { + try { + migrateLegacyCheckpoints(sourceTableName, targetTableName); + markMigrationComplete(targetTableName); + } catch (final Exception e) { + clearMigrationLock(targetTableName); + throw new ProcessException("Failed to migrate legacy checkpoints to [%s]".formatted(targetTableName), e); + } + } + + private boolean acquireMigrationLock(final String targetTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final PutItemRequest request = PutItemRequest.builder() + .tableName(targetTableName) + .item(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build(), + MIGRATION_STATUS_ATTRIBUTE, AttributeValue.builder().s(MIGRATION_STATUS_IN_PROGRESS).build(), + "migrationOwner", AttributeValue.builder().s(nodeId).build(), + "migrationStartedAt", AttributeValue.builder().n(String.valueOf(now)).build(), + "lastUpdateTimestamp", AttributeValue.builder().n(String.valueOf(now)).build())) + .conditionExpression("attribute_not_exists(streamName)") + .build(); + dynamoDbClient.putItem(request); + logger.info("Acquired checkpoint migration lock for stream [{}]", streamName); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } + } + + private boolean forceAcquireStaleMigrationLock(final String targetTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final long staleThreshold = now - MIGRATION_LOCK_STALE_MILLIS; + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("SET migrationOwner = :owner, migrationStartedAt = :now, lastUpdateTimestamp = :now") + .conditionExpression("migrationStatus = :inProgress AND migrationStartedAt < :staleThreshold") + .expressionAttributeValues(Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build(), + ":inProgress", AttributeValue.builder().s(MIGRATION_STATUS_IN_PROGRESS).build(), + ":staleThreshold", AttributeValue.builder().n(String.valueOf(staleThreshold)).build())) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Force-acquired stale migration lock for stream [{}]", streamName); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } + } + + private void migrateLegacyCheckpoints(final String sourceTableName, final String targetTableName) { + logger.info("Starting legacy checkpoint migration from [{}] to [{}] for stream [{}]", + sourceTableName, targetTableName, streamName); + + Map exclusiveStartKey = null; + int scanned = 0; + int migrated = 0; + int skippedMissingAttr = 0; + int skippedNonNumeric = 0; + do { + final ScanRequest scanRequest = ScanRequest.builder() + .tableName(sourceTableName) + .exclusiveStartKey(exclusiveStartKey) + .build(); + final ScanResponse scanResponse = dynamoDbClient.scan(scanRequest); + + for (final Map item : scanResponse.items()) { + scanned++; + final AttributeValue leaseKeyAttr = item.get(LEGACY_LEASE_KEY_ATTRIBUTE); + final AttributeValue checkpointAttr = item.get(LEGACY_CHECKPOINT_ATTRIBUTE); + if (leaseKeyAttr == null || checkpointAttr == null) { + skippedMissingAttr++; + logger.warn("Skipping legacy item missing leaseKey or checkpoint: keys={}", item.keySet()); + continue; + } + + final String shardId = extractShardId(leaseKeyAttr.s()); + if (shardId == null || shardId.isEmpty()) { + skippedMissingAttr++; + continue; + } + + final String checkpoint = checkpointAttr.s(); + if (!isValidSequenceNumber(checkpoint)) { + skippedNonNumeric++; + logger.warn("Skipping non-numeric legacy checkpoint [{}] for shard {}", checkpoint, shardId); + continue; + } + + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(shardId).build())) + .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts") + .expressionAttributeValues(Map.of( + ":seq", AttributeValue.builder().s(checkpoint).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(request); + migrated++; + } + + exclusiveStartKey = scanResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + logger.info("Legacy checkpoint migration complete for stream [{}]: scanned={}, migrated={}, skippedNonNumeric={}, skippedMissingAttr={}", + streamName, scanned, migrated, skippedNonNumeric, skippedMissingAttr); + } + + private static String extractShardId(final String legacyLeaseKey) { + if (legacyLeaseKey == null || legacyLeaseKey.isEmpty()) { + return null; + } + final int separatorIndex = legacyLeaseKey.lastIndexOf(':'); + if (separatorIndex >= 0 && separatorIndex + 1 < legacyLeaseKey.length()) { + return legacyLeaseKey.substring(separatorIndex + 1); + } + return legacyLeaseKey; + } + + private static boolean isValidSequenceNumber(final String value) { + if (value == null || value.isEmpty()) { + return false; + } + for (int idx = 0; idx < value.length(); idx++) { + if (!Character.isDigit(value.charAt(idx))) { + return false; + } + } + return true; + } + + private void markMigrationComplete(final String targetTableName) { + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("SET migrationStatus = :status, migrationCompletedAt = :doneAt, lastUpdateTimestamp = :ts REMOVE migrationOwner") + .expressionAttributeValues(Map.of( + ":status", AttributeValue.builder().s(MIGRATION_STATUS_COMPLETE).build(), + ":doneAt", AttributeValue.builder().n(String.valueOf(now)).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(request); + } + + private void clearMigrationLock(final String targetTableName) { + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("REMOVE migrationStatus, migrationOwner, migrationStartedAt") + .build(); + dynamoDbClient.updateItem(request); + } + + private boolean isMigrationComplete(final String targetTableName) { + final GetItemResponse response = dynamoDbClient.getItem(GetItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .build()); + if (!response.hasItem()) { + return false; + } + final AttributeValue status = response.item().get(MIGRATION_STATUS_ATTRIBUTE); + return status != null && MIGRATION_STATUS_COMPLETE.equals(status.s()); + } + + private void waitForMigrationComplete(final String targetTableName) { + for (int attempt = 0; attempt < MIGRATION_WAIT_MAX_ATTEMPTS; attempt++) { + if (isMigrationComplete(targetTableName)) { + logger.info("Observed completed checkpoint migration for stream [{}]", streamName); + return; + } + + try { + Thread.sleep(MIGRATION_WAIT_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for checkpoint migration to complete", e); + } + } + + logger.warn("Timed out waiting for checkpoint migration to complete for stream [{}]; will check for stale lock", + streamName); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java deleted file mode 100644 index c4d1522a350d..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java +++ /dev/null @@ -1,725 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import jakarta.annotation.Nullable; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferLease; -import software.amazon.kinesis.exceptions.InvalidStateException; -import software.amazon.kinesis.exceptions.KinesisClientLibDependencyException; -import software.amazon.kinesis.exceptions.ShutdownException; -import software.amazon.kinesis.exceptions.ThrottlingException; -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Queue; -import java.util.Random; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - -import static java.util.Collections.emptyList; - -/** - * A record buffer which limits the maximum memory usage across all shard buffers. - * If the memory limit is reached, adding new records will block until enough memory is freed. - */ -final class MemoryBoundRecordBuffer implements RecordBuffer.ForKinesisClientLibrary, RecordBuffer.ForProcessor { - - private final ComponentLog logger; - - private final long checkpointIntervalMillis; - private final BlockingMemoryTracker memoryTracker; - - private final AtomicLong bufferIdCounter = new AtomicLong(0); - - /** - * All shard buffers stored by their ids. - *

- * When a buffer is invalidated, it is removed from this map, but its id may still be present in the buffersToLease queue. - * Since the buffer can be invalidated concurrently, it's possible for some buffer operations to be called - * after the buffer was removed from this map. In that case the operations should take no effect. - */ - private final ConcurrentMap shardBuffers = new ConcurrentHashMap<>(); - - /** - * A queue with ids shard buffers available for leasing. - *

- * Note: when a buffer is invalidated its id is NOT removed from the queue immediately. - */ - private final Queue buffersToLease = new ConcurrentLinkedQueue<>(); - - MemoryBoundRecordBuffer(final ComponentLog logger, final long maxMemoryBytes, final Duration checkpointInterval) { - this.logger = logger; - this.memoryTracker = new BlockingMemoryTracker(logger, maxMemoryBytes); - this.checkpointIntervalMillis = checkpointInterval.toMillis(); - } - - @Override - public ShardBufferId createBuffer(final String shardId) { - final ShardBufferId id = new ShardBufferId(shardId, bufferIdCounter.getAndIncrement()); - - logger.debug("Creating new buffer for shard {} with id {}", shardId, id); - - shardBuffers.put(id, new ShardBuffer(id, logger, checkpointIntervalMillis)); - buffersToLease.add(id); - return id; - } - - @Override - public void addRecords(final ShardBufferId bufferId, final List records, final RecordProcessorCheckpointer checkpointer) { - if (records.isEmpty()) { - return; - } - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot add records with sequence and subsequence numbers: {}.{} - {}.{}", - bufferId, - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber()); - return; - } - - final RecordBatch recordBatch = new RecordBatch(records, checkpointer, calculateMemoryUsage(records)); - memoryTracker.reserveMemory(recordBatch, bufferId); - final boolean addedRecords = buffer.offer(recordBatch); - - if (addedRecords) { - logger.debug("Successfully added records with sequence and subsequence numbers: {}.{} - {}.{} to buffer with id {}", - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber(), - bufferId); - } else { - logger.debug("Buffer with id {} was invalidated. Cannot add records with sequence and subsequence numbers: {}.{} - {}.{}", - bufferId, - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber()); - // If the buffer was invalidated, we should free memory reserved for these records. - memoryTracker.freeMemory(List.of(recordBatch), bufferId); - } - } - - @Override - public void checkpointEndedShard(final ShardBufferId bufferId, final RecordProcessorCheckpointer checkpointer) { - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot checkpoint the ended shard", bufferId); - return; - } - - logger.debug("Finishing consumption for buffer {}. Checkpointing the ended shard", bufferId); - buffer.checkpointEndedShard(checkpointer); - - logger.debug("Removing buffer with id {} after successful ended shard checkpoint", bufferId); - shardBuffers.remove(bufferId); - } - - @Override - public void shutdownShardConsumption(final ShardBufferId bufferId, final RecordProcessorCheckpointer checkpointer) { - final ShardBuffer buffer = shardBuffers.remove(bufferId); - - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot shutdown shard consumption", bufferId); - } else { - logger.debug("Shutting down the buffer {}. Checkpointing last consumed record", bufferId); - final Collection invalidatedBatches = buffer.shutdownBuffer(checkpointer); - memoryTracker.freeMemory(invalidatedBatches, bufferId); - } - } - - @Override - public void consumerLeaseLost(final ShardBufferId bufferId) { - final ShardBuffer buffer = shardBuffers.remove(bufferId); - - if (buffer == null) { - logger.debug("Buffer with id {} not found. Ignoring lease lost event", bufferId); - } else { - logger.debug("Lease lost for buffer {}: Invalidating", bufferId); - final Collection invalidatedBatches = buffer.invalidate(); - memoryTracker.freeMemory(invalidatedBatches, bufferId); - } - } - - @Override - public Optional acquireBufferLease() { - final Set seenBuffers = new HashSet<>(); - - while (true) { - final ShardBufferId bufferId = buffersToLease.poll(); - if (bufferId == null) { - // The queue is empty or all buffers were seen already. Nothing to consume. - return Optional.empty(); - } - - if (seenBuffers.contains(bufferId)) { - // If the same buffer is seen again, there is a high chance we iterated through most of the buffers and didn't find any that isn't empty. - // To avoid burning CPU we return empty here, even if some buffer received records in the meantime. It will be picked up in the next iteration. - buffersToLease.add(bufferId); - return Optional.empty(); - } - - final ShardBuffer buffer = shardBuffers.get(bufferId); - - if (buffer == null) { - // By the time the bufferId is polled, it might have been invalidated. No need to return it to the queue. - logger.debug("Buffer with id {} was removed while polling for lease. Continuing to poll", bufferId); - } else if (buffer.isEmpty()) { - seenBuffers.add(bufferId); - buffersToLease.add(bufferId); - logger.debug("Buffer with id {} is empty. Continuing to poll", bufferId); - } else { - logger.debug("Acquired lease for buffer {}", bufferId); - return Optional.of(new Lease(bufferId)); - } - } - } - - @Override - public List consumeRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to consume records from a buffer that was already returned to the pool. Ignoring"); - return emptyList(); - } - - final ShardBufferId bufferId = lease.bufferId(); - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot consume records", bufferId); - return emptyList(); - } - - return buffer.consumeRecords(); - } - - @Override - public void commitConsumedRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to commit records from a buffer that was already returned to the pool. Ignoring"); - return; - } - - final ShardBufferId bufferId = lease.bufferId(); - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot commit consumed records", bufferId); - return; - } - - logger.debug("Committing consumed records for buffer {}", bufferId); - final List consumedBatches = buffer.commitConsumedRecords(); - memoryTracker.freeMemory(consumedBatches, bufferId); - } - - @Override - public void rollbackConsumedRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to rollback records from a buffer that was already returned to the pool. Ignoring"); - return; - } - - final ShardBufferId bufferId = lease.bufferId(); - final ShardBuffer buffer = shardBuffers.get(bufferId); - - if (buffer != null) { - logger.debug("Rolling back consumed records for buffer {}", bufferId); - buffer.rollbackConsumedRecords(); - } - } - - @Override - public void returnBufferLease(final Lease lease) { - if (lease.returnToPool()) { - final ShardBufferId bufferId = lease.bufferId(); - buffersToLease.add(bufferId); - logger.debug("The buffer {} is available for lease again", bufferId); - } else { - logger.warn("Attempting to return a buffer that was already returned to the pool. Ignoring"); - } - } - - static final class Lease implements ShardBufferLease { - - private final ShardBufferId bufferId; - private final AtomicBoolean returnedToPool = new AtomicBoolean(false); - - private Lease(final ShardBufferId bufferId) { - this.bufferId = bufferId; - } - - @Override - public String shardId() { - return bufferId.shardId(); - } - - private ShardBufferId bufferId() { - return bufferId; - } - - private boolean isReturnedToPool() { - return returnedToPool.get(); - } - - /** - * Marks the lease as returned to the pool. - * @return true if the lease was not returned before, false otherwise. - */ - private boolean returnToPool() { - final boolean wasReturned = returnedToPool.getAndSet(true); - return !wasReturned; - } - } - - /** - * A memory tracker which blocks a thread when the memory usage exceeds the allowed maximum. - *

- * In order to make progress, the memory consumption may exceed the limit, but any new records will not be accepted. - * This is done to support the case when a single record batch is larger than the allowed memory limit. - */ - private static class BlockingMemoryTracker { - - private static final long AWAIT_MILLIS = 100; - - private final ComponentLog logger; - - private final long maxMemoryBytes; - - private final AtomicLong consumedMemoryBytes = new AtomicLong(0); - /** - * Whenever memory is freed a latch opens. Then replaced with a new one. - */ - private final AtomicReference memoryAvailableLatch = new AtomicReference<>(new CountDownLatch(1)); - - BlockingMemoryTracker(final ComponentLog logger, final long maxMemoryBytes) { - this.logger = logger; - this.maxMemoryBytes = maxMemoryBytes; - } - - void reserveMemory(final RecordBatch recordBatch, final ShardBufferId bufferId) { - final long consumedBytes = recordBatch.batchSizeBytes(); - - if (consumedBytes == 0) { - logger.debug("The batch for buffer {} is empty. No need to reserve memory", bufferId); - return; - } - - while (true) { - final long currentlyConsumedBytes = consumedMemoryBytes.get(); - - if (currentlyConsumedBytes >= maxMemoryBytes) { - // Not enough memory available, need to wait. - try { - memoryAvailableLatch.get().await(AWAIT_MILLIS, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for available memory in RecordBuffer", e); - } - } else { - final long newConsumedBytes = currentlyConsumedBytes + consumedBytes; - if (consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, newConsumedBytes)) { - logger.debug("Reserved {} bytes for {} records for buffer {}. Total consumed memory: {} bytes", - consumedBytes, recordBatch.size(), bufferId, newConsumedBytes); - break; - } - // If we're here, the compare and set operation failed, as another thread has modified the gauge in meantime. - // Retrying the operation. - } - } - } - - void freeMemory(final Collection consumedBatches, final ShardBufferId bufferId) { - if (consumedBatches.isEmpty()) { - logger.debug("No batches were consumed from buffer {}. No need to free memory", bufferId); - return; - } - - long freedBytes = 0; - for (final RecordBatch batch : consumedBatches) { - freedBytes += batch.batchSizeBytes(); - } - - while (true) { - final long currentlyConsumedBytes = consumedMemoryBytes.get(); - if (currentlyConsumedBytes < freedBytes) { - throw new IllegalStateException("Attempting to free more memory than currently used"); - } - - final long newTotal = currentlyConsumedBytes - freedBytes; - if (consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, newTotal)) { - logger.debug("Freed {} bytes for {} batches from buffer {}. Total consumed memory: {} bytes", - freedBytes, consumedBatches.size(), bufferId, newTotal); - - final CountDownLatch oldLatch = memoryAvailableLatch.getAndSet(new CountDownLatch(1)); - oldLatch.countDown(); // Release any waiting threads for free memory. - break; - } - // If we're here, the compare and set operation failed, as another thread has modified the gauge in meantime. - // Retrying the operation. - } - } - } - - private record RecordBatch(List records, - RecordProcessorCheckpointer checkpointer, - long batchSizeBytes) { - int size() { - return records.size(); - } - } - - private long calculateMemoryUsage(final Collection records) { - long totalBytes = 0; - for (final KinesisClientRecord record : records) { - final ByteBuffer data = record.data(); - if (data != null) { - totalBytes += data.capacity(); - } - } - return totalBytes; - } - - /** - * ShardBuffer stores all record batches for a single shard in two queues: - * - IN_PROGRESS: record batches that have been consumed but not yet checkpointed. - * - PENDING: record batches that have been added but not yet consumed. - *

- * When consuming records all PENDING batches are moved to IN_PROGRESS. - * After a successful checkpoint all IN_PROGRESS batches are cleared. - * After a rollback all IN_PROGRESS batches are kept, allowing to retry consumption. - *

- * Each batch preserves the original grouping of records as provided by Kinesis - * along with their associated checkpointer, ensuring atomicity. - */ - private static class ShardBuffer { - - private static final long AWAIT_MILLIS = 100; - - // Retry configuration. - private static final int MAX_RETRY_ATTEMPTS = 5; - private static final long BASE_RETRY_DELAY_MILLIS = 100; - private static final long MAX_RETRY_DELAY_MILLIS = 10_000; - private static final Random RANDOM = new Random(); - - private final ShardBufferId bufferId; - private final ComponentLog logger; - - private final long checkpointIntervalMillis; - private volatile long nextCheckpointTimeMillis; - - /** - * A last record checkpointer and sequence number that was ignored due to the checkpoint interval. - * If null, the last checkpoint was successful or no checkpoint was attempted yet. - */ - private volatile @Nullable LastIgnoredCheckpoint lastIgnoredCheckpoint; - - /** - * Queues for managing record batches with their checkpointers in different states. - */ - private final Queue inProgressBatches = new ConcurrentLinkedQueue<>(); - private final Queue pendingBatches = new ConcurrentLinkedQueue<>(); - /** - * Counter for tracking the number of batches in the buffer. Can be larger than the number of batches in the queues. - */ - private final AtomicInteger batchesCount = new AtomicInteger(0); - - /** - * A countdown latch that is used to signal when the buffer becomes empty. Used when ShardBuffer should be closed. - */ - private volatile @Nullable CountDownLatch emptyBufferLatch = null; - private final AtomicBoolean invalidated = new AtomicBoolean(false); - - ShardBuffer(final ShardBufferId bufferId, final ComponentLog logger, final long checkpointIntervalMillis) { - this.bufferId = bufferId; - this.logger = logger; - this.checkpointIntervalMillis = checkpointIntervalMillis; - this.nextCheckpointTimeMillis = System.currentTimeMillis() + checkpointIntervalMillis; - } - - /** - * @param recordBatch record batch with records to add. - * @return true if the records were added successfully, false if a buffer was invalidated. - */ - boolean offer(final RecordBatch recordBatch) { - if (invalidated.get()) { - return false; - } - - // Batches count must be always equal to or larger than the number of batches in the queues. - // Thus, the ordering of the operations. - batchesCount.incrementAndGet(); - pendingBatches.offer(recordBatch); - - return true; - } - - List consumeRecords() { - if (invalidated.get()) { - return emptyList(); - } - - RecordBatch pendingBatch; - while ((pendingBatch = pendingBatches.poll()) != null) { - inProgressBatches.offer(pendingBatch); - } - - final List recordsToConsume = new ArrayList<>(); - for (final RecordBatch batch : inProgressBatches) { - recordsToConsume.addAll(batch.records()); - } - - return recordsToConsume; - } - - List commitConsumedRecords() { - if (invalidated.get()) { - return emptyList(); - } - - final List checkpointedBatches = new ArrayList<>(); - RecordBatch batch; - while ((batch = inProgressBatches.poll()) != null) { - checkpointedBatches.add(batch); - } - - if (checkpointedBatches.isEmpty()) { - // The buffer could be invalidated in the meantime, or no records were consumed. - return emptyList(); - } - - // Batches count must always be equal to or larger than the number of batches in the queues. - // To achieve so, the count is decreased only after the queue has been emptied. - batchesCount.addAndGet(-checkpointedBatches.size()); - - final RecordProcessorCheckpointer lastBatchCheckpointer = checkpointedBatches.getLast().checkpointer(); - final KinesisClientRecord lastRecord = checkpointedBatches.getLast().records().getLast(); - - if (System.currentTimeMillis() >= nextCheckpointTimeMillis) { - checkpointSequenceNumber(lastBatchCheckpointer, lastRecord.sequenceNumber(), lastRecord.subSequenceNumber()); - nextCheckpointTimeMillis = System.currentTimeMillis() + checkpointIntervalMillis; - lastIgnoredCheckpoint = null; - } else { - // Saving the checkpointer for later, in case shutdown happens before the next checkpoint. - lastIgnoredCheckpoint = new LastIgnoredCheckpoint(lastBatchCheckpointer, lastRecord.sequenceNumber(), lastRecord.subSequenceNumber()); - } - - final CountDownLatch localEmptyBufferLatch = this.emptyBufferLatch; - if (localEmptyBufferLatch != null && isEmpty()) { - // If the latch is not null, it means we are waiting for the buffer to become empty. - localEmptyBufferLatch.countDown(); - } - - return checkpointedBatches; - } - - void rollbackConsumedRecords() { - if (invalidated.get()) { - return; - } - - for (final RecordBatch recordBatch : inProgressBatches) { - for (final KinesisClientRecord record : recordBatch.records()) { - record.data().rewind(); - } - } - } - - void checkpointEndedShard(final RecordProcessorCheckpointer checkpointer) { - while (true) { - if (invalidated.get()) { - return; - } - - // Setting the latch first, so that if the buffer is being emptied concurrently in commitConsumedRecords - // the latch is guaranteed to be visible to the commitConsumedRecords, and, therefore, opened. - // This will eliminate unnecessary downtimes when waiting for a latch to be opened. - final CountDownLatch localEmptyBufferLatch = new CountDownLatch(1); - this.emptyBufferLatch = localEmptyBufferLatch; - - if (batchesCount.get() == 0) { - // Buffer is empty, perform final checkpoint. - checkpointLastReceivedRecord(checkpointer); - return; - } - - // Wait for the records to be consumed first. - try { - localEmptyBufferLatch.await(AWAIT_MILLIS, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for records to be consumed", e); - } - } - } - - Collection shutdownBuffer(final RecordProcessorCheckpointer checkpointer) { - if (invalidated.getAndSet(true)) { - return emptyList(); - } - - if (batchesCount.get() == 0) { - checkpointLastReceivedRecord(checkpointer); - return emptyList(); - } - - // If there are still records in the buffer, checkpointing with the latest provided checkpointer is not safe. - // But, if the records were committed without checkpointing in the past, we can checkpoint them now. - final LastIgnoredCheckpoint ignoredCheckpoint = this.lastIgnoredCheckpoint; - if (ignoredCheckpoint != null) { - checkpointSequenceNumber( - ignoredCheckpoint.checkpointer(), - ignoredCheckpoint.sequenceNumber(), - ignoredCheckpoint.subSequenceNumber() - ); - } - - return drainInvalidatedBatches(); - } - - Collection invalidate() { - if (invalidated.getAndSet(true)) { - return emptyList(); - } - - return drainInvalidatedBatches(); - } - - private Collection drainInvalidatedBatches() { - if (!invalidated.get()) { - throw new IllegalStateException("Unable to drain invalidated batches for valid shard buffer: " + bufferId); - } - - final List batches = new ArrayList<>(); - RecordBatch batch; - // If both consumeRecords and drainInvalidatedBatches are called concurrently, invalidation must always consume all batches. - // Since consumeRecords moves batches from pending to in_progress, during invalidation pending batches should be drained first. - while ((batch = pendingBatches.poll()) != null) { - batches.add(batch); - } - while ((batch = inProgressBatches.poll()) != null) { - batches.add(batch); - } - - // No need to adjust batchesCount after invalidation. - - return batches; - } - - /** - * Checks if the buffer has any records. Can produce false negatives. - * - * @return whether there are any records in the buffer. - */ - boolean isEmpty() { - return invalidated.get() || batchesCount.get() == 0; - } - - private void checkpointLastReceivedRecord(final RecordProcessorCheckpointer checkpointer) { - logger.debug("Performing checkpoint for buffer with id {}. Checkpointing the last received record", bufferId); - - checkpointSafely(checkpointer::checkpoint); - } - - private void checkpointSequenceNumber(final RecordProcessorCheckpointer checkpointer, final String sequenceNumber, final long subSequenceNumber) { - logger.debug("Performing checkpoint for buffer with id {}. Sequence number: [{}], sub sequence number: [{}]", - bufferId, sequenceNumber, subSequenceNumber); - - checkpointSafely(() -> checkpointer.checkpoint(sequenceNumber, subSequenceNumber)); - } - - /** - * Performs checkpointing using exponential backoff and jitter, if needed. - * - * @param checkpointAction the action which performs the checkpointing. - */ - private void checkpointSafely(final CheckpointAction checkpointAction) { - for (int attempt = 1; attempt <= MAX_RETRY_ATTEMPTS; attempt++) { - try { - checkpointAction.doCheckpoint(); - if (attempt > 1) { - logger.debug("Checkpoint succeeded on attempt {}", attempt); - } - return; - } catch (final ThrottlingException | InvalidStateException | KinesisClientLibDependencyException e) { - if (attempt == MAX_RETRY_ATTEMPTS) { - logger.error("Failed to checkpoint after {} attempts, giving up", MAX_RETRY_ATTEMPTS, e); - return; - } - - final long delayMillis = calculateRetryDelay(attempt); - - logger.debug("Checkpoint failed on attempt {} with {}, retrying in {} ms", - attempt, e.getMessage(), delayMillis); - - try { - Thread.sleep(delayMillis); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - logger.warn("Thread interrupted while waiting to retry checkpoint. Exiting retry loop", ie); - return; - } - } catch (final ShutdownException e) { - logger.warn("Failed to checkpoint records due to shutdown. Ignoring checkpoint", e); - return; - } catch (final RuntimeException e) { - logger.warn("Failed to checkpoint records due to an error. Ignoring checkpoint", e); - return; - } - } - } - - private long calculateRetryDelay(final int attempt) { - final long desiredBaseDelayMillis = BASE_RETRY_DELAY_MILLIS * (1L << (attempt - 1)); - final long baseDelayMillis = Math.min(desiredBaseDelayMillis, MAX_RETRY_DELAY_MILLIS); - final long jitterMillis = RANDOM.nextLong(baseDelayMillis / 4); // Up to 25% jitter. - return baseDelayMillis + jitterMillis; - } - - private interface CheckpointAction { - - /** - * Throws the same set of exceptions as {@link RecordProcessorCheckpointer#checkpoint()} and {@link RecordProcessorCheckpointer#checkpoint(String, long)}. - */ - void doCheckpoint() throws KinesisClientLibDependencyException, InvalidStateException, ThrottlingException, ShutdownException, IllegalArgumentException; - } - - private record LastIgnoredCheckpoint(RecordProcessorCheckpointer checkpointer, String sequenceNumber, long subSequenceNumber) { - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java new file mode 100644 index 000000000000..6c735c176df8 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Shared-throughput Kinesis consumer that runs a continuous background fetch loop per shard. + * Each owned shard gets its own virtual thread that repeatedly calls GetRecords and enqueues + * results for the processor, mirroring the producer-consumer architecture of the KCL Scheduler. + * This keeps data flowing between onTrigger invocations rather than fetching on-demand. + * + *

Concurrency is bounded by a semaphore with {@value #MAX_CONCURRENT_FETCHES} permits so + * that at most that many GetRecords HTTP calls are in flight at any moment, preventing + * connection-pool exhaustion. Additionally, when the shared result queue exceeds + * {@value #MAX_QUEUED_RESULTS} entries the fetch loop sleeps until the processor drains results. + */ +final class PollingKinesisClient extends KinesisConsumerClient { + + private static final long DEFAULT_EMPTY_SHARD_BACKOFF_NANOS = TimeUnit.MILLISECONDS.toNanos(500); + private static final long DEFAULT_ERROR_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(2); + static final int MAX_QUEUED_RESULTS = 500; + static final int MAX_CONCURRENT_FETCHES = 50; + + private final ExecutorService fetchExecutor = Executors.newVirtualThreadPerTaskExecutor(); + private final Map pollingShardStates = new ConcurrentHashMap<>(); + private final Semaphore fetchPermits = new Semaphore(MAX_CONCURRENT_FETCHES, true); + private final long emptyShardBackoffNanos; + private final long errorBackoffNanos; + private volatile Instant timestampForInitialPosition; + + PollingKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger) { + this(kinesisClient, logger, DEFAULT_EMPTY_SHARD_BACKOFF_NANOS, DEFAULT_ERROR_BACKOFF_NANOS); + } + + PollingKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger, + final long emptyShardBackoffNanos, final long errorBackoffNanos) { + super(kinesisClient, logger); + this.emptyShardBackoffNanos = emptyShardBackoffNanos; + this.errorBackoffNanos = errorBackoffNanos; + } + + void setTimestampForInitialPosition(final Instant timestamp) { + this.timestampForInitialPosition = timestamp; + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, + final String initialStreamPosition, final KinesisShardManager shardManager) { + if (fetchExecutor.isShutdown()) { + return; + } + + for (final Shard shard : shards) { + final String shardId = shard.shardId(); + final PollingShardState existing = pollingShardStates.get(shardId); + if (existing == null) { + final PollingShardState state = new PollingShardState(); + if (pollingShardStates.putIfAbsent(shardId, state) == null && state.tryStartLoop()) { + launchFetchLoop(state, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } + } else if (!existing.isExhausted() && !existing.isStopped() && !existing.isLoopRunning() + && existing.tryStartLoop()) { + logger.warn("Restarting dead fetch loop for shard {}", shardId); + launchFetchLoop(existing, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } + } + } + + @Override + boolean hasPendingFetches() { + if (hasQueuedResults()) { + return true; + } + for (final PollingShardState state : pollingShardStates.values()) { + if (!state.isExhausted() && !state.isStopped()) { + return true; + } + } + return false; + } + + @Override + void acknowledgeResults(final List results) { + } + + @Override + void rollbackResults(final List results) { + for (final ShardFetchResult result : results) { + final PollingShardState state = pollingShardStates.get(result.shardId()); + if (state != null) { + state.requestReset(); + } + } + } + + @Override + void removeUnownedShards(final Set ownedShards) { + pollingShardStates.entrySet().removeIf(entry -> { + if (!ownedShards.contains(entry.getKey())) { + entry.getValue().stop(); + return true; + } + return false; + }); + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + if (!shouldLogDiagnostics()) { + return; + } + + int active = 0; + int exhausted = 0; + int stopped = 0; + int dead = 0; + for (final PollingShardState state : pollingShardStates.values()) { + if (state.isExhausted()) { + exhausted++; + } else if (state.isStopped()) { + stopped++; + } else if (!state.isLoopRunning()) { + dead++; + } else { + active++; + } + } + + logger.debug("Kinesis polling diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}, " + + "fetchLoops={}, active={}, exhausted={}, stopped={}, dead={}, concurrentFetches={}", + cachedShardCount, ownedCount, totalQueuedResults(), pollingShardStates.size(), + active, exhausted, stopped, dead, MAX_CONCURRENT_FETCHES - fetchPermits.availablePermits()); + } + + @Override + void close() { + for (final PollingShardState state : pollingShardStates.values()) { + state.stop(); + } + pollingShardStates.clear(); + fetchExecutor.shutdownNow(); + super.close(); + } + + private void launchFetchLoop(final PollingShardState state, final String shardId, + final String streamName, final int batchSize, final String initialStreamPosition, + final KinesisShardManager shardManager) { + final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + fetchExecutor.submit(() -> { + Thread.currentThread().setContextClassLoader(contextClassLoader); + try { + runFetchLoop(state, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } catch (final Throwable t) { + if (!state.isStopped()) { + logger.error("Fetch loop for shard {} terminated unexpectedly", shardId, t); + } + } finally { + state.markLoopStopped(); + } + }); + } catch (final RejectedExecutionException e) { + state.markLoopStopped(); + logger.debug("Executor shut down; cannot start fetch loop for shard {}", shardId); + } + } + + private void runFetchLoop(final PollingShardState state, final String shardId, + final String streamName, final int batchSize, final String initialStreamPosition, + final KinesisShardManager shardManager) { + + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + + while (!Thread.currentThread().isInterrupted() && !state.isStopped()) { + try { + if (state.isExhausted()) { + return; + } + + if (state.isResetRequested()) { + state.clearReset(); + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + } + + if (state.getIterator() == null) { + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + if (state.getIterator() == null) { + sleepNanos(errorBackoffNanos); + continue; + } + } + + if (totalQueuedResults() >= MAX_QUEUED_RESULTS) { + sleepNanos(emptyShardBackoffNanos); + continue; + } + + try { + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + final GetRecordsResponse response; + try { + response = fetchRecords(shardId, state, batchSize); + } finally { + fetchPermits.release(); + } + if (response == null) { + continue; + } + + final List records = response.records(); + if (!records.isEmpty()) { + final long millisBehind = response.millisBehindLatest() != null ? response.millisBehindLatest() : -1; + enqueueResult(createFetchResult(shardId, records, millisBehind)); + } + + state.setIterator(response.nextShardIterator()); + if (state.getIterator() == null) { + state.markExhausted(); + return; + } + + if (records.isEmpty()) { + sleepNanos(emptyShardBackoffNanos); + } + } catch (final Exception e) { + if (!state.isStopped()) { + logger.error("Unexpected error in fetch loop for shard {}; will retry", shardId, e); + state.setIterator(null); + sleepNanos(errorBackoffNanos); + } + } + } + } + + private GetRecordsResponse fetchRecords(final String shardId, final PollingShardState state, final int batchSize) { + final GetRecordsRequest request = GetRecordsRequest.builder() + .shardIterator(state.getIterator()) + .limit(batchSize) + .build(); + + try { + return kinesisClient.getRecords(request); + } catch (final software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException + | software.amazon.awssdk.services.kinesis.model.LimitExceededException e) { + logger.debug("GetRecords throttled for shard {}; will retry after backoff", shardId); + sleepNanos(errorBackoffNanos); + return null; + } catch (final software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException e) { + logger.info("Shard iterator expired for shard {}; will re-acquire", shardId); + state.setIterator(null); + sleepNanos(errorBackoffNanos); + return null; + } catch (final SdkClientException e) { + if (!state.isStopped()) { + logger.warn("GetRecords timed out for shard {}; will retry with existing iterator", shardId); + sleepNanos(errorBackoffNanos); + } + return null; + } catch (final Exception e) { + if (!state.isStopped()) { + logger.error("GetRecords failed for shard {}", shardId, e); + state.setIterator(null); + sleepNanos(errorBackoffNanos); + } + return null; + } + } + + private static void sleepNanos(final long nanos) { + try { + TimeUnit.NANOSECONDS.sleep(nanos); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private String getShardIterator(final PollingShardState state, final String streamName, + final String shardId, final String initialStreamPosition, final KinesisShardManager shardManager) { + final String lastSequenceNumber; + try { + lastSequenceNumber = shardManager.readCheckpoint(shardId); + } catch (final Exception e) { + if (!state.isStopped()) { + logger.warn("Failed to read checkpoint for shard {}; will retry", shardId, e); + } + return null; + } + + final ShardIteratorType iteratorType; + final String startingSequenceNumber; + final Instant timestamp; + if (lastSequenceNumber != null) { + iteratorType = ShardIteratorType.AFTER_SEQUENCE_NUMBER; + startingSequenceNumber = lastSequenceNumber; + timestamp = null; + } else { + iteratorType = ShardIteratorType.fromValue(initialStreamPosition); + startingSequenceNumber = null; + timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? timestampForInitialPosition : null; + } + + logger.debug("Getting shard iterator for shard {} with type={}, startingSeq={}, timestamp={}", + shardId, iteratorType, startingSequenceNumber, timestamp); + + final GetShardIteratorRequest.Builder iteratorRequestBuilder = GetShardIteratorRequest.builder() + .streamName(streamName) + .shardId(shardId) + .shardIteratorType(iteratorType); + + if (startingSequenceNumber != null) { + iteratorRequestBuilder.startingSequenceNumber(startingSequenceNumber); + } + if (timestamp != null) { + iteratorRequestBuilder.timestamp(timestamp); + } + + try { + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + + try { + return kinesisClient.getShardIterator(iteratorRequestBuilder.build()).shardIterator(); + } catch (final Exception e) { + if (!state.isStopped()) { + logger.error("Failed to get shard iterator for shard {} (type={}, seq={})", + shardId, iteratorType, startingSequenceNumber, e); + } + return null; + } finally { + fetchPermits.release(); + } + } + + static final class PollingShardState { + private volatile String currentIterator; + private volatile boolean shardExhausted; + private volatile boolean stopped; + private volatile boolean resetRequested; + private final AtomicBoolean loopRunning = new AtomicBoolean(); + + String getIterator() { + return currentIterator; + } + + void setIterator(final String iterator) { + currentIterator = iterator; + } + + boolean isExhausted() { + return shardExhausted; + } + + void markExhausted() { + shardExhausted = true; + } + + boolean isStopped() { + return stopped; + } + + void stop() { + stopped = true; + } + + boolean isResetRequested() { + return resetRequested; + } + + void requestReset() { + resetRequested = true; + } + + void clearReset() { + resetRequested = false; + } + + boolean tryStartLoop() { + return loopRunning.compareAndSet(false, true); + } + + void markLoopStopped() { + loopRunning.set(false); + } + + boolean isLoopRunning() { + return loopRunning.get(); + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java deleted file mode 100644 index fd2aa3be0fb4..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processor.ProcessSession; -import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverter; -import org.apache.nifi.schema.access.SchemaNotFoundException; -import org.apache.nifi.serialization.MalformedRecordException; -import org.apache.nifi.serialization.RecordReader; -import org.apache.nifi.serialization.RecordReaderFactory; -import org.apache.nifi.serialization.RecordSetWriter; -import org.apache.nifi.serialization.RecordSetWriterFactory; -import org.apache.nifi.serialization.WriteResult; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static java.util.Collections.emptyMap; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; - -final class ReaderRecordProcessor { - - private final RecordReaderFactory recordReaderFactory; - private final KinesisRecordConverter recordConverter; - private final RecordSetWriterFactory recordWriterFactory; - private final ComponentLog logger; - - ReaderRecordProcessor( - final RecordReaderFactory recordReaderFactory, - final KinesisRecordConverter recordConverter, - final RecordSetWriterFactory recordWriterFactory, - final ComponentLog logger) { - this.recordReaderFactory = recordReaderFactory; - this.recordConverter = recordConverter; - this.recordWriterFactory = recordWriterFactory; - this.logger = logger; - } - - ProcessingResult processRecords( - final ProcessSession session, - final String streamName, - final String shardId, - final List records) { - final List successFlowFiles = new ArrayList<>(); - final List failureFlowFiles = new ArrayList<>(); - - ActiveFlowFile activeFlowFile = null; - - for (final KinesisClientRecord kinesisRecord : records) { - final int dataSize = kinesisRecord.data().remaining(); - final byte[] data = new byte[dataSize]; - kinesisRecord.data().get(data); - - try (final InputStream in = new ByteArrayInputStream(data); - final RecordReader reader = recordReaderFactory.createRecordReader(emptyMap(), in, data.length, logger)) { - - Record record; - while ((record = reader.nextRecord()) != null) { - final Record convertedRecord = recordConverter.convert(record, kinesisRecord, streamName, shardId); - final RecordSchema writeSchema = recordWriterFactory.getSchema(emptyMap(), convertedRecord.getSchema()); - - if (activeFlowFile == null) { - activeFlowFile = ActiveFlowFile.startNewFile(logger, session, recordWriterFactory, writeSchema, streamName, shardId); - } else if (!writeSchema.equals(activeFlowFile.schema())) { - // If the write schema has changed, we need to complete the current FlowFile and start a new one. - final FlowFile completedFlowFile = activeFlowFile.complete(); - successFlowFiles.add(completedFlowFile); - - activeFlowFile = ActiveFlowFile.startNewFile(logger, session, recordWriterFactory, writeSchema, streamName, shardId); - } - - activeFlowFile.writeRecord(convertedRecord, kinesisRecord); - } - } catch (final IOException | MalformedRecordException | SchemaNotFoundException e) { - logger.error("Reader or Writer failed to process Kinesis Record with Stream Name [{}] Shard Id [{}] Sequence Number [{}] SubSequence Number [{}]", - streamName, shardId, kinesisRecord.sequenceNumber(), kinesisRecord.subSequenceNumber(), e); - final FlowFile failureFlowFile = createParseFailureFlowFile(session, streamName, shardId, kinesisRecord, e); - failureFlowFiles.add(failureFlowFile); - } - } - - if (activeFlowFile != null) { - final FlowFile completedFlowFile = activeFlowFile.complete(); - successFlowFiles.add(completedFlowFile); - } - - return new ProcessingResult(successFlowFiles, failureFlowFiles); - } - - private static FlowFile createParseFailureFlowFile( - final ProcessSession session, - final String streamName, - final String shardId, - final KinesisClientRecord record, - final Exception e) { - FlowFile flowFile = session.create(); - - record.data().rewind(); - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - channel.write(record.data()); - } - }); - - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, record, record); - - final Throwable cause = e.getCause() != null ? e.getCause() : e; - attributes.put(RECORD_ERROR_MESSAGE, cause.toString()); - - flowFile = session.putAllAttributes(flowFile, attributes); - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - return flowFile; - } - - record ProcessingResult(List successFlowFiles, List parseFailureFlowFiles) { - } - - /** - * A class that manages a single {@link FlowFile} with a static schema that is currently being written to. - * On a schema change the current {@link ActiveFlowFile} should be completed a new instance of this class - * with a new schema should be created. - * - * An {@link ActiveFlowFile} must have at least one record written to it before it can be completed. - */ - private static final class ActiveFlowFile { - - private final ComponentLog logger; - - private final ProcessSession session; - private final FlowFile flowFile; - private final RecordSetWriter writer; - private final RecordSchema schema; - - private final String streamName; - private final String shardId; - - private KinesisClientRecord firstRecord; - private KinesisClientRecord lastRecord; - - private ActiveFlowFile( - final ComponentLog logger, - final ProcessSession session, - final FlowFile flowFile, - final RecordSetWriter writer, - final RecordSchema schema, - final String streamName, - final String shardId) { - this.logger = logger; - this.session = session; - this.flowFile = flowFile; - this.writer = writer; - this.schema = schema; - this.streamName = streamName; - this.shardId = shardId; - } - - static ActiveFlowFile startNewFile( - final ComponentLog logger, - final ProcessSession session, - final RecordSetWriterFactory recordWriterFactory, - final RecordSchema writeSchema, - final String streamName, - final String shardId) throws SchemaNotFoundException { - final FlowFile flowFile = session.create(); - final OutputStream outputStream = session.write(flowFile); - - try { - final RecordSetWriter writer = recordWriterFactory.createWriter(logger, writeSchema, outputStream, flowFile); - writer.beginRecordSet(); - - return new ActiveFlowFile(logger, session, flowFile, writer, writeSchema, streamName, shardId); - - } catch (final SchemaNotFoundException e) { - logger.debug("Failed to find writeSchema for Kinesis stream record: {}", e.getMessage()); - try { - outputStream.close(); - } catch (final IOException ioe) { - e.addSuppressed(ioe); - } - throw e; - - } catch (final IOException e) { - final ProcessException processException = new ProcessException("Failed to create a writer for a FlowFile", e); - - logger.debug("Stopping Kinesis records processing. Failed to create a writer for a FlowFile: {}", e.getMessage()); - try { - outputStream.close(); - } catch (final IOException ioe) { - processException.addSuppressed(ioe); - } - throw processException; - } - } - - RecordSchema schema() { - return schema; - } - - void writeRecord(final Record record, final KinesisClientRecord kinesisRecord) { - try { - writer.write(record); - } catch (final IOException e) { - logger.debug("Stopping Kinesis records processing. Failed to write to a FlowFile: {}", e.getMessage()); - throw new ProcessException("Failed to write a record into a FlowFile", e); - } - - if (firstRecord == null) { - firstRecord = kinesisRecord; - } - lastRecord = kinesisRecord; - } - - FlowFile complete() { - if (firstRecord == null || lastRecord == null) { - throw new IllegalStateException("Cannot complete an ActiveFlowFile that has no records"); - } - - try { - final WriteResult finalResult = writer.finishRecordSet(); - writer.close(); - - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, firstRecord, lastRecord); - attributes.putAll(finalResult.getAttributes()); - attributes.put(RECORD_COUNT, String.valueOf(finalResult.getRecordCount())); - attributes.put(MIME_TYPE, writer.getMimeType()); - - final FlowFile completedFlowFile = session.putAllAttributes(flowFile, attributes); - session.getProvenanceReporter().receive(completedFlowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - return completedFlowFile; - - } catch (final IOException e) { - final ProcessException processException = new ProcessException("Failed to complete a FlowFile", e); - - logger.debug("Stopping Kinesis records processing. Failed to complete a FlowFile: {}", e.getMessage()); - try { - writer.close(); - } catch (final IOException ioe) { - processException.addSuppressed(ioe); - } - - throw processException; - } - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java deleted file mode 100644 index 12fa52c6fe7a..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; -import java.util.Optional; - -/** - * RecordBuffer keeps track of all created Shard buffers, including exclusive read access via leasing. - * It acts as the main interface between KCL callbacks and the {@link ConsumeKinesis} processor, - * routing events to appropriate Shard buffers and ensuring thread-safe operations. - */ -interface RecordBuffer { - - /** - * Interface for interactions from the Kinesis Client Library to the Record Buffer. - * Reflects the methods called by {@link software.amazon.kinesis.processor.ShardRecordProcessor}. - */ - interface ForKinesisClientLibrary { - - ShardBufferId createBuffer(String shardId); - - void addRecords(ShardBufferId bufferId, List records, RecordProcessorCheckpointer checkpointer); - - /** - * Called when a shard ends - waits until the buffer is flushed then performs the final checkpoint. - */ - void checkpointEndedShard(ShardBufferId bufferId, RecordProcessorCheckpointer checkpointer); - - /** - * Called when a consumer is shut down. Performs the checkpoint and returns - * without waiting for the buffer to be flushed. - */ - void shutdownShardConsumption(ShardBufferId bufferId, RecordProcessorCheckpointer checkpointer); - - /** - * Called when lease is lost - immediately invalidates the buffer to prevent further operations. - */ - void consumerLeaseLost(ShardBufferId bufferId); - } - - /** - * Interface for interactions from {@link ConsumeKinesis} processor to the Record Buffer. - */ - interface ForProcessor { - - /** - * Acquires an exclusive lease for a buffer that has data available for consumption. - * If no data is available in the buffers, returns an empty Optional. - *

- * After acquiring a lease, the processor can consume records from the buffer. - * After consuming the records the processor must always {@link #returnBufferLease(LEASE)}. - */ - Optional acquireBufferLease(); - - /** - * Consumes records from the buffer associated with the given lease. - * The records have to be committed or rolled back later. - */ - List consumeRecords(LEASE lease); - - void commitConsumedRecords(LEASE lease); - - void rollbackConsumedRecords(LEASE lease); - - /** - * Returns the lease for a buffer back to the pool making it available for consumption again. - * The method can be called multiple times with the same lease, but only the first call will actually take an effect. - */ - void returnBufferLease(LEASE lease); - } - - record ShardBufferId(String shardId, long bufferId) { - } - - interface ShardBufferLease { - String shardId(); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java new file mode 100644 index 000000000000..0a84d411f6ed --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import java.math.BigInteger; + +/** + * Immutable checkpoint position within a Kinesis shard, composed of a sequence number + * and a sub-sequence number. The sub-sequence number is non-zero only for KPL-aggregated + * records and identifies the position within the aggregate. + * + * @param sequenceNumber the Kinesis record sequence number + * @param subSequenceNumber the sub-record index within a KPL aggregate (0 for non-aggregated) + */ +record ShardCheckpoint(String sequenceNumber, long subSequenceNumber) { + + /** + * Returns the higher of two checkpoints. Comparison is first by sequence number, + * then by sub-sequence number within the same aggregate. + */ + ShardCheckpoint max(final ShardCheckpoint other) { + final int comparison = new BigInteger(this.sequenceNumber).compareTo(new BigInteger(other.sequenceNumber)); + if (comparison > 0) { + return this; + } + if (comparison < 0) { + return other; + } + return this.subSequenceNumber >= other.subSequenceNumber ? this : other; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java similarity index 68% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java index e7575d6fcf77..343b393f956f 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java @@ -16,12 +16,19 @@ */ package org.apache.nifi.processors.aws.kinesis; -final class ProvenanceTransitUriFormat { +import java.util.List; - static String toTransitUri(final String streamName, final String shardId) { - return "kinesis:stream/" + streamName + "/" + shardId; +record ShardFetchResult(String shardId, List records, long millisBehindLatest) { + + String firstSequenceNumber() { + return records.getFirst().sequenceNumber(); + } + + String lastSequenceNumber() { + return records.getLast().sequenceNumber(); } - private ProvenanceTransitUriFormat() { + long lastSubSequenceNumber() { + return records.getLast().subSequenceNumber(); } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java deleted file mode 100644 index d454577d420b..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.FIELD_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.composeMetadataObject; - -public final class InjectMetadataRecordConverter implements KinesisRecordConverter { - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - final List schemaFields = new ArrayList<>(record.getSchema().getFields()); - schemaFields.add(FIELD_METADATA); - final RecordSchema schema = new SimpleRecordSchema(schemaFields); - - final Record metadata = composeMetadataObject(kinesisRecord, streamName, shardId); - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.put(METADATA, metadata); - - return new MapRecord(schema, recordValues); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java deleted file mode 100644 index 83a65fd81f2f..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.record.Record; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -public interface KinesisRecordConverter { - - Record convert(Record record, KinesisClientRecord kinesisRecord, String streamName, String shardId); -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java deleted file mode 100644 index b67ae22072cd..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.record.Record; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -public final class ValueRecordConverter implements KinesisRecordConverter { - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - return record; - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java deleted file mode 100644 index 693077f4437e..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.FIELD_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.composeMetadataObject; - -public final class WrapperRecordConverter implements KinesisRecordConverter { - - private static final String VALUE = "value"; - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - final Record metadata = composeMetadataObject(kinesisRecord, streamName, shardId); - - final RecordSchema convertedSchema = new SimpleRecordSchema(List.of( - FIELD_METADATA, - new RecordField(VALUE, RecordFieldType.RECORD.getRecordDataType(record.getSchema()))) - ); - final Map convertedRecord = Map.of( - METADATA, metadata, - VALUE, record - ); - - return new MapRecord(convertedSchema, convertedRecord); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java new file mode 100644 index 000000000000..e5fdff3b30cc --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.TableDescription; +import software.amazon.awssdk.services.dynamodb.model.TableStatus; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class CheckpointTableUtilsTest { + + @Test + void testCopyCheckpointItemsConvertsNewSchemaItemsForLegacyDestination() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest request = invocation.getArgument(0); + if ("legacy-table".equals(request.tableName())) { + return legacySchemaResponse(); + } + return newSchemaResponse(); + }); + + final Map newSchemaItem = Map.of( + "streamName", AttributeValue.builder().s("my-stream").build(), + "shardId", AttributeValue.builder().s("shardId-0001").build(), + "sequenceNumber", AttributeValue.builder().s("12345").build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(newSchemaItem).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "migration-table", "legacy-table"); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb, times(1)).putItem(putCaptor.capture()); + + final Map copiedItem = putCaptor.getValue().item(); + assertEquals("my-stream:shardId-0001", copiedItem.get("leaseKey").s()); + assertEquals("12345", copiedItem.get("checkpoint").s()); + } + + @Test + void testCopyCheckpointItemsSkipsNodeAndMigrationMarkersForLegacyDestination() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenReturn(legacySchemaResponse()); + + final Map nodeItem = Map.of( + "streamName", AttributeValue.builder().s("my-stream").build(), + "shardId", AttributeValue.builder().s("__node__#node-a").build()); + final Map migrationMarkerItem = Map.of( + "streamName", AttributeValue.builder().s("my-stream").build(), + "shardId", AttributeValue.builder().s("__migration__").build()); + final Map shardItem = Map.of( + "streamName", AttributeValue.builder().s("my-stream").build(), + "shardId", AttributeValue.builder().s("shardId-0002").build(), + "sequenceNumber", AttributeValue.builder().s("67890").build()); + + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( + ScanResponse.builder().items(List.of(nodeItem, migrationMarkerItem, shardItem)).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "migration-table", "legacy-table"); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb, times(1)).putItem(putCaptor.capture()); + assertEquals("my-stream:shardId-0002", putCaptor.getValue().item().get("leaseKey").s()); + } + + private static DescribeTableResponse newSchemaResponse() { + final KeySchemaElement hashKey = KeySchemaElement.builder() + .attributeName("streamName") + .keyType(KeyType.HASH) + .build(); + final KeySchemaElement rangeKey = KeySchemaElement.builder() + .attributeName("shardId") + .keyType(KeyType.RANGE) + .build(); + final TableDescription table = TableDescription.builder() + .keySchema(hashKey, rangeKey) + .tableStatus(TableStatus.ACTIVE) + .build(); + return DescribeTableResponse.builder().table(table).build(); + } + + private static DescribeTableResponse legacySchemaResponse() { + final KeySchemaElement hashKey = KeySchemaElement.builder() + .attributeName("leaseKey") + .keyType(KeyType.HASH) + .build(); + final TableDescription table = TableDescription.builder() + .keySchema(hashKey) + .tableStatus(TableStatus.ACTIVE) + .build(); + return DescribeTableResponse.builder().table(table).build(); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java index 6a5b866f25da..900da3bc101c 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java @@ -16,764 +16,1048 @@ */ package org.apache.nifi.processors.aws.kinesis; -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.flowfile.attributes.CoreAttributes; +import com.google.protobuf.ByteString; +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.nifi.avro.AvroReader; +import org.apache.nifi.avro.AvroRecordSetWriter; +import org.apache.nifi.controller.AbstractControllerService; import org.apache.nifi.json.JsonRecordSetWriter; import org.apache.nifi.json.JsonTreeReader; -import org.apache.nifi.processor.Processor; -import org.apache.nifi.processor.Relationship; -import org.apache.nifi.processor.exception.FlowFileHandlingException; +import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderControllerService; import org.apache.nifi.processors.aws.region.RegionUtil; import org.apache.nifi.provenance.ProvenanceEventRecord; -import org.apache.nifi.provenance.ProvenanceEventType; -import org.apache.nifi.reporting.InitializationException; -import org.apache.nifi.schema.access.SchemaAccessUtils; -import org.apache.nifi.schema.inference.SchemaInferenceUtil; +import org.apache.nifi.serialization.RecordSetWriter; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.RecordSet; import org.apache.nifi.util.MockFlowFile; -import org.apache.nifi.util.MockProcessSession; -import org.apache.nifi.util.SharedSessionState; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; -import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.api.parallel.ExecutionMode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.EnabledIfDockerAvailable; +import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.localstack.LocalStackContainer; import org.testcontainers.utility.DockerImageName; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.kinesis.KinesisClient; -import software.amazon.awssdk.services.kinesis.model.Consumer; -import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; -import software.amazon.awssdk.services.kinesis.model.ListStreamConsumersResponse; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; -import software.amazon.awssdk.services.kinesis.model.ScalingType; -import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.awssdk.services.kinesis.model.CreateStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordRequest; import software.amazon.awssdk.services.kinesis.model.StreamStatus; - -import java.net.URI; -import java.time.Duration; -import java.util.Collection; +import software.amazon.kinesis.retrieval.kpl.Messages; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.Callable; -import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.IntStream; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.MINUTES; -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.mapping; -import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toSet; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_PARSE_FAILURE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_SUCCESS; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.JsonRecordAssert.assertFlowFileRecordPayloads; -import static org.junit.jupiter.api.Assertions.assertAll; +import java.util.stream.Collectors; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Timeout.ThreadMode.SEPARATE_THREAD; -/** - * Tests run in parallel to optimize execution time as Kinesis consumer coordination takes a lot. - */ -@Execution(ExecutionMode.CONCURRENT) -@Timeout(value = 5, unit = MINUTES, threadMode = SEPARATE_THREAD) +@Testcontainers +@EnabledIfDockerAvailable class ConsumeKinesisIT { - private static final Logger logger = LoggerFactory.getLogger(ConsumeKinesisIT.class); - private static final DockerImageName LOCALSTACK_IMAGE = DockerImageName.parse("localstack/localstack:4.12.0"); + @Container + private static final LocalStackContainer LOCALSTACK = new LocalStackContainer( + DockerImageName.parse("localstack/localstack:4")); + + private static final String AVRO_SCHEMA_A = """ + { + "type": "record", + "name": "A", + "fields": [ + {"name": "id", "type": "int"}, + {"name": "name", "type": "string"} + ] + }"""; + + private static final String AVRO_SCHEMA_B = """ + { + "type": "record", + "name": "B", + "fields": [ + {"name": "code", "type": "string"}, + {"name": "value", "type": "double"} + ] + }"""; - private static final LocalStackContainer localstack = new LocalStackContainer(LOCALSTACK_IMAGE).withServices("kinesis", "dynamodb", "cloudwatch"); + private TestRunner runner; + private KinesisClient kinesisClient; + private int credentialServiceCounter = 0; - private static KinesisClient kinesisClient; - private static DynamoDbClient dynamoDbClient; + @BeforeEach + void setUp() throws Exception { + kinesisClient = KinesisClient.builder() + .endpointOverride(LOCALSTACK.getEndpoint()) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create(LOCALSTACK.getAccessKey(), LOCALSTACK.getSecretKey()))) + .region(Region.of(LOCALSTACK.getRegion())) + .build(); - private String streamName; - private String applicationName; - private TestRunner runner; - private TestKinesisStreamClient streamClient; + runner = TestRunners.newTestRunner(FastTimingConsumeKinesis.class); - @BeforeAll - static void oneTimeSetup() { - localstack.start(); + final JsonTreeReader reader = new JsonTreeReader(); + runner.addControllerService("json-reader", reader); + runner.enableControllerService(reader); - final AwsCredentialsProvider credentialsProvider = StaticCredentialsProvider.create( - AwsBasicCredentials.create(localstack.getAccessKey(), localstack.getSecretKey()) - ); + final JsonRecordSetWriter normalWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", normalWriter); + runner.enableControllerService(normalWriter); - kinesisClient = KinesisClient.builder() - .endpointOverride(localstack.getEndpoint()) - .credentialsProvider(credentialsProvider) - .region(Region.of(localstack.getRegion())) - .build(); + final FailingRecordSetWriterFactory failingWriter = new FailingRecordSetWriterFactory(); + runner.addControllerService("failing-writer", failingWriter); + runner.enableControllerService(failingWriter); - dynamoDbClient = DynamoDbClient.builder() - .endpointOverride(localstack.getEndpoint()) - .credentialsProvider(credentialsProvider) - .region(Region.of(localstack.getRegion())) - .build(); + addCredentialService(runner, "creds"); + runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "test-app-" + System.currentTimeMillis()); + runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "creds"); + runner.setProperty(RegionUtil.REGION, LOCALSTACK.getRegion()); + runner.setProperty(ConsumeKinesis.ENDPOINT_OVERRIDE, LOCALSTACK.getEndpoint().toString()); + runner.setProperty(ConsumeKinesis.MAX_BATCH_DURATION, "200 ms"); } - @AfterAll - static void tearDown() { + @AfterEach + void tearDown() { if (kinesisClient != null) { kinesisClient.close(); } - if (dynamoDbClient != null) { - dynamoDbClient.close(); - } - localstack.stop(); } - @BeforeEach - void setUp() throws InitializationException { - final UUID testId = UUID.randomUUID(); - streamName = "%s-kinesis-stream-%s".formatted(getClass().getSimpleName(), testId); - streamClient = new TestKinesisStreamClient(kinesisClient, streamName); - applicationName = "%s-test-kinesis-app-%s".formatted(getClass().getSimpleName(), testId); - runner = createTestRunner(streamName, applicationName); - } + @Test + void testFlowFilePerRecordStrategy() throws Exception { + final String streamName = "per-record-test"; + final int recordCount = 5; - @AfterEach - void tearDownEach() { - runner.stop(); + createStream(streamName); + publishRecords(streamName, recordCount); - if (streamClient != null) { - try { - streamClient.deleteStream(); - } catch (final Exception e) { - logger.warn("Failed to delete stream {}: {}", streamName, e.getMessage()); - } + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(recordCount, flowFiles.size(), "Expected one FlowFile per Kinesis record"); + + for (final MockFlowFile ff : flowFiles) { + assertEquals("1", ff.getAttribute("record.count")); + assertEquals(streamName, ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); + assertEquals(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE), ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_PARTITION_KEY)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SUBSEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SUBSEQUENCE)); + + final String content = ff.getContent(); + assertTrue(content.startsWith("{"), "Expected raw JSON content: " + content); } - // Removing tables generated by KCL. - deleteTable(applicationName); - deleteTable(applicationName + "-CoordinatorState"); - deleteTable(applicationName + "-WorkerMetricStats"); + final Set emittedShardIds = flowFiles.stream() + .map(ff -> ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)) + .collect(Collectors.toSet()); + final List receiveEvents = runner.getProvenanceEvents().stream() + .filter(event -> "RECEIVE".equals(event.getEventType().name())) + .toList(); + assertEquals(recordCount, receiveEvents.size(), "Expected one RECEIVE event per emitted FlowFile"); + for (final ProvenanceEventRecord receiveEvent : receiveEvents) { + final String transitUri = receiveEvent.getTransitUri(); + assertNotNull(transitUri, "RECEIVE event should include a transit URI"); + assertTrue(emittedShardIds.stream().anyMatch(shardId -> transitUri.endsWith("/" + shardId)), + "RECEIVE transit URI should include one of the emitted shard IDs: " + transitUri); + } + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } - private void deleteTable(final String tableName) { - try { - dynamoDbClient.deleteTable(req -> req.tableName(tableName)); - } catch (final Exception e) { - logger.warn("Failed to delete DynamoDB table {}: {}", tableName, e.getMessage()); + @Test + void testRecordOrientedStrategy() throws Exception { + final String streamName = "record-oriented-test"; + final int recordCount = 5; + + createStream(streamName); + publishRecords(streamName, recordCount); + + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeSingleMessageFromSingleShard() { - streamClient.createStream(1); - - final String testMessage = "Hello, Kinesis!"; - streamClient.putRecord("test-partition-key", testMessage); + void testRecordOrientedStrategyWithInjectedMetadata() throws Exception { + final String streamName = "record-oriented-metadata-test"; + final int recordCount = 5; - runProcessorWithInitAndWaitForFiles(runner, 1); + createStream(streamName); + publishRecords(streamName, recordCount); - runner.assertTransferCount(REL_SUCCESS, 1); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final MockFlowFile flowFile = flowFiles.getFirst(); - - flowFile.assertContentEquals(testMessage); - flowFile.assertAttributeEquals("aws.kinesis.partition.key", "test-partition-key"); - assertNotNull(flowFile.getAttribute("aws.kinesis.first.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.last.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.shard.id")); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, "INJECT_METADATA"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + long totalRecords = 0; + for (final MockFlowFile flowFile : flowFiles) { + totalRecords += Long.parseLong(flowFile.getAttribute("record.count")); + + final String content = flowFile.getContent(); + assertTrue(content.contains("\"kinesisMetadata\""), "Expected injected kinesisMetadata object"); + assertTrue(content.contains("\"stream\":\"" + streamName + "\""), "Expected stream in injected metadata"); + assertTrue(content.contains("\"shardId\":\""), "Expected shardId in injected metadata"); + assertTrue(content.contains("\"sequenceNumber\":\""), "Expected sequenceNumber in injected metadata"); + assertTrue(content.contains("\"subSequenceNumber\":0"), "Expected default subSequenceNumber in injected metadata"); + assertTrue(content.contains("\"partitionKey\":\""), "Expected partitionKey in injected metadata"); + } - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFile); + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - // Creates an enhanced fan-out consumer by default. - assertEquals( - List.of(applicationName), - streamClient.getEnhancedFanOutConsumerNames(), - "Expected a single enhanced fan-out consumer with an application name"); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeSingleMessageFromSingleShard_withoutEnhancedFanOut() { - runner.setProperty(ConsumeKinesis.CONSUMER_TYPE, ConsumeKinesis.ConsumerType.SHARED_THROUGHPUT); + void testDemarcatorStrategy() throws Exception { + final String streamName = "demarcator-test"; + final int recordCount = 3; + + createStream(streamName); + publishRecords(streamName, recordCount); + + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "LINE_DELIMITED"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); - streamClient.createStream(1); + final StringBuilder allContent = new StringBuilder(); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + allContent.append(ff.getContent()); + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - final String testMessage = "Hello, Kinesis!"; - streamClient.putRecord("test-partition-key", testMessage); + final String[] lines = allContent.toString().split("\n"); + assertEquals(recordCount, lines.length, "Expected one line per Kinesis record"); + for (final String line : lines) { + assertTrue(line.startsWith("{"), "Expected JSON content: " + line); + } - runProcessorWithInitAndWaitForFiles(runner, 1); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); + } - runner.assertTransferCount(REL_SUCCESS, 1); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final MockFlowFile flowFile = flowFiles.getFirst(); + @Test + void testDemarcatorStrategyWithCustomDelimiter() throws Exception { + final String streamName = "custom-delim-test"; + final int recordCount = 3; + final String delimiter = "|||"; - flowFile.assertContentEquals(testMessage); - flowFile.assertAttributeEquals("aws.kinesis.partition.key", "test-partition-key"); - assertNotNull(flowFile.getAttribute("aws.kinesis.first.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.last.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.shard.id")); + createStream(streamName); + publishRecords(streamName, recordCount); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFile); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "DEMARCATOR"); + runner.setProperty(ConsumeKinesis.MESSAGE_DEMARCATOR, delimiter); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + final StringBuilder allContent = new StringBuilder(); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + allContent.append(ff.getContent()); + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - assertTrue( - streamClient.getEnhancedFanOutConsumerNames().isEmpty(), - "No enhanced fan-out consumers should be created for Shared Throughput consumer type"); + final String[] parts = allContent.toString().split("\\|\\|\\|"); + assertEquals(recordCount, parts.length, "Expected records separated by custom delimiter"); + for (final String part : parts) { + assertTrue(part.startsWith("{"), "Expected JSON content: " + part); + } + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeManyMessagesFromSingleShardWithOrdering() { - final int messageCount = 10; + void testFailedWriteRollsBackAndRecordsAreReConsumed() throws Exception { + final String streamName = "rollback-test"; + final int recordCount = 5; + + createStream(streamName); + publishRecords(streamName, recordCount); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); - streamClient.createStream(1); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "failing-writer"); + runUntilOutput(runner); - final List messages = IntStream.range(0, messageCount).mapToObj(i -> "Message-" + i).toList(); - streamClient.putRecords("partition-key", messages); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); - runProcessorWithInitAndWaitForFiles(runner, messageCount); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertTransferCount(REL_SUCCESS, messageCount); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final List flowFilesContent = flowFiles.stream().map(MockFlowFile::getContent).toList(); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile transferred to success"); - assertEquals(messages, flowFilesContent); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "All records should be re-consumed after rollback"); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeMessagesFromMultipleShardsStream() { - final int shardCount = 5; - final int messagesPerPartitionKey = 5; + void testClusterSimulationDistributesShards() throws Exception { + final String streamName = "cluster-sim-test"; + final int shardCount = 4; + final String appName = "cluster-app-" + System.currentTimeMillis(); + + createStream(streamName, shardCount); - streamClient.createStream(shardCount); + final TestRunner runner1 = createConfiguredRunner(streamName, appName); + final TestRunner runner2 = createConfiguredRunner(streamName, appName); - // Every shard has message with the same payload. - final List shardMessages = IntStream.range(0, messagesPerPartitionKey).mapToObj(String::valueOf).toList(); + runner1.run(1, false, true); + runner2.run(1, false, true); - IntStream.range(0, shardCount).forEach(shard -> streamClient.putRecords(String.valueOf(shard), shardMessages)); + int recordId = 0; + final long deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); - // Run processor and wait for all records - final int totalMessages = shardCount * messagesPerPartitionKey; - runProcessorWithInitAndWaitForFiles(runner, totalMessages); + runner1.run(1, false, false); + runner2.run(1, false, false); - // Verify results - runner.assertTransferCount(REL_SUCCESS, totalMessages); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); + if (hasFlowFiles(runner1) && hasFlowFiles(runner2)) { + break; + } + } - final Map> partitionKey2Messages = flowFiles.stream() - .collect(groupingBy( - f -> f.getAttribute("aws.kinesis.partition.key"), - mapping(MockFlowFile::getContent, toList()) - )); + runner1.run(1, true, false); + runner2.run(1, true, false); - assertEquals(shardCount, partitionKey2Messages.size()); + final List flowFiles1 = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List flowFiles2 = runner2.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); - assertAll( - partitionKey2Messages.values().stream() - .map(actual -> () -> assertEquals(shardMessages, actual)) - ); + assertFalse(flowFiles1.isEmpty(), "Runner 1 should have received data"); + assertFalse(flowFiles2.isEmpty(), "Runner 2 should have received data"); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles); + final Set uniqueRecords = new HashSet<>(); + for (final MockFlowFile ff : flowFiles1) { + uniqueRecords.add(ff.getContent()); + } + for (final MockFlowFile ff : flowFiles2) { + uniqueRecords.add(ff.getContent()); + } + assertEquals(flowFiles1.size() + flowFiles2.size(), uniqueRecords.size(), + "No duplicate records should be consumed"); } @Test - @Disabled("Does not work with LocalStack: https://github.com/localstack/localstack/issues/12833. Enable only when using real Kinesis.") - void testResharding_inParallelWithConsumption() { - // Using partition keys with uniformally distributed hashes to ensure the data is distributed across split shards. - final List partitionKeys = List.of( - "pk-0-14", // 035517ff4ca68849589f43842c07362f - "pk-1-14", // 5f045ae51eea9bd124d76041a6a27073 - "pk-2-2", // 85fb9a2b01b009904eb8a6fa13a21d6c - "pk-3-2" // dbf24a6e26910143c60188e2fcb53b4f - ); - - // Data to be produced at each stage - final int partitionRecordsPerStage = 3; - final int totalStages = 5; // initial + 4 resharding stages with data - final int totalRecordsPerPartition = partitionRecordsPerStage * totalStages; - final int totalRecords = partitionKeys.size() * totalRecordsPerPartition; - - // Start resharding and data production operations in background thread. - final Thread reshardingThread = new Thread(() -> { - int messageSeq = 0; // For each partition key the message content are sequential numbers. - - streamClient.createStream(1); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(2); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(4); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(3); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(2); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - }); - - reshardingThread.start(); - - runProcessorWithInitAndWaitForFiles(runner, totalRecords); - - runner.assertTransferCount(REL_SUCCESS, totalRecords); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - - final Map> partitionKeyToMessages = flowFiles.stream() - .collect(groupingBy( - f -> f.getAttribute("aws.kinesis.partition.key"), - mapping(MockFlowFile::getContent, toList()) - )); - - final List expectedPartitionMessages = IntStream.range(0, totalRecordsPerPartition).mapToObj(Integer::toString).toList(); - assertAll( - partitionKeyToMessages.entrySet().stream() - .map(actual -> () -> assertEquals( - expectedPartitionMessages, - actual.getValue(), - "Partition messages do not match expected for partition key: " + actual.getKey())) - ); - } - - private void putRecords(final Collection partitionKeys, final int count, final int startIndex) { - IntStream.range(startIndex, startIndex + count).forEach(i -> { - final String message = Integer.toString(i); - partitionKeys.forEach(partitionKey -> streamClient.putRecord(partitionKey, message)); - }); + void testClusterScaleDownAndScaleUpRebalancesShards() throws Exception { + final String streamName = "cluster-rebalance-test"; + final int shardCount = 6; + final String appName = "cluster-rebalance-app-" + System.currentTimeMillis(); + + createStream(streamName, shardCount); + + final TestRunner runner1 = createConfiguredRunner(streamName, appName); + final TestRunner runner2 = createConfiguredRunner(streamName, appName); + + runner1.run(1, false, true); + runner2.run(1, false, true); + + int recordId = 0; + long deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + runner2.run(1, false, false); + if (hasFlowFiles(runner1) && hasFlowFiles(runner2)) { + break; + } + } + assertFalse(runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 1 should receive data in stage one"); + assertFalse(runner2.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 2 should receive data in stage one"); + + final int runner1Checkpoint = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).size(); + runner2.run(1, true, false); + + deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + + final Set shards = new HashSet<>(); + for (final MockFlowFile ff : getNewFlowFiles(runner1, runner1Checkpoint)) { + shards.add(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + } + if (shards.size() == shardCount) { + break; + } + } + + final Set stageTwoShards = new HashSet<>(); + for (final MockFlowFile ff : getNewFlowFiles(runner1, runner1Checkpoint)) { + stageTwoShards.add(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + } + assertEquals(shardCount, stageTwoShards.size(), "Single active runner should consume from all shards"); + + final TestRunner runner3 = createConfiguredRunner(streamName, appName); + final TestRunner runner4 = createConfiguredRunner(streamName, appName); + final int runner1StageThreeCheckpoint = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).size(); + + runner3.run(1, false, true); + runner4.run(1, false, true); + + deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + runner3.run(1, false, false); + runner4.run(1, false, false); + + if (!getNewFlowFiles(runner1, runner1StageThreeCheckpoint).isEmpty() + && hasFlowFiles(runner3) && hasFlowFiles(runner4)) { + break; + } + } + + runner1.run(1, true, false); + runner3.run(1, true, false); + runner4.run(1, true, false); + + assertFalse(getNewFlowFiles(runner1, runner1StageThreeCheckpoint).isEmpty(), + "Runner 1 should receive data in stage three"); + assertFalse(runner3.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 3 should receive data in stage three"); + assertFalse(runner4.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 4 should receive data in stage three"); } @Test - void testSessionRollback() throws InterruptedException { - streamClient.createStream(1); + void testKplAggregatedRecordsFlowFilePerRecord() throws Exception { + final String streamName = "kpl-per-record-test"; + createStream(streamName); - // Initialize the processor. - runner.run(1, false, true); + publishAggregatedRecord(streamName, "agg-pk-1", + List.of("pk-a", "pk-b"), + List.of("{\"id\":1,\"name\":\"Alice\"}", "{\"id\":2,\"name\":\"Bob\"}", "{\"id\":3,\"name\":\"Charlie\"}"), + List.of(0, 1, 0)); - final ConsumeKinesis processor = (ConsumeKinesis) runner.getProcessor(); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - final String firstMessage = "Initial-Rollback-Message"; - streamClient.putRecord("key", firstMessage); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(3, flowFiles.size(), "Each sub-record in the aggregate should produce its own FlowFile"); - // First attempt with a failing session - should rollback. - while (true) { - final MockProcessSession failingSession = createFailingSession(processor); - try { - processor.onTrigger(runner.getProcessContext(), failingSession); - } catch (final FlowFileHandlingException ignored) { - failingSession.assertAllFlowFilesTransferred(REL_SUCCESS, 0); - break; // Expected rollback occurred - } - Thread.sleep(1000); + final Set contents = new HashSet<>(); + for (final MockFlowFile ff : flowFiles) { + assertEquals("1", ff.getAttribute("record.count")); + assertEquals(streamName, ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + contents.add(ff.getContent()); } - // Write another message. - final String secondMessage = "Another-Test-Message"; - streamClient.putRecord("key", secondMessage); + assertTrue(contents.contains("{\"id\":1,\"name\":\"Alice\"}")); + assertTrue(contents.contains("{\"id\":2,\"name\":\"Bob\"}")); + assertTrue(contents.contains("{\"id\":3,\"name\":\"Charlie\"}")); + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(3, counter.longValue()); + } + + @Test + void testKplAggregatedRecordsRecordOrientedWithMetadata() throws Exception { + final String streamName = "kpl-metadata-test"; + createStream(streamName); - runProcessorAndWaitForFiles(runner, 2); + publishAggregatedRecord(streamName, "agg-pk-1", + List.of("pk-inner"), + List.of("{\"id\":10,\"name\":\"Alpha\"}", "{\"id\":20,\"name\":\"Beta\"}"), + List.of(0, 0)); - // Verify the messages are written in the correct order. - runner.assertTransferCount(REL_SUCCESS, 2); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - flowFiles.getFirst().assertContentEquals(firstMessage); - flowFiles.getLast().assertContentEquals(secondMessage); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, "INJECT_METADATA"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty()); + + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + final String content = ff.getContent(); + assertTrue(content.contains("\"kinesisMetadata\""), "Expected injected metadata"); + assertTrue(content.contains("\"stream\":\"" + streamName + "\"")); + assertTrue(content.contains("\"partitionKey\":\"pk-inner\""), "Expected inner partition key, not the outer one"); + } + assertEquals(2, totalRecords); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles.getFirst(), flowFiles.getLast()); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(2, counter.longValue()); } @Test - void testRecordProcessingWithSchemaChangesAndInvalidRecords() throws InitializationException { - streamClient.createStream(1); + void testKplAggregatedMixedWithPlainRecords() throws Exception { + final String streamName = "kpl-mixed-test"; + createStream(streamName); - final TestRunner recordRunner = createRecordTestRunner(streamName, applicationName); + publishRecords(streamName, 2); - final List testRecords = List.of( - "{\"name\":\"John\",\"age\":30}", // Schema A - "{\"name\":\"Jane\",\"age\":25}", // Schema A - "{invalid json}", - "{\"id\":\"123\",\"value\":\"test\"}" // Schema B - ); + publishAggregatedRecord(streamName, "agg-pk", + List.of("pk-agg"), + List.of("{\"id\":100,\"name\":\"Aggregated-1\"}", "{\"id\":101,\"name\":\"Aggregated-2\"}", "{\"id\":102,\"name\":\"Aggregated-3\"}"), + List.of(0, 0, 0)); - testRecords.forEach(record -> streamClient.putRecord("key", record)); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - runProcessorWithInitAndWaitForFiles(recordRunner, 3); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(5, flowFiles.size(), "2 plain + 3 deaggregated sub-records"); - // Verify successful records. - recordRunner.assertTransferCount(REL_SUCCESS, 2); - final List successFlowFiles = recordRunner.getFlowFilesForRelationship(REL_SUCCESS); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(5, counter.longValue()); + } - final MockFlowFile firstFlowFile = successFlowFiles.getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecordPayloads(firstFlowFile, testRecords.getFirst(), testRecords.get(1)); + @Test + void testKplMultipleAggregatedRecords() throws Exception { + final String streamName = "kpl-multi-agg-test"; + createStream(streamName); + + for (int batch = 0; batch < 3; batch++) { + final List payloads = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + final int id = batch * 5 + i; + payloads.add("{\"id\":" + id + ",\"name\":\"rec-" + id + "\"}"); + } + publishAggregatedRecord(streamName, "agg-pk-" + batch, + List.of("pk-" + batch), payloads, payloads.stream().map(p -> 0).toList()); + } - final MockFlowFile secondFlowFile = successFlowFiles.get(1); - assertEquals("1", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecordPayloads(secondFlowFile, testRecords.getLast()); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - // Verify failure record. - recordRunner.assertTransferCount(REL_PARSE_FAILURE, 1); - final List parseFailureFlowFiles = recordRunner.getFlowFilesForRelationship(REL_PARSE_FAILURE); - final MockFlowFile parseFailureFlowFile = parseFailureFlowFiles.getFirst(); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(15, flowFiles.size(), "3 aggregated records x 5 sub-records each"); - parseFailureFlowFile.assertContentEquals(testRecords.get(2)); - assertNotNull(parseFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); + final Set contents = new HashSet<>(); + for (final MockFlowFile ff : flowFiles) { + contents.add(ff.getContent()); + } + for (int i = 0; i < 15; i++) { + assertTrue(contents.contains("{\"id\":" + i + ",\"name\":\"rec-" + i + "\"}"), + "Missing deaggregated record with id=" + i); + } - // Verify provenance events. - assertReceiveProvenanceEvents(recordRunner.getProvenanceEvents(), firstFlowFile, secondFlowFile, parseFailureFlowFile); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(15, counter.longValue()); } @Test - void testRecordProcessingWithDemarcator() throws InitializationException { - streamClient.createStream(1); + void testAvroSameSchemaProducesSingleFlowFile() throws Exception { + final String streamName = "avro-same-schema-test"; + createStream(streamName); - final TestRunner demarcatorTestRunner = createDemarcatorTestRunner(streamName, applicationName, System.lineSeparator()); + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); - final List testRecords = List.of( - "{\"name\":\"John\",\"age\":30}", // Schema A - "{\"name\":\"Jane\",\"age\":25}", // Schema A - "{invalid json}", - "{\"id\":\"123\",\"value\":\"test\"}" // Schema B - ); + for (int i = 0; i < 4; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "record-" + i); + publishAvroRecord(streamName, "pk-" + i, schemaA, rec); + } - testRecords.forEach(record -> streamClient.putRecord("key", record)); + configureAvroRecordOriented(streamName); + runUntilOutput(runner); - runProcessorWithInitAndWaitForFiles(demarcatorTestRunner, 1); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); - // All records from the same shard are put as is into the same FlowFile. - demarcatorTestRunner.assertTransferCount(REL_SUCCESS, 1); - final List successFlowFiles = demarcatorTestRunner.getFlowFilesForRelationship(REL_SUCCESS); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(4, totalRecords, "All 4 same-schema Avro records should be consumed"); + } - final MockFlowFile flowFile = successFlowFiles.getFirst(); - assertEquals("4", flowFile.getAttribute(RECORD_COUNT)); - flowFile.assertContentEquals(String.join(System.lineSeparator(), testRecords)); + @Test + void testAvroDifferentSchemasSplitFlowFiles() throws Exception { + final String streamName = "avro-diff-schema-test"; + createStream(streamName); + + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); + + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "a-" + i); + publishAvroRecord(streamName, "pk-a-" + i, schemaA, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaB); + rec.put("code", "code-" + i); + rec.put("value", i * 1.5); + publishAvroRecord(streamName, "pk-b-" + i, schemaB, rec); + } - // Verify provenance events. - assertReceiveProvenanceEvents(demarcatorTestRunner.getProvenanceEvents(), flowFile); - } + configureAvroRecordOriented(streamName); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(2, flowFiles.size(), "A,A,B,B should produce 2 FlowFiles"); - private static void assertReceiveProvenanceEvents(final List actualEvents, final FlowFile... expectedFlowFiles) { - assertReceiveProvenanceEvents(actualEvents, List.of(expectedFlowFiles)); + assertEquals("2", flowFiles.get(0).getAttribute("record.count")); + assertEquals("2", flowFiles.get(1).getAttribute("record.count")); } - private static void assertReceiveProvenanceEvents(final List actualEvents, final Collection expectedFlowFiles) { - assertEquals(expectedFlowFiles.size(), actualEvents.size(), "Each produced FlowFile must have a provenance event"); + @Test + void testAvroInterleavedSchemasSplitFlowFiles() throws Exception { + final String streamName = "avro-interleaved-test"; + createStream(streamName); + + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); + + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "a1-" + i); + publishAvroRecord(streamName, "pk-a1-" + i, schemaA, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaB); + rec.put("code", "code-" + i); + rec.put("value", i * 2.0); + publishAvroRecord(streamName, "pk-b-" + i, schemaB, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", 100 + i); + rec.put("name", "a2-" + i); + publishAvroRecord(streamName, "pk-a2-" + i, schemaA, rec); + } - assertAll( - actualEvents.stream().map(event -> () -> - assertEquals(ProvenanceEventType.RECEIVE, event.getEventType(), "Unexpected Provenance Event Type")) - ); + configureAvroRecordOriented(streamName); + runUntilOutput(runner); - final Set eventFlowFileUuids = actualEvents.stream() - .map(ProvenanceEventRecord::getFlowFileUuid) - .collect(toSet()); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(3, flowFiles.size(), "A,A,B,B,A,A should produce 3 FlowFiles (no demux)"); - assertAll( - expectedFlowFiles.stream() - .map(flowFile -> flowFile.getAttribute(CoreAttributes.UUID.key())) - .map(uuid -> () -> - assertTrue(eventFlowFileUuids.contains(uuid), "Expected Provenance Event for FlowFile UUID: %s was not present".formatted(uuid))) - ); + assertEquals("2", flowFiles.get(0).getAttribute("record.count")); + assertEquals("2", flowFiles.get(1).getAttribute("record.count")); + assertEquals("2", flowFiles.get(2).getAttribute("record.count")); } - private TestRunner createTestRunner(final String streamName, final String applicationName) throws InitializationException { - final TestRunner runner = TestRunners.newTestRunner(TestConsumeKinesis.class); + @Test + void testAvroMixedWithCorruptDataRoutesToParseFailure() throws Exception { + final String streamName = "avro-corrupt-mix-test"; + createStream(streamName); - final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); - runner.addControllerService("credentials", credentialsService); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, localstack.getAccessKey()); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, localstack.getSecretKey()); - runner.enableControllerService(credentialsService); + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); - runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "credentials"); - runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); - runner.setProperty(ConsumeKinesis.APPLICATION_NAME, applicationName); - runner.setProperty(RegionUtil.REGION, localstack.getRegion()); - runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, ConsumeKinesis.InitialPosition.TRIM_HORIZON); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.FLOW_FILE); + final GenericRecord recA = new GenericData.Record(schemaA); + recA.put("id", 1); + recA.put("name", "valid-a"); + publishAvroRecord(streamName, "pk-a", schemaA, recA); - runner.setProperty(ConsumeKinesis.METRICS_PUBLISHING, ConsumeKinesis.MetricsPublishing.CLOUDWATCH); + publishCorruptRecord(streamName, "pk-corrupt", "THIS_IS_NOT_AVRO_DATA"); - runner.setProperty(ConsumeKinesis.MAX_BYTES_TO_BUFFER, "10 MB"); + final GenericRecord recB = new GenericData.Record(schemaB); + recB.put("code", "valid-b"); + recB.put("value", 3.14); + publishAvroRecord(streamName, "pk-b", schemaB, recB); - runner.assertValid(); - return runner; - } + configureAvroRecordOriented(streamName); + runUntilOutput(runner); + + final List successFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); - private TestRunner createRecordTestRunner(final String streamName, final String applicationName) throws InitializationException { - final TestRunner runner = createTestRunner(streamName, applicationName); + long totalSuccessRecords = 0; + for (final MockFlowFile ff : successFlowFiles) { + totalSuccessRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(2, totalSuccessRecords, "Both valid Avro records should be in success"); + assertEquals(1, failureFlowFiles.size(), "Corrupt record should be routed to parse failure"); + failureFlowFiles.getFirst().assertContentEquals("THIS_IS_NOT_AVRO_DATA"); + } - final JsonTreeReader jsonReader = new JsonTreeReader(); - runner.addControllerService("json-reader", jsonReader); - runner.setProperty(jsonReader, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaInferenceUtil.INFER_SCHEMA.getValue()); - runner.enableControllerService(jsonReader); + @Test + void testParseFailureWithSomeValidJsonRecords() throws Exception { + final String streamName = "json-parse-failure-test"; + createStream(streamName); - final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); - runner.addControllerService("json-writer", jsonWriter); - runner.setProperty(jsonWriter, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaAccessUtils.INHERIT_RECORD_SCHEMA.getValue()); - runner.enableControllerService(jsonWriter); + publishRecord(streamName, 0); + publishCorruptRecord(streamName, "pk-bad-1", "CORRUPT_DATA_1"); + publishRecord(streamName, 1); + publishCorruptRecord(streamName, "pk-bad-2", "CORRUPT_DATA_2"); + publishRecord(streamName, 2); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.RECORD); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertValid(); - return runner; + final List successFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + + long totalSuccessRecords = 0; + for (final MockFlowFile ff : successFlowFiles) { + totalSuccessRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(3, totalSuccessRecords, "All valid JSON records should route to success"); + assertEquals(2, failureFlowFiles.size(), "Both corrupt records should route to parse failure"); } - private TestRunner createDemarcatorTestRunner(final String streamName, final String applicationName, final String demarcator) throws InitializationException { - final TestRunner runner = createTestRunner(streamName, applicationName); + @Test + void testAllCorruptRecordsRouteToParseFailure() throws Exception { + final String streamName = "all-corrupt-test"; + createStream(streamName); + + publishCorruptRecord(streamName, "pk-1", "NOT_VALID_DATA_1"); + publishCorruptRecord(streamName, "pk-2", "NOT_VALID_DATA_2"); + publishCorruptRecord(streamName, "pk-3", "NOT_VALID_DATA_3"); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.DEMARCATOR); - runner.setProperty(ConsumeKinesis.MESSAGE_DEMARCATOR, demarcator); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertValid(); - return runner; + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + assertEquals(3, failureFlowFiles.size(), "All corrupt records should route to parse failure"); } - private void runProcessorWithInitAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount) { - runProcessorAndWaitForFiles(runner, expectedFlowFileCount, true); + private void addCredentialService(final TestRunner testRunner, final String serviceId) throws Exception { + final AWSCredentialsProviderControllerService credsSvc = new AWSCredentialsProviderControllerService(); + testRunner.addControllerService(serviceId, credsSvc); + testRunner.setProperty(credsSvc, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, LOCALSTACK.getAccessKey()); + testRunner.setProperty(credsSvc, AWSCredentialsProviderControllerService.SECRET_KEY, LOCALSTACK.getSecretKey()); + testRunner.enableControllerService(credsSvc); } - private void runProcessorAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount) { - runProcessorAndWaitForFiles(runner, expectedFlowFileCount, false); - } + private TestRunner createConfiguredRunner(final String streamName, final String appName) throws Exception { + final TestRunner configuredRunner = TestRunners.newTestRunner(FastTimingConsumeKinesis.class); + + final String credsId = "creds-" + (credentialServiceCounter++); + addCredentialService(configuredRunner, credsId); - private void runProcessorAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount, final boolean withInit) { - logger.info("Running processor and waiting for {} files", expectedFlowFileCount); + configuredRunner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + configuredRunner.setProperty(ConsumeKinesis.APPLICATION_NAME, appName); + configuredRunner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, credsId); + configuredRunner.setProperty(RegionUtil.REGION, LOCALSTACK.getRegion()); + configuredRunner.setProperty(ConsumeKinesis.ENDPOINT_OVERRIDE, LOCALSTACK.getEndpoint().toString()); + configuredRunner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + configuredRunner.setProperty(ConsumeKinesis.MAX_BATCH_DURATION, "200 ms"); - if (withInit) { - runner.run(1, false, true); + return configuredRunner; + } + + /** + * Runs the processor with retries until at least one FlowFile appears on any output relationship + * (success or parse failure), or a 10-second deadline is reached. This guards against + * timing-sensitive tests on slow systems where LocalStack may not propagate records within a + * single 200ms batch window. + */ + private void runUntilOutput(final TestRunner testRunner) throws InterruptedException { + final long deadline = System.currentTimeMillis() + 10_000; + testRunner.run(1, false, true); + while (!hasOutput(testRunner) && System.currentTimeMillis() < deadline) { + Thread.sleep(200); + testRunner.run(1, false, false); } + testRunner.run(1, true, false); + } + + private boolean hasOutput(final TestRunner testRunner) { + return !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty() + || !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE).isEmpty(); + } - final Set relationships = runner.getProcessor().getRelationships(); + private boolean hasFlowFiles(final TestRunner testRunner) { + return !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(); + } - while (true) { - runner.run(1, false, false); + private List getNewFlowFiles(final TestRunner testRunner, final int startIndex) { + final List allFlowFiles = testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + return new ArrayList<>(allFlowFiles.subList(startIndex, allFlowFiles.size())); + } - final int currentCount = relationships.stream() - .map(runner::getFlowFilesForRelationship) - .mapToInt(Collection::size) - .sum(); - logger.info("Current files count: {}, expected: {}", currentCount, expectedFlowFileCount); + private void createStream(final String streamName) { + createStream(streamName, 1); + } - if (currentCount >= expectedFlowFileCount) { + private void createStream(final String streamName, final int shardCount) { + final CreateStreamRequest request = CreateStreamRequest.builder() + .streamName(streamName) + .shardCount(shardCount) + .build(); + kinesisClient.createStream(request); + waitForStreamActive(streamName); + } + + private void waitForStreamActive(final String streamName) { + final DescribeStreamRequest request = DescribeStreamRequest.builder().streamName(streamName).build(); + for (int i = 0; i < 300; i++) { + if (kinesisClient.describeStream(request).streamDescription().streamStatus() == StreamStatus.ACTIVE) { return; } try { - Thread.sleep(5_000); + Thread.sleep(100); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for files", e); + throw new RuntimeException(e); } } - } - private MockProcessSession createFailingSession(final Processor processor) { - final SharedSessionState sharedState = new SharedSessionState(processor, new AtomicLong()); - - return MockProcessSession.builder(sharedState, processor) - .failCommit() - .build(); + throw new RuntimeException("Stream " + streamName + " did not become ACTIVE within 30 seconds"); } - public static class TestConsumeKinesis extends ConsumeKinesis { - - @Override - URI getKinesisEndpointOverride() { - return localstack.getEndpoint(); + private void publishAggregatedRecord(final String streamName, final String outerPartitionKey, + final List partitionKeys, final List jsonPayloads, final List pkIndices) { + final Messages.AggregatedRecord.Builder builder = Messages.AggregatedRecord.newBuilder(); + for (final String pk : partitionKeys) { + builder.addPartitionKeyTable(pk); } - - @Override - URI getDynamoDbEndpointOverride() { - return localstack.getEndpoint(); + for (int i = 0; i < jsonPayloads.size(); i++) { + builder.addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(pkIndices.get(i)) + .setData(ByteString.copyFromUtf8(jsonPayloads.get(i)))); } - @Override - URI getCloudwatchEndpointOverride() { - return localstack.getEndpoint(); + final byte[] protobufBytes = builder.build().toByteArray(); + final byte[] payload; + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(KplDeaggregator.KPL_MAGIC); + out.write(protobufBytes); + out.write(md5); + payload = out.toByteArray(); + } catch (final Exception e) { + throw new RuntimeException(e); } - } - - /** - * Test client wrapper for Kinesis operations with built-in retry logic and error handling. - */ - private static class TestKinesisStreamClient { - private static final Logger logger = LoggerFactory.getLogger(TestKinesisStreamClient.class); - - private static final int MAX_RETRIES = 10; - private static final long INITIAL_RETRY_DELAY_MILLIS = 1_000; - private static final long MAX_RETRY_DELAY_MILLIS = 60 * 1_000; - private static final Duration STREAM_WAIT_TIMEOUT = Duration.ofMinutes(2); - - private final KinesisClient kinesisClient; - private final String streamName; + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(outerPartitionKey) + .data(SdkBytes.fromByteArray(payload)) + .build()); + } - TestKinesisStreamClient(KinesisClient kinesisClient, String streamName) { - this.kinesisClient = kinesisClient; - this.streamName = streamName; + private void publishRecords(final String streamName, final int count) { + for (int i = 0; i < count; i++) { + publishRecord(streamName, i); } + } - void createStream(final int shardCount) { - logger.info("Creating stream: {} with {} shards", streamName, shardCount); + private void publishRecord(final String streamName, final int recordId) { + final String json = """ + {"id": %d, "name": "record-%d"}""".formatted(recordId, recordId); + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey("key-" + recordId) + .data(SdkBytes.fromUtf8String(json)) + .build()); + } - executeWithRetry( - "createStream", - () -> kinesisClient.createStream(req -> req.streamName(streamName).shardCount(shardCount))); + private void publishCorruptRecord(final String streamName, final String partitionKey, final String corruptData) { + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(partitionKey) + .data(SdkBytes.fromUtf8String(corruptData)) + .build()); + } - waitForStreamActive(); - logger.info("Stream {} is now active", streamName); + private static byte[] serializeAvroContainer(final Schema schema, final GenericRecord... records) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + final GenericDatumWriter datumWriter = new GenericDatumWriter<>(schema); + try (final DataFileWriter fileWriter = new DataFileWriter<>(datumWriter)) { + fileWriter.create(schema, baos); + for (final GenericRecord record : records) { + fileWriter.append(record); + } + } + return baos.toByteArray(); + } catch (final IOException e) { + throw new RuntimeException(e); } + } - List getEnhancedFanOutConsumerNames() { - final String arn = describeStream().streamARN(); + private void publishAvroRecord(final String streamName, final String partitionKey, final Schema schema, + final GenericRecord record) { + final byte[] avroBytes = serializeAvroContainer(schema, record); + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(partitionKey) + .data(SdkBytes.fromByteArray(avroBytes)) + .build()); + } - final ListStreamConsumersResponse response = executeWithRetry( - "listStreamConsumers", - () -> kinesisClient.listStreamConsumers(req -> req.streamARN(arn)) - ); + private void configureAvroRecordOriented(final String streamName) throws Exception { + final AvroReader avroReader = new AvroReader(); + runner.addControllerService("avro-reader", avroReader); + runner.enableControllerService(avroReader); - return response.consumers().stream() - .map(Consumer::consumerName) - .toList(); - } + final AvroRecordSetWriter avroWriter = new AvroRecordSetWriter(); + runner.addControllerService("avro-writer", avroWriter); + runner.enableControllerService(avroWriter); - private StreamDescription describeStream() { - final DescribeStreamResponse response = executeWithRetry( - "describeStream", - () -> kinesisClient.describeStream(req -> req.streamName(streamName)) - ); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "avro-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "avro-writer"); + } - return response.streamDescription(); + static class FailingRecordSetWriterFactory extends AbstractControllerService implements RecordSetWriterFactory { + @Override + public RecordSchema getSchema(final Map variables, final RecordSchema readSchema) { + return readSchema; } - void deleteStream() { - logger.info("Deleting stream: {}", streamName); - - executeWithRetry( - "deleteStream", - () -> kinesisClient.deleteStream(req -> req.streamName(streamName).enforceConsumerDeletion(true))); + @Override + public RecordSetWriter createWriter(final ComponentLog logger, final RecordSchema schema, + final OutputStream out, final Map variables) { + return new FailingRecordSetWriter(); } + } - void putRecord(final String partitionKey, final String data) { - final SdkBytes bytes = SdkBytes.fromString(data, UTF_8); - - executeWithRetry( - "putRecord", - () -> kinesisClient.putRecord(req -> req.streamName(streamName).partitionKey(partitionKey).data(bytes))); + private static class FailingRecordSetWriter implements RecordSetWriter { + @Override + public WriteResult write(final Record record) throws IOException { + throw new IOException("Simulated write failure for rollback testing"); } - void putRecords(final String partitionKey, final List data) { - final List records = data.stream() - .map(it -> PutRecordsRequestEntry.builder() - .data(SdkBytes.fromString(it, UTF_8)) - .partitionKey(partitionKey) - .build()) - .toList(); - - executeWithRetry( - "putRecords", - () -> kinesisClient.putRecords(req -> req.streamName(streamName).records(records))); + @Override + public WriteResult write(final RecordSet recordSet) throws IOException { + throw new IOException("Simulated write failure for rollback testing"); } - /** - * Adjusts a number of shards for the stream. - * Note: in order to ensure new shards become active, the method waits for 30 seconds. - */ - void reshardStream(final int targetShardCount) { - logger.info("Resharding stream {} to {} shards", streamName, targetShardCount); + @Override + public void beginRecordSet() { + } - executeWithRetry( - "reshardStream", - () -> kinesisClient.updateShardCount(req -> req.streamName(streamName).targetShardCount(targetShardCount).scalingType(ScalingType.UNIFORM_SCALING))); + @Override + public WriteResult finishRecordSet() { + return WriteResult.of(0, Map.of()); + } - waitForStreamActive(); + @Override + public String getMimeType() { + return "application/json"; + } - try { - // After resharding new messages can still be put into the older shards for some time, so we wait a bit. - Thread.sleep(30_000); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } + @Override + public void flush() { + } - logger.info("Stream {} resharding completed", streamName); - } - - private void waitForStreamActive() { - final long timeoutMillis = System.currentTimeMillis() + STREAM_WAIT_TIMEOUT.toMillis(); - - while (System.currentTimeMillis() < timeoutMillis) { - try { - final StreamStatus status = describeStream().streamStatus(); - if (status == StreamStatus.ACTIVE) { - return; - } - - logger.info("Stream {} status: {}, waiting...", streamName, status); - Thread.sleep(1000); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for stream to be active", e); - } catch (final RuntimeException e) { - logger.warn("Error checking stream status for {}: {}", streamName, e.getMessage()); - try { - Thread.sleep(1000); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for stream to be active", ie); - } - } - } + @Override + public void close() { + } + } - throw new IllegalStateException("Stream " + streamName + " did not become active within timeout"); - } - - private T executeWithRetry(final String operation, final Callable op) { - Exception lastException = null; - - for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { - try { - return op.call(); - } catch (final Exception e) { - lastException = e; - logger.warn("Attempt {} of {} failed for operation {}: {}", - attempt, MAX_RETRIES, operation, e.getMessage()); - - if (attempt < MAX_RETRIES) { - try { - final long delayMillis = INITIAL_RETRY_DELAY_MILLIS * (1 << (attempt - 1)); - Thread.sleep(Math.min(delayMillis, MAX_RETRY_DELAY_MILLIS)); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted during retry delay", ie); - } - } - } - } + public static class FastTimingConsumeKinesis extends ConsumeKinesis { + private static final long TEST_SHARD_CACHE_MILLIS = 500; + private static final long TEST_LEASE_DURATION_MILLIS = 3_000; + private static final long TEST_LEASE_REFRESH_INTERVAL_MILLIS = 1_000; + private static final long TEST_NODE_HEARTBEAT_EXPIRATION_MILLIS = 4_000; - throw new IllegalStateException("Operation " + operation + " failed after " + MAX_RETRIES + " attempts", lastException); + @Override + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return new KinesisShardManager( + kinesisClient, + dynamoDbClient, + logger, + checkpointTableName, + streamName, + TEST_SHARD_CACHE_MILLIS, + TEST_LEASE_DURATION_MILLIS, + TEST_LEASE_REFRESH_INTERVAL_MILLIS, + TEST_NODE_HEARTBEAT_EXPIRATION_MILLIS); } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java index 1024896b2c9c..478648e4b51b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java @@ -16,85 +16,512 @@ */ package org.apache.nifi.processors.aws.kinesis; +import org.apache.nifi.json.JsonRecordSetWriter; +import org.apache.nifi.json.JsonTreeReader; +import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderControllerService; -import org.apache.nifi.processors.aws.region.RegionUtil; -import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.PropertyMigrationResult; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Shard; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.PROCESSING_STRATEGY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.DEMARCATOR; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.FLOW_FILE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.RECORD; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_PARSE_FAILURE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_SUCCESS; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; class ConsumeKinesisTest { - private TestRunner testRunner; + private TestRunner runner; @BeforeEach - void setUp() { - testRunner = createTestRunner(); + void setUp() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + final JsonTreeReader reader = new JsonTreeReader(); + runner.addControllerService("json-reader", reader); + runner.enableControllerService(reader); + + final JsonRecordSetWriter writer = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", writer); + runner.enableControllerService(writer); + } + + private void setCommonProperties() throws Exception { + final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); + runner.addControllerService("creds", credentialsService); + runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, "AK_STUB"); + runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, "SK_STUB"); + runner.enableControllerService(credentialsService); + + runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "test-app"); + runner.setProperty(ConsumeKinesis.STREAM_NAME, "test-stream"); + runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "creds"); } @Test - void getRelationshipsForFlowFileProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, FLOW_FILE); + void testProcessingStrategyValidation() throws Exception { + setCommonProperties(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runner.assertValid(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "LINE_DELIMITED"); + runner.assertValid(); - final Set relationships = testRunner.getProcessor().getRelationships(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.assertNotValid(); - assertEquals(Set.of(REL_SUCCESS), relationships); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.assertNotValid(); + + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.assertValid(); } @Test - void getRelationshipsForRecordProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, RECORD); + void testAllValidRecordsRoutedToSuccess() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "{\"name\":\"Charlie\"}")); - final Set relationships = testRunner.getProcessor().getRelationships(); + triggerWithRecords(records); - assertEquals(Set.of(REL_SUCCESS, REL_PARSE_FAILURE), relationships); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "3"); } @Test - void getRelationshipsForDemarcatorProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, DEMARCATOR); + void testSingleInvalidRecordRoutedToParseFailure() throws Exception { + assertInvalidRecordAtPosition("1", "THIS IS NOT JSON", + testRecord("1", "THIS IS NOT JSON"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "{\"name\":\"Charlie\"}")); + assertInvalidRecordAtPosition("2", "CORRUPT DATA HERE", + testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "CORRUPT DATA HERE"), testRecord("3", "{\"name\":\"Charlie\"}")); + assertInvalidRecordAtPosition("3", "NOT VALID JSON!!!", + testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "NOT VALID JSON!!!")); + } + + @Test + void testMultipleInvalidRecordsInBatch() throws Exception { + final List records = List.of( + testRecord("1", "BAD FIRST"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "BAD THIRD"), + testRecord("4", "{\"name\":\"Dave\"}"), + testRecord("5", "BAD FIFTH")); + + triggerWithRecords(records); - final Set relationships = testRunner.getProcessor().getRelationships(); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 3); - assertEquals(Set.of(REL_SUCCESS), relationships); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + + final List failures = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + final List failureSequences = new ArrayList<>(); + for (final MockFlowFile flowFile : failures) { + failureSequences.add(flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + } + assertTrue(failureSequences.contains("1")); + assertTrue(failureSequences.contains("3")); + assertTrue(failureSequences.contains("5")); } - private static TestRunner createTestRunner() { - final TestRunner runner = TestRunners.newTestRunner(ConsumeKinesis.class); + @Test + void testAllInvalidRecordsRoutedToParseFailure() throws Exception { + final List records = List.of( + testRecord("1", "BAD1"), + testRecord("2", "BAD2"), + testRecord("3", "BAD3")); - final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); - try { - runner.addControllerService("credentials", credentialsService); - } catch (final InitializationException e) { - throw new RuntimeException(e); + triggerWithRecords(records); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 3); + } + + @Test + void testFlowFilePerRecordDeliversAllRecords() throws Exception { + final List records = List.of( + testRecord("1", "record-one"), + testRecord("2", "record-two"), + testRecord("3", "record-three")); + + triggerWithStrategy(records, "FLOW_FILE", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 3); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + for (final MockFlowFile flowFile : flowFiles) { + flowFile.assertAttributeEquals("record.count", "1"); + flowFile.assertAttributeEquals(ConsumeKinesis.ATTR_STREAM_NAME, "test-stream"); + flowFile.assertAttributeEquals(ConsumeKinesis.ATTR_SHARD_ID, "shardId-000000000001"); + final String firstSequence = flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE); + final String lastSequence = flowFile.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE); + assertEquals(firstSequence, lastSequence); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_PARTITION_KEY)); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SUBSEQUENCE)); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_LAST_SUBSEQUENCE)); } - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, "123"); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, "123"); - runner.enableControllerService(credentialsService); - runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "credentials"); - runner.setProperty(ConsumeKinesis.STREAM_NAME, "stream"); - runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "application"); - runner.setProperty(RegionUtil.REGION, "us-west-2"); - runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, ConsumeKinesis.InitialPosition.TRIM_HORIZON); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.FLOW_FILE); + flowFiles.get(0).assertContentEquals("record-one"); + flowFiles.get(1).assertContentEquals("record-two"); + flowFiles.get(2).assertContentEquals("record-three"); + } + + @Test + void testDemarcatorDeliversAllRecords() throws Exception { + final List records = List.of( + testRecord("1", "line-one"), + testRecord("2", "line-two"), + testRecord("3", "line-three")); + + triggerWithStrategy(records, "LINE_DELIMITED", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertContentEquals("line-one\nline-two\nline-three"); + success.assertAttributeEquals("record.count", "3"); + } + + @Test + void testMultipleShardsNoDataLoss() throws Exception { + final ShardFetchResult shard1Result = new ShardFetchResult("shard-A", + List.of(testRecord("10", "{\"id\":1}"), testRecord("20", "{\"id\":2}")), 0L); + final ShardFetchResult shard2Result = new ShardFetchResult("shard-B", + List.of(testRecord("30", "{\"id\":3}"), testRecord("40", "{\"id\":4}")), 0L); + + triggerWithResults(List.of(shard1Result, shard2Result), "RECORD"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 2); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final Set shardsSeen = new LinkedHashSet<>(); + long totalRecords = 0; + for (final MockFlowFile flowFile : flowFiles) { + shardsSeen.add(flowFile.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + totalRecords += Long.parseLong(flowFile.getAttribute("record.count")); + } + assertEquals(Set.of("shard-A", "shard-B"), shardsSeen); + assertEquals(4, totalRecords); + } + + @Test + void testRecordMetadataInjectionPreservesRecordCount() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "{\"name\":\"Charlie\"}")); + + triggerWithOutputStrategy(records, "INJECT_METADATA"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "3"); + + final String content = success.getContent(); + assertTrue(content.contains("kinesisMetadata")); + assertTrue(content.contains("\"stream\"")); + assertTrue(content.contains("\"shardId\"")); + assertTrue(content.contains("\"sequenceNumber\"")); + assertTrue(content.contains("\"partitionKey\"")); + } + + @Test + void testUseWrapperOutputStrategy() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}")); + + triggerWithOutputStrategy(records, "USE_WRAPPER"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + + final String content = success.getContent(); + assertTrue(content.contains("kinesisMetadata")); + assertTrue(content.contains("value")); + assertTrue(content.contains("Alice")); + } + + @Test + void testAtTimestampInitialPositionRequiresTimestamp() throws Exception { + setCommonProperties(); + runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, "AT_TIMESTAMP"); + runner.assertNotValid(); + + runner.setProperty(ConsumeKinesis.STREAM_POSITION_TIMESTAMP, "2025-01-15T00:00:00Z"); + runner.assertValid(); + } + + @Test + void testPropertyMigrationRenamesMaxBytesToBuffer() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + setCommonProperties(); + runner.setProperty("Max Bytes to Buffer", "5 MB"); + + final PropertyMigrationResult result = runner.migrateProperties(); + assertTrue(result.getPropertiesRenamed().containsKey("Max Bytes to Buffer")); + assertEquals("Max Batch Size", result.getPropertiesRenamed().get("Max Bytes to Buffer")); + assertEquals("5 MB", runner.getProcessContext().getProperty(ConsumeKinesis.MAX_BATCH_SIZE).getValue()); + } + + @Test + void testPropertyMigrationRemovesCheckpointInterval() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + setCommonProperties(); + runner.setProperty("Checkpoint Interval", "5 min"); + + final PropertyMigrationResult result = runner.migrateProperties(); + assertTrue(result.getPropertiesRemoved().contains("Checkpoint Interval")); + } + + @Test + void testDynamicRelationships() throws Exception { + setCommonProperties(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + assertEquals(Set.of("success"), collectRelationshipNames()); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + assertEquals(Set.of("success", "parse.failure"), collectRelationshipNames()); + } + + @Test + void testEmptyRecordDoesNotCauseStuckState() throws Exception { + final DeaggregatedRecord emptyRecord = new DeaggregatedRecord("shardId-000000000001", "2", 0, "pk-2", new byte[0], Instant.now()); + + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + emptyRecord, + testRecord("3", "{\"name\":\"Charlie\"}")); + + triggerWithRecords(records); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); - runner.setProperty(ConsumeKinesis.METRICS_PUBLISHING, ConsumeKinesis.MetricsPublishing.CLOUDWATCH); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + } + + private void triggerWithRecords(final List records) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); + final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); - runner.setProperty(ConsumeKinesis.MAX_BYTES_TO_BUFFER, "10 MB"); + runner.run(); + } + + private Set collectRelationshipNames() { + final Set names = new LinkedHashSet<>(); + for (final Relationship relationship : runner.getProcessor().getRelationships()) { + names.add(relationship.getName()); + } + return names; + } + + private void assertInvalidRecordAtPosition(final String expectedFailureSequence, final String expectedFailureContent, + final DeaggregatedRecord... records) throws Exception { + triggerWithRecords(List.of(records)); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 1); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + + final MockFlowFile failure = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE).getFirst(); + failure.assertContentEquals(expectedFailureContent); + failure.assertAttributeEquals(ConsumeKinesis.ATTR_FIRST_SEQUENCE, expectedFailureSequence); + } + + private void triggerWithOutputStrategy(final List records, final String outputStrategy) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); + final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, outputStrategy); + + runner.run(); + } + + private void triggerWithStrategy(final List records, final String processingStrategy, + final String shardId) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager(shardId); + final ShardFetchResult fetchResult = new ShardFetchResult(shardId, records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, processingStrategy); + + runner.run(); + } + + private void triggerWithResults(final List results, final String processingStrategy) throws Exception { + final Set shardIds = new LinkedHashSet<>(); + for (final ShardFetchResult fetchResult : results) { + shardIds.add(fetchResult.shardId()); + } - return runner; + final KinesisShardManager mockShardManager = buildShardManager(shardIds.toArray(new String[0])); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, results); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, processingStrategy); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + + runner.run(); + } + + private static KinesisShardManager buildShardManager(final String... shardIds) { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + final List shards = new ArrayList<>(); + for (final String id : shardIds) { + shards.add(Shard.builder().shardId(id).build()); + } + when(mockShardManager.getOwnedShards()).thenReturn(shards); + when(mockShardManager.getCachedShardCount()).thenReturn(shardIds.length); + when(mockShardManager.shouldProcessFetchedResult(anyString())).thenReturn(true); + return mockShardManager; + } + + private static DeaggregatedRecord testRecord(final String sequenceNumber, final String data) { + return new DeaggregatedRecord( + "shardId-000000000001", + sequenceNumber, + 0, + "pk-" + sequenceNumber, + data.getBytes(StandardCharsets.UTF_8), + Instant.now()); + } + + static class TestableConsumeKinesis extends ConsumeKinesis { + private final KinesisShardManager mockShardManager; + private final List preloadedResults; + + TestableConsumeKinesis(final KinesisShardManager mockShardManager, final ShardFetchResult preloadedResult) { + this(mockShardManager, List.of(preloadedResult)); + } + + TestableConsumeKinesis(final KinesisShardManager mockShardManager, final List preloadedResults) { + this.mockShardManager = mockShardManager; + this.preloadedResults = preloadedResults; + } + + @Override + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return mockShardManager; + } + + @Override + protected KinesisConsumerClient createConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger, + final boolean efoMode) { + final KinesisConsumerClient client = new StubConsumerClient(mock(KinesisClient.class), logger); + for (final ShardFetchResult result : preloadedResults) { + client.enqueueResult(result); + } + return client; + } + } + + static class StubConsumerClient extends KinesisConsumerClient { + StubConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger) { + super(kinesisClient, logger); + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, + final String initialStreamPosition, final KinesisShardManager shardManager) { + } + + @Override + boolean hasPendingFetches() { + return hasQueuedResults(); + } + + @Override + void acknowledgeResults(final List results) { + } + + @Override + void rollbackResults(final List results) { + } + + @Override + void removeUnownedShards(final Set ownedShards) { + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java deleted file mode 100644 index e0f3a40e47ee..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.util.MockFlowFile; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - -final class JsonRecordAssert { - - private static final ObjectMapper MAPPER = new ObjectMapper(); - - static void assertFlowFileRecords(final FlowFile flowFile, final KinesisClientRecord... expectedRecords) { - assertFlowFileRecords(flowFile, List.of(expectedRecords)); - } - - static void assertFlowFileRecords(final FlowFile flowFile, final List expectedRecords) { - final List expectedPayloads = expectedRecords.stream() - .map(KinesisRecordPayload::extract) - .toList(); - assertFlowFileRecordPayloads(flowFile, expectedPayloads); - } - - static void assertFlowFileRecordPayloads(final FlowFile flowFile, final String... expectedPayloads) { - assertFlowFileRecordPayloads(flowFile, List.of(expectedPayloads)); - } - - static void assertFlowFileRecordPayloads(final FlowFile flowFile, final List expectedPayloads) { - try { - final MockFlowFile mockFlowFile = assertInstanceOf(MockFlowFile.class, flowFile, "A passed FlowFile should be an instance of MockFlowFile"); - - final JsonNode node = MAPPER.readTree(mockFlowFile.getContent()); - final ArrayNode array = assertInstanceOf(ArrayNode.class, node, "FlowFile content is expected to be an array"); - - assertEquals(expectedPayloads.size(), array.size(), "Array size mismatch"); - - for (int i = 0; i < expectedPayloads.size(); i++) { - final JsonNode recordNode = array.get(i); - final JsonNode expectedNode = MAPPER.readTree(expectedPayloads.get(i)); - assertEquals(recordNode, expectedNode, "Record at index " + i + " does not match expected JSON"); - } - - } catch (final JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - private JsonRecordAssert() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java new file mode 100644 index 000000000000..0c48f0099d02 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java @@ -0,0 +1,600 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KinesisConsumerClientTest { + + /** + * Verifies that when an EFO subscription expires and is renewed, the renewal uses the + * maximum of lastAcknowledged, lastQueued, and DynamoDB checkpoint. Here the acknowledged + * sequence (99999) exceeds the checkpoint (11111), so the renewal should use 99999. + */ + @Test + void testSubscriptionRenewalUsesLastAcknowledgedSequenceNumber() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("11111"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(1, capturedRequests.size()); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, capturedRequests.get(0).startingPosition().type()); + assertEquals("11111", capturedRequests.get(0).startingPosition().sequenceNumber(), + "Initial subscription should use the DynamoDB checkpoint"); + + simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", "99999"); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, capturedRequests.get(1).startingPosition().type()); + assertEquals("99999", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastAcknowledged, checkpoint) = lastAcknowledged"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that when a ShardConsumer has no acknowledged data, the renewal falls back to + * the DynamoDB checkpoint. + */ + @Test + void testSubscriptionRenewalFallsBackToCheckpointWhenNoQueuedData() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("55555"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", null); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("55555", capturedRequests.get(0).startingPosition().sequenceNumber()); + assertEquals("55555", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal with no acknowledged data should fall back to DynamoDB checkpoint"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that acknowledged sequence tracking is monotonic. If acknowledgements are observed + * out-of-order for a shard, renewal must use the highest acknowledged sequence. + */ + @Test + void testSubscriptionRenewalUsesHighestAcknowledgedSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("10000"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + client.acknowledgeResults(List.of( + shardFetchResult("shardId-000000000001", "20000"), + shardFetchResult("shardId-000000000001", "15000"))); + + simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", null); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("20000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use the highest acknowledged sequence"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that renewal always uses the maximum of lastQueued, lastAcknowledged, and + * the DynamoDB checkpoint. This prevents one-event replay duplicates caused by races + * between onNext counter updates and concurrent handler onError callbacks. + */ + @Test + void testSubscriptionRenewalAlwaysUsesMaxSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000", "70000", 2); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=90000, lastAcked=70000, checkpoint=50000) = 90000"); + + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "95000", "80000", 0); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(3, capturedRequests.size()); + assertEquals("95000", capturedRequests.get(2).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=95000, lastAcked=80000, checkpoint=50000) = 95000"); + + verify(mockShardManager, times(3)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that polling a queued result does not affect the renewal position. The renewal + * uses max(lastQueued, lastAcknowledged, checkpoint) regardless of whether results have + * been polled or acknowledged. + */ + @Test + void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000", "70000", 1); + client.enqueueResult(shardFetchResult("shardId-000000000001", "90000")); + + final ShardFetchResult polled = client.pollShardResult("shardId-000000000001"); + assertNotNull(polled, "Expected queued result to be polled"); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=90000, lastAcked=70000, checkpoint=50000) = 90000"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that acknowledging multiple fetched results from the same shard requests only one + * additional EFO event for that shard. + */ + @Test + void testAcknowledgeResultsRequestsNextOncePerShard() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EfoKinesisClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + final Object shardConsumer = getShardConsumer(client, "shardId-000000000001"); + final Subscription subscription = mock(Subscription.class); + setField(shardConsumer, "subscription", subscription); + setField(shardConsumer, "queuedResultCount", new AtomicInteger(2)); + + client.acknowledgeResults(List.of( + shardFetchResult("shardId-000000000001", "60000"), + shardFetchResult("shardId-000000000001", "61000"))); + + verify(subscription, times(1)).request(1); + } + + /** + * Verifies that concurrent startup calls do not create duplicate initial subscriptions + * for the same shard. + */ + @Test + void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn(null); + + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + final List capturedRequests = new ArrayList<>(); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenAnswer(invocation -> { + capturedRequests.add(invocation.getArgument(0)); + Thread.sleep(50L); + return new CompletableFuture<>(); + }); + + final EfoKinesisClient client = new EfoKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + setField(client, "kinesisAsyncClient", mockAsyncClient); + setField(client, "consumerArn", "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + final CountDownLatch startLatch = new CountDownLatch(1); + + final Thread t1 = new Thread(() -> runStartFetches(client, shards, mockShardManager, startLatch)); + final Thread t2 = new Thread(() -> runStartFetches(client, shards, mockShardManager, startLatch)); + t1.start(); + t2.start(); + startLatch.countDown(); + t1.join(5_000L); + t2.join(5_000L); + + assertEquals(1, capturedRequests.size(), + "Concurrent startup should create only one initial SubscribeToShard request per shard"); + } + + private static EfoKinesisClient createEfoClient(final List capturedRequests) throws Exception { + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenAnswer(invocation -> { + capturedRequests.add(invocation.getArgument(0)); + return CompletableFuture.completedFuture(null); + }); + + final EfoKinesisClient client = new EfoKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + setField(client, "kinesisAsyncClient", mockAsyncClient); + setField(client, "consumerArn", "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); + return client; + } + + /** + * After the initial subscription is created with a completed future (immediately expired), + * this method configures the ShardConsumer so it appears ready for renewal: sets + * lastAcknowledgedSequenceNumber, clears the subscribing guard, and resets the backoff timer. + * + * @param client the consumer client whose internal ShardConsumer state is being configured + * @param shardId the shard whose ShardConsumer to modify + * @param lastAckedSeq the sequence number to set, or null to leave unset + */ + @SuppressWarnings("unchecked") + private static void simulateExpiredSubscriptionWithAcknowledgedData( + final EfoKinesisClient client, final String shardId, final String lastAckedSeq) throws Exception { + final Field consumersField = EfoKinesisClient.class.getDeclaredField("shardConsumers"); + consumersField.setAccessible(true); + final Map consumers = (Map) consumersField.get(client); + final Object shardConsumer = consumers.get(shardId); + final Class scClass = shardConsumer.getClass(); + + setField(shardConsumer, scClass, "subscribing", new AtomicBoolean(false)); + setField(shardConsumer, scClass, "lastSubscribeAttemptNanos", 0L); + if (lastAckedSeq != null) { + getAtomicRef(shardConsumer, scClass, "lastAcknowledgedSequenceNumber").set(lastAckedSeq); + } + } + + /** + * Verifies that a stale error callback from an old subscription does not corrupt the state + * of a newer subscription. This tests the generation counter mechanism that prevents a race + * between the response handler's onError and the subscriber's onError when they fire on + * different threads for the same error, with a new subscription created in between. + * + *

The race without the generation counter: + *

    + *
  1. Response handler onError fires on Netty thread: subscribing = false
  2. + *
  3. NiFi thread creates new subscription B: subscribing = true
  4. + *
  5. Old subscriber onError fires: nulls B's subscription, subscribing = false
  6. + *
  7. B's state is corrupted, leading to duplicate subscriptions
  8. + *
+ */ + @Test + void testStaleErrorCallbackDoesNotCorruptNewSubscription() throws Exception { + final ComponentLog mockLogger = mock(ComponentLog.class); + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + final EfoKinesisClient.ShardConsumer consumer = + new EfoKinesisClient.ShardConsumer("shardId-000000000001", result -> { }, mockLogger); + + final StartingPosition pos = StartingPosition.builder() + .type(ShardIteratorType.TRIM_HORIZON) + .build(); + + consumer.subscribe(mockAsyncClient, "test-arn", pos); + final int gen1 = getGeneration(consumer); + assertEquals(1, gen1); + + simulateSubscriberActive(consumer); + + final Subscription gen1Subscription = getSubscription(consumer); + assertNotNull(gen1Subscription, "Subscription should be set after onSubscribe"); + + endSubscriptionIfCurrent(consumer, gen1); + assertFalse(getSubscribing(consumer), "subscribing should be false after endSubscription"); + + consumer.subscribe(mockAsyncClient, "test-arn", pos); + final int gen2 = getGeneration(consumer); + assertEquals(2, gen2); + + simulateSubscriberActive(consumer); + + final Subscription gen2Subscription = getSubscription(consumer); + assertNotNull(gen2Subscription, "New subscription should be set"); + + endSubscriptionIfCurrent(consumer, gen1); + + assertNotNull(getSubscription(consumer), + "Stale callback (gen1) must NOT null out gen2's subscription"); + assertTrue(getSubscribing(consumer), + "Stale callback (gen1) must NOT reset gen2's subscribing flag"); + + endSubscriptionIfCurrent(consumer, gen2); + + assertFalse(getSubscribing(consumer), + "Current-generation callback should clean up normally"); + } + + private static void simulateSubscriberActive(final Object shardConsumer) throws Exception { + final Subscription mockSub = mock(Subscription.class); + final Field subField = shardConsumer.getClass().getDeclaredField("subscription"); + subField.setAccessible(true); + subField.set(shardConsumer, mockSub); + } + + private static Subscription getSubscription(final Object shardConsumer) throws Exception { + final Field field = shardConsumer.getClass().getDeclaredField("subscription"); + field.setAccessible(true); + return (Subscription) field.get(shardConsumer); + } + + private static boolean getSubscribing(final Object shardConsumer) throws Exception { + final Field field = shardConsumer.getClass().getDeclaredField("subscribing"); + field.setAccessible(true); + return ((AtomicBoolean) field.get(shardConsumer)).get(); + } + + private static int getGeneration(final Object shardConsumer) throws Exception { + final Field field = shardConsumer.getClass().getDeclaredField("subscriptionGeneration"); + field.setAccessible(true); + return ((AtomicInteger) field.get(shardConsumer)).get(); + } + + private static void endSubscriptionIfCurrent(final Object shardConsumer, final int generation) throws Exception { + final java.lang.reflect.Method method = shardConsumer.getClass().getDeclaredMethod("endSubscriptionIfCurrent", int.class); + method.setAccessible(true); + method.invoke(shardConsumer, generation); + } + + private static void setField(final Object target, final String fieldName, final Object value) throws Exception { + setField(target, target.getClass(), fieldName, value); + } + + private static void setField(final Object target, final Class clazz, final String fieldName, final Object value) throws Exception { + final Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); + } + + @SuppressWarnings("unchecked") + private static AtomicReference getAtomicRef( + final Object target, final Class clazz, final String fieldName) throws Exception { + final Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + return (AtomicReference) field.get(target); + } + + @SuppressWarnings("unchecked") + private static void simulateExpiredSubscriptionWithState( + final EfoKinesisClient client, + final String shardId, + final String lastQueuedSeq, + final String lastAckedSeq, + final int queuedCount) throws Exception { + final Field consumersField = EfoKinesisClient.class.getDeclaredField("shardConsumers"); + consumersField.setAccessible(true); + final Map consumers = (Map) consumersField.get(client); + final Object shardConsumer = consumers.get(shardId); + final Class scClass = shardConsumer.getClass(); + + setField(shardConsumer, scClass, "subscribing", new AtomicBoolean(false)); + setField(shardConsumer, scClass, "lastSubscribeAttemptNanos", 0L); + setField(shardConsumer, scClass, "lastQueuedSequenceNumber", lastQueuedSeq); + getAtomicRef(shardConsumer, scClass, "lastAcknowledgedSequenceNumber").set(lastAckedSeq); + setField(shardConsumer, scClass, "queuedResultCount", new AtomicInteger(queuedCount)); + } + + private static ShardFetchResult shardFetchResult(final String shardId, final String sequenceNumber) { + final DeaggregatedRecord record = new DeaggregatedRecord(shardId, sequenceNumber, 0, "pk", "{}".getBytes(), null); + return new ShardFetchResult(shardId, List.of(record), 0L); + } + + @SuppressWarnings("unchecked") + private static Object getShardConsumer(final EfoKinesisClient client, final String shardId) throws Exception { + final Field consumersField = EfoKinesisClient.class.getDeclaredField("shardConsumers"); + consumersField.setAccessible(true); + final Map consumers = (Map) consumersField.get(client); + return consumers.get(shardId); + } + + /** + * Verifies the per-shard exclusive processing lock: claimShard returns false when a shard is + * already claimed, and releaseShards makes it claimable again. Different shards are independent. + */ + @Test + void testShardClaimExclusivity() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + assertTrue(client.claimShard("shard-1"), "First claim should succeed"); + assertFalse(client.claimShard("shard-1"), "Duplicate claim should fail"); + assertTrue(client.claimShard("shard-2"), "Different shard claim should succeed independently"); + + client.releaseShards(List.of("shard-1")); + assertTrue(client.claimShard("shard-1"), "Claim after release should succeed"); + assertFalse(client.claimShard("shard-2"), "shard-2 was not released and should still be claimed"); + } + + /** + * Verifies that enqueueResult stores results in per-shard queues and that + * pollShardResult retrieves them in FIFO order. + */ + @Test + void testEnqueueResultIsPolledFromShardQueue() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + final ShardFetchResult result = shardFetchResult("shard-1", "12345"); + + client.enqueueResult(result); + + final ShardFetchResult polled = client.pollShardResult("shard-1"); + assertNotNull(polled, "Enqueued result should be available for polling from its shard queue"); + assertEquals("shard-1", polled.shardId()); + } + + /** + * Verifies that per-shard queues preserve FIFO order. When multiple results for the + * same shard are enqueued (potentially interleaved with other shards), polling that + * shard's queue must return them in enqueue order. This prevents the out-of-order + * delivery that occurred with the old flat-queue requeue-to-back design. + */ + @Test + void testPerShardOrderingPreservedAcrossEnqueues() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + client.enqueueResult(shardFetchResult("shard-5", "100")); + client.enqueueResult(shardFetchResult("shard-3", "500")); + client.enqueueResult(shardFetchResult("shard-5", "200")); + client.enqueueResult(shardFetchResult("shard-3", "600")); + client.enqueueResult(shardFetchResult("shard-5", "300")); + + assertEquals("100", client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals("200", client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals("300", client.pollShardResult("shard-5").firstSequenceNumber()); + assertNull(client.pollShardResult("shard-5"), "Queue should be empty after draining"); + + assertEquals("500", client.pollShardResult("shard-3").firstSequenceNumber()); + assertEquals("600", client.pollShardResult("shard-3").firstSequenceNumber()); + assertNull(client.pollShardResult("shard-3"), "Queue should be empty after draining"); + } + + /** + * Verifies that getShardIdsWithResults returns only shards with non-empty queues, + * and that totalQueuedResults reflects the actual count across all shard queues. + */ + @Test + void testQueueIntrospectionMethods() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + assertFalse(client.hasQueuedResults()); + assertEquals(0, client.totalQueuedResults()); + assertTrue(client.getShardIdsWithResults().isEmpty()); + + client.enqueueResult(shardFetchResult("shard-1", "10")); + client.enqueueResult(shardFetchResult("shard-2", "20")); + client.enqueueResult(shardFetchResult("shard-1", "30")); + + assertTrue(client.hasQueuedResults()); + assertEquals(3, client.totalQueuedResults()); + assertEquals(Set.of("shard-1", "shard-2"), client.getShardIdsWithResults()); + + client.pollShardResult("shard-1"); + client.pollShardResult("shard-1"); + + assertEquals(1, client.totalQueuedResults()); + assertEquals(Set.of("shard-2"), client.getShardIdsWithResults()); + } + + /** + * Reproduces the scenario from the original out-of-order delivery bug. With a single + * flat queue and requeue-to-back, result B would end up behind C after being requeued, + * causing C to be delivered first. Per-shard queues eliminate this: when Task-2 cannot + * claim shard-1, it simply skips the shard queue. B and C remain in FIFO order for the + * next consumer. + */ + @Test + void testPerShardQueuesPreventOutOfOrderDeliveryAcrossInvocations() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + final ShardFetchResult resultA = shardFetchResult("shard-1", "100"); + final ShardFetchResult resultB = shardFetchResult("shard-1", "200"); + final ShardFetchResult resultC = shardFetchResult("shard-1", "300"); + final ShardFetchResult otherShardResult = shardFetchResult("shard-2", "999"); + + client.enqueueResult(resultA); + client.enqueueResult(resultB); + client.enqueueResult(otherShardResult); + client.enqueueResult(resultC); + + // Task-1: claims shard-1, polls A(100) from its per-shard queue + assertTrue(client.claimShard("shard-1"), "Task-1 should claim shard-1"); + final ShardFetchResult polledA = client.pollShardResult("shard-1"); + assertEquals("100", polledA.firstSequenceNumber()); + + // Task-2: cannot claim shard-1 (held by Task-1), so it skips shard-1 entirely. + // B and C remain at the head of shard-1's queue, undisturbed. + assertFalse(client.claimShard("shard-1"), "Task-2 cannot claim shard-1 (held by Task-1)"); + + // Task-2: claims shard-2 and polls its result + assertTrue(client.claimShard("shard-2")); + final ShardFetchResult polledOther = client.pollShardResult("shard-2"); + assertEquals("999", polledOther.firstSequenceNumber()); + + // Task-1 commits A and releases shard-1 + client.releaseShards(List.of("shard-1")); + + // Next invocation: shard-1's queue still has B(200) then C(300) in correct order + final ShardFetchResult firstPoll = client.pollShardResult("shard-1"); + final ShardFetchResult secondPoll = client.pollShardResult("shard-1"); + + assertNotNull(firstPoll, "Expected shard-1 queue to have B"); + assertNotNull(secondPoll, "Expected shard-1 queue to have C"); + assertEquals("200", firstPoll.firstSequenceNumber(), "First poll must be B(200), not C(300)"); + assertEquals("300", secondPoll.firstSequenceNumber(), "Second poll must be C(300)"); + assertNull(client.pollShardResult("shard-1"), "shard-1 queue should be empty"); + } + + private static void runStartFetches(final KinesisConsumerClient client, final List shards, + final KinesisShardManager shardManager, final CountDownLatch startLatch) { + try { + startLatch.await(); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", shardManager); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java deleted file mode 100644 index 8a03d71d244f..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import static java.nio.charset.StandardCharsets.UTF_8; - -final class KinesisRecordPayload { - - static String extract(final KinesisClientRecord record) { - record.data().rewind(); - - final byte[] buffer = new byte[record.data().remaining()]; - record.data().get(buffer); - - return new String(buffer, UTF_8); - } - - private KinesisRecordPayload() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java new file mode 100644 index 000000000000..2fdc6df5c639 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; +import software.amazon.awssdk.services.dynamodb.model.CreateTableResponse; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableResponse; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.TableDescription; +import software.amazon.awssdk.services.dynamodb.model.TableStatus; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse; +import software.amazon.awssdk.services.kinesis.KinesisClient; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KinesisShardManagerTest { + + private DynamoDbClient dynamoDb; + private KinesisShardManager manager; + + @BeforeEach + void setUp() { + dynamoDb = mock(DynamoDbClient.class); + manager = new KinesisShardManager(mock(KinesisClient.class), dynamoDb, mock(ComponentLog.class), "test-table", "test-stream"); + } + + /** + * Verifies that writeCheckpoints enforces monotonic checkpoint advancement. + * A later call with a lower sequence number must not overwrite a higher one. + */ + @Test + void testCheckpointMonotonicity() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("30000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("70000", 0))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(2)).updateItem(captor.capture()); + + final List requests = captor.getAllValues(); + assertEquals("50000", requests.get(0).expressionAttributeValues().get(":seq").s()); + assertEquals("70000", requests.get(1).expressionAttributeValues().get(":seq").s(), + "Only increasing checkpoints should be written to DynamoDB"); + } + + /** + * Verifies that checkpoints for different shards are tracked independently. + * A lower checkpoint on shard-1 must not affect shard-2. + */ + @Test + void testCheckpointMonotonicityPerShard() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of( + "shard-1", new ShardCheckpoint("50000", 0), + "shard-2", new ShardCheckpoint("20000", 0))); + manager.writeCheckpoints(Map.of( + "shard-1", new ShardCheckpoint("30000", 0), + "shard-2", new ShardCheckpoint("40000", 0))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(3)).updateItem(captor.capture()); + + long shard1Writes = 0; + long shard2Writes = 0; + for (final UpdateItemRequest request : captor.getAllValues()) { + if ("shard-1".equals(request.key().get("shardId").s())) { + shard1Writes++; + } else if ("shard-2".equals(request.key().get("shardId").s())) { + shard2Writes++; + } + } + + assertEquals(1, shard1Writes, "shard-1 regression (30000 < 50000) should be skipped"); + assertEquals(2, shard2Writes, "shard-2 advance (40000 > 20000) should be written"); + } + + /** + * Verifies that close() resets the monotonic checkpoint guard, allowing a fresh start. + */ + @Test + void testCloseResetsCheckpointGuard() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); + manager.close(); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("30000", 0))); + + verify(dynamoDb, times(2)).updateItem(any(UpdateItemRequest.class)); + } + + /** + * Verifies that readCheckpoint returns null when no checkpoint item exists in DynamoDB. + * A wrong return value here would cause the processor to either skip data (if it returned + * a stale sequence) or re-process from TRIM_HORIZON unnecessarily. + */ + @Test + void testReadCheckpointReturnsNullForMissingCheckpoint() { + final GetItemResponse emptyResponse = GetItemResponse.builder().item(Map.of()).build(); + when(dynamoDb.getItem(any(GetItemRequest.class))).thenReturn(emptyResponse); + + assertNull(manager.readCheckpoint("shard-1"), + "Missing checkpoint must return null so the processor starts from the initial stream position"); + } + + /** + * Verifies that readCheckpoint ignores a non-numeric sequence number stored in DynamoDB. + * Passing a corrupt value to GetShardIterator would cause an API error and potentially + * force a full re-read from TRIM_HORIZON, creating massive duplication. + */ + @Test + void testReadCheckpointIgnoresInvalidSequenceNumber() { + final GetItemResponse response = GetItemResponse.builder() + .item(Map.of( + "streamName", AttributeValue.builder().s("test-stream").build(), + "shardId", AttributeValue.builder().s("shard-1").build(), + "sequenceNumber", AttributeValue.builder().s("NOT_A_NUMBER").build())) + .build(); + when(dynamoDb.getItem(any(GetItemRequest.class))).thenReturn(response); + + assertNull(manager.readCheckpoint("shard-1"), + "Non-numeric sequence number must be treated as missing to avoid API errors"); + } + + /** + * Verifies that a ConditionalCheckFailedException during checkpoint write (indicating + * lease loss) does not crash the processor and does not corrupt the in-memory monotonic + * guard. A subsequent higher-valued checkpoint must still be attempted. + */ + @Test + void testWriteCheckpointHandlesLostLeaseGracefully() { + final ConditionalCheckFailedException lostLease = ConditionalCheckFailedException.builder().message("lost lease").build(); + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenThrow(lostLease).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("70000", 0))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(2)).updateItem(captor.capture()); + + assertEquals("50000", captor.getAllValues().get(0).expressionAttributeValues().get(":seq").s()); + assertEquals("70000", captor.getAllValues().get(1).expressionAttributeValues().get(":seq").s(), + "After a lost-lease failure, the next higher checkpoint must still be attempted"); + } + + /** + * Verifies that writeCheckpoints stores the subSequenceNumber alongside the sequenceNumber. + */ + @Test + void testCheckpointStoresSubSequenceNumber() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 7))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(1)).updateItem(captor.capture()); + + final UpdateItemRequest request = captor.getValue(); + assertEquals("50000", request.expressionAttributeValues().get(":seq").s()); + assertEquals("7", request.expressionAttributeValues().get(":subSeq").n(), + "subSequenceNumber must be persisted in the DynamoDB checkpoint"); + } + + /** + * Verifies that for the same sequence number, a higher sub-sequence number is written + * and a lower one is skipped. + */ + @Test + void testCheckpointMonotonicityWithSubSequenceNumber() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 3))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 1))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 5))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(2)).updateItem(captor.capture()); + + final List requests = captor.getAllValues(); + assertEquals("3", requests.get(0).expressionAttributeValues().get(":subSeq").n()); + assertEquals("5", requests.get(1).expressionAttributeValues().get(":subSeq").n(), + "Only increasing sub-sequence checkpoints within the same sequence should be written"); + } + + /** + * Verifies that when no checkpoint table exists and no orphaned migration table exists, + * a fresh table is created with the configured name (no suffix). + */ + @Test + void testEnsureCheckpointTableCreatesNewTableWhenNotFound() { + final AtomicInteger mainTableDescribeCount = new AtomicInteger(); + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + if (mainTableDescribeCount.incrementAndGet() <= 2) { + throw ResourceNotFoundException.builder().build(); + } + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + when(dynamoDb.createTable(any(CreateTableRequest.class))).thenReturn(CreateTableResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(CreateTableRequest.class); + verify(dynamoDb).createTable(captor.capture()); + assertEquals("test-table", captor.getValue().tableName(), "Fresh table should use the configured name, not a migration suffix"); + } + + /** + * Verifies that when the configured table already exists with the new composite-key + * schema and no migration table is lingering, no creation or migration occurs. + */ + @Test + void testEnsureCheckpointTableUsesExistingNewTable() { + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + + manager.ensureCheckpointTableExists(); + + verify(dynamoDb, never()).createTable(any(CreateTableRequest.class)); + verify(dynamoDb, never()).deleteTable(any(DeleteTableRequest.class)); + } + + /** + * Verifies that when the configured table is NOT_FOUND but an orphaned migration table + * exists (crash during a previous migration), the migration table is renamed to the + * configured name by creating a new table, copying items, and deleting the migration table. + */ + @Test + void testEnsureCheckpointTableRenamesOrphanedMigrationTable() { + final AtomicInteger mainTableDescribeCount = new AtomicInteger(); + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + if (mainTableDescribeCount.incrementAndGet() <= 3) { + throw ResourceNotFoundException.builder().build(); + } + + return newSchemaActiveResponse(); + } + + if ("test-table_migration".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + + throw ResourceNotFoundException.builder().build(); + }); + + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(UpdateItemResponse.builder().build()); + when(dynamoDb.deleteTable(any(DeleteTableRequest.class))).thenAnswer(invocation -> { + final DeleteTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + throw ResourceNotFoundException.builder().build(); + } + return DeleteTableResponse.builder().build(); + }); + + when(dynamoDb.createTable(any(CreateTableRequest.class))).thenReturn(CreateTableResponse.builder().build()); + + final Map checkpointItem = Map.of( + "streamName", AttributeValue.builder().s("test-stream").build(), + "shardId", AttributeValue.builder().s("shard-1").build(), + "sequenceNumber", AttributeValue.builder().s("12345").build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(checkpointItem).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + final ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateTableRequest.class); + verify(dynamoDb).createTable(createCaptor.capture()); + assertEquals("test-table", createCaptor.getValue().tableName(), "Renamed table should use the configured name"); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb).putItem(putCaptor.capture()); + assertEquals("test-table", putCaptor.getValue().tableName(), "Items should be copied to the configured table name"); + assertEquals("12345", putCaptor.getValue().item().get("sequenceNumber").s(), "Checkpoint data should be preserved during rename"); + + final ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteTableRequest.class); + verify(dynamoDb, times(2)).deleteTable(deleteCaptor.capture()); + final String deletedMigrationTable = deleteCaptor.getAllValues().stream() + .map(DeleteTableRequest::tableName) + .filter("test-table_migration"::equals) + .findFirst() + .orElseThrow(); + + assertEquals("test-table_migration", deletedMigrationTable, "Migration table should be deleted after rename"); + } + + /** + * Verifies that when the configured table has the new schema but a lingering migration + * table exists (crash after copy but before migration table deletion), the items are + * copied and the migration table is cleaned up. + */ + @Test + void testEnsureCheckpointTableCleansUpLingeringMigrationTable() { + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName()) || "test-table_migration".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + + when(dynamoDb.deleteTable(any(DeleteTableRequest.class))).thenReturn(DeleteTableResponse.builder().build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + verify(dynamoDb, never()).createTable(any(CreateTableRequest.class)); + final ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteTableRequest.class); + verify(dynamoDb).deleteTable(deleteCaptor.capture()); + assertEquals("test-table_migration", deleteCaptor.getValue().tableName(), "Lingering migration table should be deleted"); + } + + private static DescribeTableResponse newSchemaActiveResponse() { + final KeySchemaElement hashKey = KeySchemaElement.builder() + .attributeName("streamName") + .keyType(KeyType.HASH) + .build(); + final KeySchemaElement rangeKey = KeySchemaElement.builder() + .attributeName("shardId") + .keyType(KeyType.RANGE) + .build(); + final TableDescription table = TableDescription.builder() + .keySchema(hashKey, rangeKey) + .tableStatus(TableStatus.ACTIVE) + .build(); + return DescribeTableResponse.builder().table(table).build(); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java new file mode 100644 index 000000000000..9d3d976ad01d --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.kinesis.retrieval.AggregatorUtil; +import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.kpl.Messages; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.time.Instant; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class KplDeaggregatorTest { + + private static final Instant ARRIVAL = Instant.parse("2025-06-15T12:00:00Z"); + private static final String TEST_SHARD_ID = "shardId-000000000000"; + + @Test + void testNonAggregatedPassthrough() { + final byte[] payload = "hello".getBytes(StandardCharsets.UTF_8); + final Record record = buildKinesisRecord("seq-001", "pk-1", payload); + + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + final DeaggregatedRecord dr = result.getFirst(); + assertEquals(TEST_SHARD_ID, dr.shardId()); + assertEquals("seq-001", dr.sequenceNumber()); + assertEquals(0, dr.subSequenceNumber()); + assertEquals("pk-1", dr.partitionKey()); + assertArrayEquals(payload, dr.data()); + assertEquals(ARRIVAL, dr.approximateArrivalTimestamp()); + } + + @Test + void testSingleSubRecord() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-A"), + List.of(new SubRecord(0, "data-A".getBytes(StandardCharsets.UTF_8)))); + final Record record = buildKinesisRecord("seq-100", "agg-pk", aggregated); + + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + final DeaggregatedRecord dr = result.getFirst(); + assertEquals("seq-100", dr.sequenceNumber()); + assertEquals(0, dr.subSequenceNumber()); + assertEquals("pk-A", dr.partitionKey()); + assertArrayEquals("data-A".getBytes(StandardCharsets.UTF_8), dr.data()); + } + + @Test + void testMultipleSubRecords() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-X", "pk-Y"), + List.of( + new SubRecord(0, "first".getBytes(StandardCharsets.UTF_8)), + new SubRecord(1, "second".getBytes(StandardCharsets.UTF_8)), + new SubRecord(0, "third".getBytes(StandardCharsets.UTF_8)))); + final Record record = buildKinesisRecord("seq-200", "agg-pk", aggregated); + + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(3, result.size()); + + assertEquals("pk-X", result.get(0).partitionKey()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals("first".getBytes(StandardCharsets.UTF_8), result.get(0).data()); + + assertEquals("pk-Y", result.get(1).partitionKey()); + assertEquals(1, result.get(1).subSequenceNumber()); + assertArrayEquals("second".getBytes(StandardCharsets.UTF_8), result.get(1).data()); + + assertEquals("pk-X", result.get(2).partitionKey()); + assertEquals(2, result.get(2).subSequenceNumber()); + assertArrayEquals("third".getBytes(StandardCharsets.UTF_8), result.get(2).data()); + + for (final DeaggregatedRecord dr : result) { + assertEquals("seq-200", dr.sequenceNumber()); + assertEquals(ARRIVAL, dr.approximateArrivalTimestamp()); + } + } + + @Test + void testMixedAggregatedAndNonAggregated() throws Exception { + final byte[] plainPayload = "plain-data".getBytes(StandardCharsets.UTF_8); + final Record plainRecord = buildKinesisRecord("seq-001", "pk-plain", plainPayload); + + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-agg"), + List.of(new SubRecord(0, "agg-data".getBytes(StandardCharsets.UTF_8)))); + final Record aggRecord = buildKinesisRecord("seq-002", "pk-outer", aggregated); + + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + + assertEquals(2, result.size()); + assertEquals("seq-001", result.get(0).sequenceNumber()); + assertArrayEquals(plainPayload, result.get(0).data()); + assertEquals("seq-002", result.get(1).sequenceNumber()); + assertArrayEquals("agg-data".getBytes(StandardCharsets.UTF_8), result.get(1).data()); + } + + @Test + void testCorruptedProtobufFallsBackToPassthrough() { + final byte[] corrupted = new byte[KplDeaggregator.KPL_MAGIC.length + 20 + 16]; + System.arraycopy(KplDeaggregator.KPL_MAGIC, 0, corrupted, 0, KplDeaggregator.KPL_MAGIC.length); + final byte[] protobufPart = new byte[20]; + protobufPart[0] = (byte) 0xFF; + System.arraycopy(protobufPart, 0, corrupted, KplDeaggregator.KPL_MAGIC.length, 20); + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufPart); + System.arraycopy(md5, 0, corrupted, KplDeaggregator.KPL_MAGIC.length + 20, 16); + } catch (final Exception e) { + throw new RuntimeException(e); + } + + final Record record = buildKinesisRecord("seq-bad", "pk-bad", corrupted); + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + assertEquals("seq-bad", result.get(0).sequenceNumber()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals(corrupted, result.get(0).data()); + } + + @Test + void testMd5MismatchFallsBackToPassthrough() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-1"), + List.of(new SubRecord(0, "data".getBytes(StandardCharsets.UTF_8)))); + + aggregated[aggregated.length - 1] ^= 0xFF; + + final Record record = buildKinesisRecord("seq-md5", "pk-md5", aggregated); + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals(aggregated, result.get(0).data()); + } + + @Test + void testIsAggregatedDetection() { + assertFalse(KplDeaggregator.isAggregated(new byte[0])); + assertFalse(KplDeaggregator.isAggregated(new byte[]{0x01, 0x02})); + assertFalse(KplDeaggregator.isAggregated("regular data".getBytes(StandardCharsets.UTF_8))); + + final byte[] withMagic = new byte[KplDeaggregator.KPL_MAGIC.length + 16 + 1]; + System.arraycopy(KplDeaggregator.KPL_MAGIC, 0, withMagic, 0, KplDeaggregator.KPL_MAGIC.length); + assertTrue(KplDeaggregator.isAggregated(withMagic)); + } + + @Test + void testEmptyRecordList() { + final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of()); + assertTrue(result.isEmpty()); + } + + // ---- KCL cross-validation tests ---- + // These tests use the KCL's own protobuf Messages class to create aggregated records, + // then verify that our KplDeaggregator produces the same results as the KCL's AggregatorUtil. + + @Test + void testKclAggregatedSingleRecord() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-kcl-1") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("hello from KPL"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-kcl-1", "outer-pk", payload); + + final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(1, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + assertDeaggregatedMatchesKcl(ourResult.getFirst(), kclResult.getFirst()); + assertEquals("pk-kcl-1", ourResult.getFirst().partitionKey()); + assertArrayEquals("hello from KPL".getBytes(StandardCharsets.UTF_8), ourResult.getFirst().data()); + } + + @Test + void testKclAggregatedMultipleRecords() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-alpha") + .addPartitionKeyTable("pk-beta") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-0"))) + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(1) + .setData(ByteString.copyFromUtf8("record-1"))) + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-2"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-kcl-multi", "outer-pk", payload); + + final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(3, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + for (int i = 0; i < ourResult.size(); i++) { + assertDeaggregatedMatchesKcl(ourResult.get(i), kclResult.get(i)); + } + + assertEquals("pk-alpha", ourResult.get(0).partitionKey()); + assertEquals(0, ourResult.get(0).subSequenceNumber()); + assertArrayEquals("record-0".getBytes(StandardCharsets.UTF_8), ourResult.get(0).data()); + + assertEquals("pk-beta", ourResult.get(1).partitionKey()); + assertEquals(1, ourResult.get(1).subSequenceNumber()); + assertArrayEquals("record-1".getBytes(StandardCharsets.UTF_8), ourResult.get(1).data()); + + assertEquals("pk-alpha", ourResult.get(2).partitionKey()); + assertEquals(2, ourResult.get(2).subSequenceNumber()); + assertArrayEquals("record-2".getBytes(StandardCharsets.UTF_8), ourResult.get(2).data()); + } + + @Test + void testKclAggregatedMixedWithPlainRecords() { + final Record plainRecord = buildKinesisRecord("seq-plain", "pk-plain", + "plain-data".getBytes(StandardCharsets.UTF_8)); + + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-inner") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("agg-data"))) + .build(); + final Record aggRecord = buildKinesisRecord("seq-agg", "outer-pk", + wrapAsKplPayload(aggProto.toByteArray())); + + final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + + final List kclPlain = deaggregateViaKcl(plainRecord); + final List kclAgg = deaggregateViaKcl(aggRecord); + + assertEquals(2, ourResult.size()); + assertEquals(1, kclPlain.size()); + assertEquals(1, kclAgg.size()); + + assertDeaggregatedMatchesKcl(ourResult.get(0), kclPlain.getFirst()); + assertDeaggregatedMatchesKcl(ourResult.get(1), kclAgg.getFirst()); + } + + @Test + void testKclAggregatedWithExplicitHashKeys() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-0") + .addExplicitHashKeyTable("12345678901234567890") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setExplicitHashKeyIndex(0) + .setData(ByteString.copyFromUtf8("with-ehk"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-ehk", "outer-pk", payload); + + final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(1, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + assertDeaggregatedMatchesKcl(ourResult.getFirst(), kclResult.getFirst()); + assertArrayEquals("with-ehk".getBytes(StandardCharsets.UTF_8), ourResult.getFirst().data()); + } + + @Test + void testKclAggregatedLargeBatch() { + final Messages.AggregatedRecord.Builder builder = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-batch"); + for (int i = 0; i < 100; i++) { + builder.addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-" + i))); + } + final byte[] payload = wrapAsKplPayload(builder.build().toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-batch", "outer-pk", payload); + + final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(100, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + for (int i = 0; i < ourResult.size(); i++) { + assertDeaggregatedMatchesKcl(ourResult.get(i), kclResult.get(i)); + assertEquals(i, ourResult.get(i).subSequenceNumber()); + assertArrayEquals(("record-" + i).getBytes(StandardCharsets.UTF_8), ourResult.get(i).data()); + } + } + + // ---- Test helpers ---- + + private record SubRecord(int partitionKeyIndex, byte[] data) { } + + private static Record buildKinesisRecord(final String sequenceNumber, final String partitionKey, final byte[] data) { + return Record.builder() + .sequenceNumber(sequenceNumber) + .partitionKey(partitionKey) + .data(SdkBytes.fromByteArray(data)) + .approximateArrivalTimestamp(ARRIVAL) + .build(); + } + + /** + * Wraps raw protobuf bytes in the KPL envelope format (magic + protobuf + MD5). + * + * @param protobufBytes serialized protobuf content + * @return complete KPL aggregated payload + */ + private static byte[] wrapAsKplPayload(final byte[] protobufBytes) { + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); + final ByteArrayOutputStream result = new ByteArrayOutputStream(); + result.write(KplDeaggregator.KPL_MAGIC); + result.write(protobufBytes); + result.write(md5); + return result.toByteArray(); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Deaggregates a single SDK v2 Record using the KCL's own {@link AggregatorUtil}. + * + * @param record the SDK v2 Kinesis record + * @return deaggregated KCL records + */ + private static List deaggregateViaKcl(final Record record) { + final KinesisClientRecord kcr = KinesisClientRecord.fromRecord(record); + return new AggregatorUtil().deaggregate(List.of(kcr)); + } + + private static void assertDeaggregatedMatchesKcl(final DeaggregatedRecord ours, final KinesisClientRecord kcl) { + assertEquals(kcl.sequenceNumber(), ours.sequenceNumber(), "sequence number mismatch"); + assertEquals(kcl.subSequenceNumber(), ours.subSequenceNumber(), "sub-sequence number mismatch"); + assertEquals(kcl.partitionKey(), ours.partitionKey(), "partition key mismatch"); + assertArrayEquals(toBytes(kcl.data()), ours.data(), "data mismatch"); + } + + private static byte[] toBytes(final ByteBuffer buffer) { + final ByteBuffer dup = buffer.duplicate(); + final byte[] bytes = new byte[dup.remaining()]; + dup.get(bytes); + return bytes; + } + + private static byte[] buildAggregatedPayload(final List partitionKeys, final List subRecords) + throws Exception { + final ByteArrayOutputStream protobufBuffer = new ByteArrayOutputStream(); + final CodedOutputStream cos = CodedOutputStream.newInstance(protobufBuffer); + + for (final String pk : partitionKeys) { + cos.writeString(1, pk); + } + + for (final SubRecord sub : subRecords) { + final int innerSize = computeInnerRecordSize(sub); + cos.writeTag(3, 2); + cos.writeUInt32NoTag(innerSize); + cos.writeUInt64(1, sub.partitionKeyIndex); + cos.writeByteArray(3, sub.data); + } + + cos.flush(); + final byte[] protobufBytes = protobufBuffer.toByteArray(); + return wrapAsKplPayload(protobufBytes); + } + + private static int computeInnerRecordSize(final SubRecord sub) { + int size = 0; + size += CodedOutputStream.computeUInt64Size(1, sub.partitionKeyIndex); + size += CodedOutputStream.computeByteArraySize(3, sub.data); + return size; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java deleted file mode 100644 index b6005bf8b99d..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java +++ /dev/null @@ -1,792 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.documentation.init.NopComponentLog; -import org.apache.nifi.processors.aws.kinesis.MemoryBoundRecordBuffer.Lease; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import software.amazon.awssdk.services.kinesis.model.Record; -import software.amazon.kinesis.exceptions.InvalidStateException; -import software.amazon.kinesis.exceptions.KinesisClientLibDependencyException; -import software.amazon.kinesis.exceptions.ShutdownException; -import software.amazon.kinesis.exceptions.ThrottlingException; -import software.amazon.kinesis.processor.Checkpointer; -import software.amazon.kinesis.processor.PreparedCheckpointer; -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.IntStream; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class MemoryBoundRecordBufferTest { - - private static final long MAX_MEMORY_BYTES = 1024L; - private static final String SHARD_ID_1 = "shard-1"; - private static final String SHARD_ID_2 = "shard-2"; - private static final Duration CHECKPOINT_INTERVAL = Duration.ZERO; - - private MemoryBoundRecordBuffer recordBuffer; - private TestCheckpointer checkpointer1; - private TestCheckpointer checkpointer2; - - @BeforeEach - void setUp() { - recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), MAX_MEMORY_BYTES, CHECKPOINT_INTERVAL); - checkpointer1 = new TestCheckpointer(); - checkpointer2 = new TestCheckpointer(); - } - - @Test - void testCreateBuffer() { - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - assertEquals(SHARD_ID_1, bufferId1.shardId()); - - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - assertEquals(SHARD_ID_2, bufferId2.shardId()); - - final ShardBufferId newBufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - assertEquals(SHARD_ID_1, newBufferId1.shardId()); - - assertNotEquals(bufferId1, bufferId2); - assertNotEquals(bufferId1, newBufferId1); - } - - @Test - void testAddRecordsToBuffer() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - // Buffer without records is not available for leasing. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - // Should be able to get buffer ID from pool since buffer has records. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease.shardId()); - } - - @Test - void testAddEmptyRecordsList() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List emptyRecords = Collections.emptyList(); - - recordBuffer.addRecords(bufferId, emptyRecords, checkpointer1); - - // Should not be able to get buffer ID from pool since no records were added. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testConsumeRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(3); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertEquals(records, consumedRecords); - // Just consuming record should not checkpoint them. - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testCommitConsumedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - assertEquals(records.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testCommitConsumedRecords_withRecordsAddedBeforeCommit() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List originalRecords = createTestRecords(2); - - recordBuffer.addRecords(bufferId, originalRecords, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - - // Simulating new records added in parallel, before a commit. - final List newRecords = createTestRecords(5); - recordBuffer.addRecords(bufferId, newRecords, checkpointer1); - - recordBuffer.commitConsumedRecords(lease); - - // Only originalRecords, which were consumed, are checkpointed. - assertEquals(originalRecords.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testRollbackConsumedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(3); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - final List messages = consumedRecords.stream() - .map(this::readContent) - .toList(); - - recordBuffer.rollbackConsumedRecords(lease); - - // Checkpointer should not be called during rollback. - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - - final List rolledBackMessages = recordBuffer.consumeRecords(lease).stream() - .map(this::readContent) - .toList(); - assertEquals(messages, rolledBackMessages); - } - - private String readContent(final KinesisClientRecord record) { - final ByteBuffer data = record.data(); - final byte[] buffer = new byte[data.remaining()]; - data.get(buffer); - return new String(buffer, StandardCharsets.UTF_8); - } - - @Test - void testReturnBufferIdToPool() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - // Consume some records but don't return buffer id to the pool. - recordBuffer.consumeRecords(lease1); - - recordBuffer.addRecords(bufferId, createTestRecords(1), checkpointer2); - - // The buffer is still unavailable. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - - // After returning id to the pool it's possible to get the buffer from pool again. - recordBuffer.returnBufferLease(lease1); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - } - - @Test - void testReturnBufferIdToPool_multipleReturns() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.returnBufferLease(lease1); - recordBuffer.returnBufferLease(lease1); - - // Can retrieve the id only once. - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testReturnBufferIdToPool_withUncommittedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - // Consume some records, but don't commit them. - final List lease1Records = recordBuffer.consumeRecords(lease1); - recordBuffer.returnBufferLease(lease1); - - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - final List lease2Records = recordBuffer.consumeRecords(lease2); - - // Until committed, the records stay in the buffer. - assertEquals(lease1Records, lease2Records); - } - - @Test - void testReturnBufferIdToPoolWithNoPendingRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - // Get buffer from pool and consume all records. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - // Return buffer ID to pool - should not make buffer available since no pending records. - recordBuffer.returnBufferLease(lease); - - // Should not be able to get buffer from pool. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testConsumerLeaseLost() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - // Before lease lost buffer should be available in the pool. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease.shardId()); - - // Simulate lease lost. - recordBuffer.consumerLeaseLost(bufferId); - - // Should not be able to consume records from invalidated buffer. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - - // Should not be able to commit records for invalidated buffer. - recordBuffer.commitConsumedRecords(lease); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - - // Buffer should not be available in pool. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testCheckpointEndedShard() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - recordBuffer.checkpointEndedShard(bufferId, checkpointer2); - assertEquals(TestCheckpointer.LATEST_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - // Buffer should be removed and not available for operations. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - } - - @Test - void testShutdownShardConsumption_forEmptyBuffer() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - recordBuffer.shutdownShardConsumption(bufferId, checkpointer2); - assertEquals(TestCheckpointer.LATEST_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - // Buffer should be removed and not available for operations. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - } - - @Test - void testShutdownShardConsumption_forBufferWithRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - recordBuffer.shutdownShardConsumption(bufferId, checkpointer2); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - assertTrue(recordBuffer.acquireBufferLease().isEmpty(), "Buffer should not be available after shutdown"); - } - - @Test - void testShutdownShardConsumption_whileOtherShardIsValid() { - final int bufferSize = 100; - - // Create buffer with small memory limit. - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), bufferSize, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - - final List records1 = List.of(createRecordWithSize(bufferSize)); - recordBuffer.addRecords(bufferId1, records1, checkpointer1); - - // Shutting down a buffer with a record. - recordBuffer.shutdownShardConsumption(bufferId1, checkpointer1); - - // Adding records to another buffer. - final List records2 = List.of(createRecordWithSize(bufferSize)); - assertTimeoutPreemptively( - Duration.ofSeconds(1), - () -> recordBuffer.addRecords(bufferId2, records2, checkpointer2), - "Records should be added to a buffer without memory backpressure"); - - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_2, lease.shardId(), "Expected to acquire a lease for " + SHARD_ID_2); - assertEquals(records2, recordBuffer.consumeRecords(lease)); - } - - @Test - @Timeout(value = 5, unit = SECONDS) - void testMemoryBackpressure() throws InterruptedException { - // Create buffer with small memory limit. - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), 100L, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // Still fits into the buffer. - final List initialRecords = List.of(createRecordWithSize(80), createRecordWithSize(20)); - recordBuffer.addRecords(bufferId, initialRecords, checkpointer1); - - final CountDownLatch startLatch = new CountDownLatch(1); - - final List notFittingRecords = List.of(createRecordWithSize(50)); - // Thread that tries to add records (should block due to memory limit). - final Thread addRecordsThread = new Thread(() -> { - startLatch.countDown(); - // Doesn't fit into the buffer. - recordBuffer.addRecords(bufferId, notFittingRecords, checkpointer1); - }); - - addRecordsThread.start(); - startLatch.await(); - - // Wait for thread to try to add records and get blocked. - Thread.sleep(200); - assertTrue(addRecordsThread.isAlive(), "Thread should be blocked waiting for memory"); - - // Commit records in the buffer to free memory. - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(initialRecords, recordBuffer.consumeRecords(lease1)); - recordBuffer.commitConsumedRecords(lease1); - recordBuffer.returnBufferLease(lease1); - - // Thread should get unblocked and add the message. - addRecordsThread.join(); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(notFittingRecords, recordBuffer.consumeRecords(lease2)); - } - - @Test - void testMemoryBackpressure_forRecordsLargerThanMaxBuffer() { - // Create buffer with small memory limit. - final int bufferSize = 10; - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), bufferSize, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // A single batch with size double the buffer size. - final List reallyLargeBatch = List.of(createRecordWithSize(bufferSize), createRecordWithSize(bufferSize)); - - // It's possible to insert a batch that exceeds the buffer size. - assertDoesNotThrow(() -> recordBuffer.addRecords(bufferId, reallyLargeBatch, checkpointer1)); - } - - @Test - @Timeout(value = 5, unit = SECONDS) - void testConcurrentBufferAccess() throws InterruptedException { - final int numberOfShards = 10; - final int numberOfRecordConsumers = 10; - final int totalThreads = numberOfShards + numberOfRecordConsumers; - - final int recordsPerShard = 5; - - final ExecutorService executor = Executors.newFixedThreadPool(totalThreads); - // All threads start at the same moment. - final CountDownLatch startLatch = new CountDownLatch(totalThreads); - final CountDownLatch finishLatch = new CountDownLatch(totalThreads); - - final List shardIds = IntStream.range(0, numberOfShards).boxed().toList(); - final List bufferIds = shardIds.stream() - .map(shardId -> recordBuffer.createBuffer("shard-" + shardId)) - .toList(); - - // Every producer returns a checkpointer for the records it produced. - final List checkpointers = shardIds.stream() - .map(shardId -> { - final ShardBufferId bufferId = bufferIds.get(shardId); - final List records = createTestRecords(recordsPerShard); - final TestCheckpointer threadCheckpointer = new TestCheckpointer(); - - executor.submit(() -> { - try { - startLatch.countDown(); - startLatch.await(); - - recordBuffer.addRecords(bufferId, records, threadCheckpointer); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } finally { - finishLatch.countDown(); - } - }); - - return threadCheckpointer; - }) - .toList(); - - // Every consumer returns a list of shardId:sequenceNumber strings for the records it consumed. - final List>> processedRecordsFutures = IntStream.range(0, numberOfRecordConsumers) - .mapToObj(__ -> executor.submit(() -> { - try { - startLatch.countDown(); - startLatch.await(); - - Optional maybeLease = Optional.empty(); - while (maybeLease.isEmpty()) { - maybeLease = recordBuffer.acquireBufferLease(); - Thread.sleep(100); // Wait for records to be added by producers. - } - - final Lease lease = maybeLease.orElseThrow(); - final List consumedRecords = recordBuffer.consumeRecords(lease); - - recordBuffer.commitConsumedRecords(lease); - - return consumedRecords.stream() - .map(record -> lease.shardId() + ":" + record.sequenceNumber()) - .toList(); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } finally { - finishLatch.countDown(); - } - })) - .toList(); - - finishLatch.await(); - executor.shutdown(); - - assertAll( - checkpointers.stream().map(it -> () -> - assertNotEquals( - TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, - it.latestCheckpointedSequenceNumber(), - "Every checkpointer should have been called")) - ); - - final long uniqueRecordsProcessed = processedRecordsFutures.stream() - .map(future -> { - try { - return future.get(); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }) - .flatMap(Collection::stream) - .distinct() - .count(); - assertEquals(numberOfShards * recordsPerShard, uniqueRecordsProcessed); - } - - @Test - void testGetMultipleBuffersFromPool() { - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - - // Add records to both buffers. - recordBuffer.addRecords(bufferId1, createTestRecords(2), checkpointer1); - recordBuffer.addRecords(bufferId2, createTestRecords(3), checkpointer2); - - // Should be able to get both buffer IDs from pool. - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - - assertEquals(SHARD_ID_1, lease1.shardId()); - assertEquals(SHARD_ID_2, lease2.shardId()); - - // Should get different buffer IDs. - assertNotEquals(lease1.shardId(), lease2.shardId()); - } - - @Test - void testCommitRecordsWhileNewRecordsArrive() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // Add records to pending queue. - final List batch1 = createTestRecords(2); - recordBuffer.addRecords(bufferId, batch1, checkpointer1); - - // Consume batch1 records (moves from pending to in-progress). - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertEquals(batch1, consumedRecords); - - // Add more records while others are in-progress. - final List batch2 = createTestRecords(1); - recordBuffer.addRecords(bufferId, batch2, checkpointer2); - - // Commit in-progress records. - recordBuffer.commitConsumedRecords(lease); - - // Consume batch2 records. - final List remainingRecords = recordBuffer.consumeRecords(lease); - assertEquals(batch2, remainingRecords); - } - - @ParameterizedTest - @ValueSource(classes = { - KinesisClientLibDependencyException.class, - InvalidStateException.class, - ThrottlingException.class, - }) - void testRetriableCheckpointExceptions(final Class exceptionClass) throws Exception { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - final Exception exception = exceptionClass.getDeclaredConstructor(String.class).newInstance("Thrown from test"); - final TestCheckpointer failingCheckpointer = new TestCheckpointer(exception, 2); - - recordBuffer.addRecords(bufferId, records, failingCheckpointer); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - // Should handle all exception types gracefully. - recordBuffer.commitConsumedRecords(lease); - - assertEquals(records.getLast().sequenceNumber(), failingCheckpointer.latestCheckpointedSequenceNumber()); - } - - @Test - void testShutdownDuringCheckpoint() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - final ShutdownException exception = new ShutdownException("Test shutdown exception"); - final TestCheckpointer failingCheckpointer = new TestCheckpointer(exception, 1); - - recordBuffer.addRecords(bufferId, records, failingCheckpointer); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - recordBuffer.commitConsumedRecords(lease); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, failingCheckpointer.latestCheckpointedSequenceNumber()); - } - - @Test - @Timeout(value = 3, unit = SECONDS) - void testCheckpointEndedShardWaitsForPendingRecords() throws InterruptedException { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - final CountDownLatch finishStarted = new CountDownLatch(1); - - // Start finishConsumption in another thread. - final Thread finishThread = new Thread(() -> { - finishStarted.countDown(); - recordBuffer.checkpointEndedShard(bufferId, checkpointer2); - }); - - finishThread.start(); - finishStarted.await(); - - // Give finishThread time to get blocked. - Thread.sleep(200); - assertTrue(finishThread.isAlive(), "finishConsumption should block until records are committed"); - - // Commit records to unblock finishConsumption. - recordBuffer.commitConsumedRecords(lease); - assertEquals(records.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - - finishThread.join(); - assertEquals( - TestCheckpointer.LATEST_SEQUENCE_NUMBER, - checkpointer2.latestCheckpointedSequenceNumber(), - "Checkpointer should be called after finishConsumption unblocks"); - } - - private List createTestRecords(int count) { - return IntStream.range(0, count) - .mapToObj(i -> { - final String data = "test-record-" + i; - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data.getBytes(StandardCharsets.UTF_8)).asReadOnlyBuffer()) - .partitionKey("partition-" + i) - .sequenceNumber(String.valueOf(i)) - .build(); - }) - .toList(); - } - - private KinesisClientRecord createRecordWithSize(int sizeBytes) { - final byte[] data = new byte[sizeBytes]; - Arrays.fill(data, (byte) 'X'); - - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data).asReadOnlyBuffer()) - .partitionKey("partition") - .sequenceNumber("1") - .build(); - } - - /** - * Test implementation of RecordProcessorCheckpointer that tracks checkpoint calls - * and can simulate various exception scenarios. - */ - private static class TestCheckpointer implements RecordProcessorCheckpointer { - - static final String NO_CHECKPOINT_SEQUENCE_NUMBER = "NONE"; - static final String LATEST_SEQUENCE_NUMBER = "LATEST"; - - private final Exception exceptionToThrow; - private final AtomicInteger throwsLeft; - - private volatile String latestCheckpointedSequenceNumber = NO_CHECKPOINT_SEQUENCE_NUMBER; - - TestCheckpointer() { - this.exceptionToThrow = null; - this.throwsLeft = new AtomicInteger(0); - } - - TestCheckpointer(final Exception exceptionToThrow, final int maxThrows) { - this.exceptionToThrow = exceptionToThrow; - this.throwsLeft = new AtomicInteger(maxThrows); - } - - @Override - public void checkpoint() throws KinesisClientLibDependencyException, InvalidStateException, ThrottlingException, ShutdownException { - doCheckpoint(LATEST_SEQUENCE_NUMBER); - } - - @Override - public void checkpoint(Record record) { - throw notImplemented(); - } - - @Override - public void checkpoint(String sequenceNumber) { - throw notImplemented(); - } - - @Override - public void checkpoint(String sequenceNumber, long subSequenceNumber) throws ShutdownException, InvalidStateException { - doCheckpoint(sequenceNumber); - } - - private void doCheckpoint(final String sequenceNumber) throws InvalidStateException, ShutdownException { - if (exceptionToThrow != null && throwsLeft.decrementAndGet() == 0) { - switch (exceptionToThrow) { - case KinesisClientLibDependencyException e -> throw e; - case InvalidStateException e -> throw e; - case ThrottlingException e -> throw e; - case ShutdownException e -> throw e; - default -> throw new RuntimeException(exceptionToThrow); - } - } - - latestCheckpointedSequenceNumber = sequenceNumber; - } - - @Override - public PreparedCheckpointer prepareCheckpoint() { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(Record record) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(Record record, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, long subSequenceNumber) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, long subSequenceNumber, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public Checkpointer checkpointer() { - throw notImplemented(); - } - - String latestCheckpointedSequenceNumber() { - return latestCheckpointedSequenceNumber; - } - - private static RuntimeException notImplemented() { - return new UnsupportedOperationException("Not implemented for test"); - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java new file mode 100644 index 000000000000..3277cb9ce958 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; +import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse; +import software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; + +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class PollingKinesisClientTest { + + private static final GetShardIteratorResponse ITERATOR_RESPONSE = + GetShardIteratorResponse.builder().shardIterator("iter-1").build(); + + private KinesisClient mockKinesisClient; + private KinesisShardManager mockShardManager; + private PollingKinesisClient consumer; + + @BeforeEach + void setUp() { + mockKinesisClient = mock(KinesisClient.class); + mockShardManager = mock(KinesisShardManager.class); + consumer = new PollingKinesisClient(mockKinesisClient, mock(ComponentLog.class), 1L, 1L); + } + + @AfterEach + void tearDown() { + if (consumer != null) { + consumer.close(); + } + } + + /** + * When a shard reaches exhaustion (nextShardIterator is null), the records from the + * final GetRecords response must be queued before the loop exits. A regression would + * silently drop the last batch of every exhausted shard. + */ + @Test + void testExhaustedShardDeliversAllRecords() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B"), record("300", "C")) + .nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Last batch of an exhausted shard must not be dropped"); + assertEquals(3, result.records().size()); + assertEquals("100", result.firstSequenceNumber()); + assertEquals("300", result.lastSequenceNumber()); + + assertEventuallyNoPendingFetches(); + } + + /** + * When a fetch loop thread dies from an uncaught Throwable, the next startFetches call + * must detect the dead loop and restart it so records continue flowing. An Error thrown + * from readCheckpoint escapes the Exception catch in getShardIterator, propagates through + * runFetchLoop, and is caught by the Throwable guard in launchFetchLoop. + */ + @Test + void testDeadLoopRecoveryRestoresDataFlow() throws Exception { + when(mockShardManager.readCheckpoint(anyString())) + .thenThrow(new AssertionError("simulated loop death")) + .thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + final List shards = shards("shard-1"); + consumer.startFetches(shards, "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + Thread.sleep(100); + + consumer.startFetches(shards, "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Dead loop must be restarted and produce records"); + assertEquals("100", result.firstSequenceNumber()); + } + + /** + * Validates that after a rollback (session commit failure), the fetch loop re-acquires + * the shard iterator from the checkpoint position, not from where it left off. This + * prevents data loss when in-flight records need to be re-consumed. + */ + @Test + void testRollbackCausesIteratorReAcquisitionFromCheckpoint() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null).thenReturn("500"); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("600", "data"), record("700", "data2")) + .nextShardIterator("iter-next").millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult); + + consumer.rollbackResults(List.of(firstResult)); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(GetShardIteratorRequest.class); + verify(mockKinesisClient, timeout(5000).atLeast(2)).getShardIterator(captor.capture()); + + final GetShardIteratorRequest reAcquisition = captor.getAllValues().get(captor.getAllValues().size() - 1); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, reAcquisition.shardIteratorType(), + "After rollback, iterator must resume from checkpoint, not from where it left off"); + assertEquals("500", reAcquisition.startingSequenceNumber()); + } + + /** + * Validates that an expired iterator is transparently replaced by re-acquiring from + * the checkpoint. Records must still arrive after the transient error. + */ + @Test + void testExpiredIteratorRecovery() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final ExpiredIteratorException expired = ExpiredIteratorException.builder().message("expired").build(); + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenThrow(expired).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Records must arrive after expired iterator recovery"); + + verify(mockKinesisClient, timeout(5000).atLeast(2)).getShardIterator(any(GetShardIteratorRequest.class)); + } + + /** + * Validates that throttling (ProvisionedThroughputExceededException) is transient: + * the loop retries and records eventually arrive without data loss. + */ + @Test + void testThrottledFetchRetriesWithoutDataLoss() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final ProvisionedThroughputExceededException throttled = + ProvisionedThroughputExceededException.builder().message("throttled").build(); + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenThrow(throttled).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Throttled fetch must retry and eventually deliver records"); + } + + /** + * Validates that a RuntimeException from readCheckpoint (e.g., DynamoDB throttle on + * startup) does not permanently kill the fetch loop. The getShardIterator method + * catches the exception and returns null, the loop backs off and retries. Records + * must eventually arrive. + */ + @Test + void testCheckpointReadFailureRetriesWithoutKillingLoop() throws Exception { + when(mockShardManager.readCheckpoint(anyString())) + .thenThrow(new RuntimeException("DynamoDB throttle")) + .thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Checkpoint read failure must not permanently kill the fetch loop"); + } + + /** + * Validates that close() terminates all fetch loops and that the consumer reports + * no pending work, preventing the processor from spin-waiting after shutdown. + */ + @Test + void testCloseTerminatesAllFetchingAndDrainsQueue() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator("iter-next").millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1", "shard-2"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + verify(mockKinesisClient, timeout(5000).atLeast(1)).getRecords(any(GetRecordsRequest.class)); + + consumer.close(); + + assertFalse(consumer.hasPendingFetches(), "After close, hasPendingFetches must return false"); + + consumer = null; + } + + /** + * Validates that once all shards are exhausted and the queue is drained, + * hasPendingFetches returns false. This ensures the processor's poll loop + * exits promptly rather than spin-waiting indefinitely. + */ + @Test + void testHasPendingFetchesFalseWhenAllShardsExhausted() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1", "shard-2"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + drainAllResults(); + assertEventuallyNoPendingFetches(); + } + + private void drainAllResults() throws InterruptedException { + ShardFetchResult discarded; + do { + discarded = consumer.pollAnyResult(500, TimeUnit.MILLISECONDS); + } while (discarded != null); + } + + private void assertEventuallyNoPendingFetches() throws InterruptedException { + final long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < deadline) { + if (!consumer.hasPendingFetches()) { + return; + } + Thread.sleep(50); + } + assertFalse(consumer.hasPendingFetches(), "Expected hasPendingFetches to become false"); + } + + private static List shards(final String... shardIds) { + return Arrays.stream(shardIds).map(id -> Shard.builder().shardId(id).build()).toList(); + } + + private static Record record(final String sequenceNumber, final String data) { + return Record.builder() + .sequenceNumber(sequenceNumber).partitionKey("pk-" + sequenceNumber) + .approximateArrivalTimestamp(Instant.now()) + .data(SdkBytes.fromString(data, StandardCharsets.UTF_8)).build(); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java deleted file mode 100644 index 71c9ccd36fa9..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -class ProvenanceTransitUriFormatTest { - - @ParameterizedTest - @CsvSource(""" - streamA,shardId123,kinesis:stream/streamA/shardId123 - stream-name-b,shardId-000000013,kinesis:stream/stream-name-b/shardId-000000013 - kinesis,shardId-00000001,kinesis:stream/kinesis/shardId-00000001 - """) - void toTransitUri(final String streamName, final String shardId, final String expectedTransitUri) { - final String transitUri = ProvenanceTransitUriFormat.toTransitUri(streamName, shardId); - assertEquals(expectedTransitUri, transitUri); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java deleted file mode 100644 index c4023148aa50..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java +++ /dev/null @@ -1,397 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.json.JsonRecordSetWriter; -import org.apache.nifi.json.JsonTreeReader; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processors.aws.kinesis.ReaderRecordProcessor.ProcessingResult; -import org.apache.nifi.processors.aws.kinesis.converter.ValueRecordConverter; -import org.apache.nifi.reporting.InitializationException; -import org.apache.nifi.schema.access.SchemaAccessUtils; -import org.apache.nifi.schema.inference.SchemaInferenceUtil; -import org.apache.nifi.serialization.MalformedRecordException; -import org.apache.nifi.serialization.RecordReader; -import org.apache.nifi.serialization.RecordReaderFactory; -import org.apache.nifi.serialization.record.MockRecordParser; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordSchema; -import org.apache.nifi.util.MockFlowFile; -import org.apache.nifi.util.MockProcessSession; -import org.apache.nifi.util.SharedSessionState; -import org.apache.nifi.util.TestRunner; -import org.apache.nifi.util.TestRunners; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.time.Instant; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.APPROXIMATE_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.PARTITION_KEY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.JsonRecordAssert.assertFlowFileRecords; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class ReaderRecordProcessorTest { - - private static final String TEST_STREAM_NAME = "stream-test"; - private static final String TEST_SHARD_ID = "shardId-test"; - - private static final String USER_JSON_1 = "{\"name\":\"John\",\"age\":30}"; - private static final String USER_JSON_2 = "{\"name\":\"Jane\",\"age\":25}"; - private static final String USER_JSON_3 = "{\"name\":\"Bob\",\"age\":35}"; - - private static final String CITY_JSON_1 = "{\"name\":\"Seattle\",\"country\":\"US\"}"; - private static final String CITY_JSON_2 = "{\"name\":\"Warsaw\",\"country\":\"PL\"}"; - - private static final String INVALID_JSON = "{invalid json}"; - - private MockProcessSession session; - private ComponentLog logger; - - private JsonRecordSetWriter jsonWriter; - private ReaderRecordProcessor processor; - - @BeforeEach - void setUp() throws InitializationException { - final TestRunner runner = TestRunners.newTestRunner(ConsumeKinesis.class); - final SharedSessionState sharedState = new SharedSessionState(runner.getProcessor(), new AtomicLong(0)); - session = new MockProcessSession(sharedState, runner.getProcessor()); - logger = runner.getLogger(); - - final JsonTreeReader jsonReader = new JsonTreeReader(); - runner.addControllerService("json-reader", jsonReader); - runner.setProperty(jsonReader, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaInferenceUtil.INFER_SCHEMA.getValue()); - runner.enableControllerService(jsonReader); - - jsonWriter = new JsonRecordSetWriter(); - runner.addControllerService("json-writer", jsonWriter); - runner.setProperty(jsonWriter, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaAccessUtils.INHERIT_RECORD_SCHEMA.getValue()); - runner.enableControllerService(jsonWriter); - - processor = new ReaderRecordProcessor(jsonReader, new ValueRecordConverter(), jsonWriter, logger); - } - - @Test - void testProcessSingleRecord() { - final KinesisClientRecord record = KinesisClientRecord.builder() - .data(ByteBuffer.wrap(USER_JSON_1.getBytes(UTF_8))) - .sequenceNumber("1") - .subSequenceNumber(2) - .approximateArrivalTimestamp(Instant.now()) - .partitionKey("key-123") - .build(); - final List records = List.of(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - - assertEquals(TEST_STREAM_NAME, successFlowFile.getAttribute(STREAM_NAME)); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - - assertEquals(record.sequenceNumber(), successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(record.subSequenceNumber()), successFlowFile.getAttribute(FIRST_SUB_SEQUENCE_NUMBER)); - assertEquals(record.sequenceNumber(), successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(record.subSequenceNumber()), successFlowFile.getAttribute(LAST_SUB_SEQUENCE_NUMBER)); - - assertEquals(record.partitionKey(), successFlowFile.getAttribute(PARTITION_KEY)); - assertEquals(String.valueOf(record.approximateArrivalTimestamp().toEpochMilli()), successFlowFile.getAttribute(APPROXIMATE_ARRIVAL_TIMESTAMP)); - - assertEquals("application/json", successFlowFile.getAttribute(MIME_TYPE)); - assertEquals("1", successFlowFile.getAttribute(RECORD_COUNT)); - - assertFlowFileRecords(successFlowFile, records); - } - - @Test - void testProcessMultipleRecordsWithSameSchema() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(USER_JSON_2, "2"), - createKinesisRecord(USER_JSON_3, "3") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - assertEquals(TEST_STREAM_NAME, successFlowFile.getAttribute(STREAM_NAME)); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - assertEquals("3", successFlowFile.getAttribute(RECORD_COUNT)); - - assertEquals(records.getFirst().sequenceNumber(), successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(records.getFirst().subSequenceNumber()), successFlowFile.getAttribute(FIRST_SUB_SEQUENCE_NUMBER)); - assertEquals(records.getLast().sequenceNumber(), successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(records.getLast().subSequenceNumber()), successFlowFile.getAttribute(LAST_SUB_SEQUENCE_NUMBER)); - - assertFlowFileRecords(successFlowFile, records); - } - - @Test - void testEmptyRecordsList() { - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, Collections.emptyList()); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - } - - @Test - void testSchemaChangeCreatesNewFlowFile() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(CITY_JSON_1, "2") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(2, result.successFlowFiles().size()); // Two different schemas = two FlowFiles - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("1", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(firstFlowFile, records.getFirst()); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("1", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(secondFlowFile, records.get(1)); - } - - @Test - void testSchemaChangeWithMultipleRecordsInBetween() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(USER_JSON_2, "2"), - createKinesisRecord(CITY_JSON_1, "3"), - createKinesisRecord(CITY_JSON_2, "4") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(2, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(firstFlowFile, records.subList(0, 2)); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("2", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(secondFlowFile, records.subList(2, 4)); - } - - @Test - void testSingleMalformedRecord() { - final List records = List.of( - createKinesisRecord(INVALID_JSON, "1") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertEquals(TEST_SHARD_ID, failureFlowFile.getAttribute(SHARD_ID)); - assertEquals("1", failureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("1", failureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - - failureFlowFile.assertContentEquals(INVALID_JSON, UTF_8); - } - - @Test - void testMalformedRecordBetweenValid() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(INVALID_JSON, "2"), - createKinesisRecord(USER_JSON_2, "3"), - createKinesisRecord(INVALID_JSON, "4"), - createKinesisRecord(USER_JSON_3, "5") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(2, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - assertEquals("3", successFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("1", successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("5", successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(successFlowFile, records.get(0), records.get(2), records.get(4)); - - assertAll(result.parseFailureFlowFiles().stream().map( - failureFlowFile -> () -> { - assertNotNull(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, failureFlowFile.getAttribute(SHARD_ID)); - } - )); - } - - @Test - void testIOExceptionDuringReaderCreation() { - final RecordReaderFactory failingReaderFactory = new MockRecordParser() { - @Override - public RecordReader createRecordReader(Map variables, InputStream in, long inputLength, ComponentLog logger) throws IOException { - throw new IOException("Failed to create reader"); - } - }; - - final ReaderRecordProcessor processor = new ReaderRecordProcessor(failingReaderFactory, new ValueRecordConverter(), jsonWriter, logger); - - final KinesisClientRecord record = createKinesisRecord(USER_JSON_1, "1"); - final List records = List.of(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertTrue(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE).contains("Failed to create reader")); - failureFlowFile.assertContentEquals(KinesisRecordPayload.extract(record), UTF_8); - } - - @Test - void testMalformedRecordExceptionDuringReading() { - final ReaderRecordProcessor processor = new ReaderRecordProcessor(getMalformedRecordExceptionReader(), new ValueRecordConverter(), jsonWriter, logger); - - final KinesisClientRecord record = createKinesisRecord(USER_JSON_1, "1"); - final List records = Collections.singletonList(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertTrue(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE).contains("Test exception")); - failureFlowFile.assertContentEquals(KinesisRecordPayload.extract(record), UTF_8); - } - - @Test - void testInvalidRecordsWithSchemaEvolution() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), // Schema A - createKinesisRecord(USER_JSON_2, "2"), // Schema A - createKinesisRecord(INVALID_JSON, "3"), - createKinesisRecord(CITY_JSON_1, "4"), // Schema B - createKinesisRecord(INVALID_JSON, "5"), - createKinesisRecord(CITY_JSON_2, "6"), // Schema B - createKinesisRecord("{\"category\":\"electronics\",\"price\":99.99}", "7") // Schema C - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(3, result.successFlowFiles().size()); - assertEquals(2, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("1", firstFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("2", firstFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(firstFlowFile, records.subList(0, 2)); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("2", secondFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("4", secondFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("6", secondFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(secondFlowFile, records.get(3), records.get(5)); - - final FlowFile thirdFlowFile = result.successFlowFiles().get(2); - assertEquals("1", thirdFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("7", thirdFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("7", thirdFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(thirdFlowFile, records.get(6)); - - final List failureFlowFiles = result.parseFailureFlowFiles(); - - final MockFlowFile firstFailureFlowFile = (MockFlowFile) failureFlowFiles.getFirst(); - assertEquals("3", firstFailureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("3", firstFailureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(firstFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, firstFailureFlowFile.getAttribute(SHARD_ID)); - firstFailureFlowFile.assertContentEquals(KinesisRecordPayload.extract(records.get(2)), UTF_8); - - final MockFlowFile secondFailureFlowFile = (MockFlowFile) failureFlowFiles.get(1); - assertEquals("5", secondFailureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("5", secondFailureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(secondFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, secondFailureFlowFile.getAttribute(SHARD_ID)); - secondFailureFlowFile.assertContentEquals(KinesisRecordPayload.extract(records.get(4)), UTF_8); - } - - private static KinesisClientRecord createKinesisRecord(final String data, final String sequenceNumber) { - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data.getBytes(UTF_8))) - .sequenceNumber(sequenceNumber) - .partitionKey("key") - .approximateArrivalTimestamp(Instant.now()) - .build(); - } - - private static RecordReaderFactory getMalformedRecordExceptionReader() { - return new MockRecordParser() { - @Override - public RecordReader createRecordReader(Map variables, InputStream in, long inputLength, ComponentLog logger) { - return new RecordReader() { - @Override - public void close() { - } - - @Override - public Record nextRecord(boolean coerceTypes, boolean dropUnknownFields) throws MalformedRecordException { - throw new MalformedRecordException("Test exception"); - } - - @Override - public RecordSchema getSchema() { - return null; - } - }; - } - }; - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java deleted file mode 100644 index c4261d3087c9..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_RECORD; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.KINESIS_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.SCHEMA_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.createTestKinesisRecord; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.verifyMetadata; -import static org.junit.jupiter.api.Assertions.assertEquals; - -class InjectMetadataRecordConverterTest { - - private static final RecordSchema EXPECTED_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField("name", RecordFieldType.STRING.getDataType()), - new RecordField("age", RecordFieldType.INT.getDataType()), - new RecordField(KINESIS_METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)) - )); - - private static final InjectMetadataRecordConverter CONVERTER = new InjectMetadataRecordConverter(); - - @Test - void testConvertWithApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(TEST_ARRIVAL_TIMESTAMP); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.remove(KINESIS_METADATA); - assertEquals(INPUT_RECORD.toMap(), recordValues); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = true; - verifyMetadata(metadata, expectTimestamp); - } - - @Test - void testConvertWithoutApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(null); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.remove(KINESIS_METADATA); - assertEquals(INPUT_RECORD.toMap(), recordValues); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = false; - verifyMetadata(metadata, expectTimestamp); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java deleted file mode 100644 index 12f736656af2..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import jakarta.annotation.Nullable; -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.time.Instant; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.APPROX_ARRIVAL_TIMESTAMP; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; - -public class KinesisRecordConverterTestUtil { - - static final String KINESIS_METADATA = "kinesisMetadata"; - - static final String TEST_STREAM_NAME = "test-stream"; - static final String TEST_SHARD_ID = "shardId-000000000001"; - static final String TEST_SEQUENCE_NUMBER = "49590338271490256608559692538361571095921575989136588801"; - static final long TEST_SUB_SEQUENCE_NUMBER = 2; - static final String TEST_PARTITION_KEY = "test-partition-key"; - static final Instant TEST_ARRIVAL_TIMESTAMP = Instant.ofEpochMilli(1640995200000L); - - static final String EXPECTED_SHARDED_SEQUENCE_NUMBER = "4959033827149025660855969253836157109592157598913658880100000000000000000002"; - - static final RecordSchema INPUT_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField("name", RecordFieldType.STRING.getDataType()), - new RecordField("age", RecordFieldType.INT.getDataType()) - )); - - static final Record INPUT_RECORD = new MapRecord(INPUT_SCHEMA, Map.of( - "name", "John Doe", - "age", 30 - )); - - static final RecordSchema SCHEMA_METADATA = new SimpleRecordSchema(List.of( - new RecordField("stream", RecordFieldType.STRING.getDataType()), - new RecordField("shardId", RecordFieldType.STRING.getDataType()), - new RecordField("sequenceNumber", RecordFieldType.STRING.getDataType()), - new RecordField("subSequenceNumber", RecordFieldType.LONG.getDataType()), - new RecordField("shardedSequenceNumber", RecordFieldType.STRING.getDataType()), - new RecordField("partitionKey", RecordFieldType.STRING.getDataType()), - new RecordField(APPROX_ARRIVAL_TIMESTAMP, RecordFieldType.TIMESTAMP.getDataType()) - )); - - private KinesisRecordConverterTestUtil() { - // Utility class - } - - static KinesisClientRecord createTestKinesisRecord(final @Nullable Instant arrivalTimestamp) { - return KinesisClientRecord.builder() - .data(ByteBuffer.allocate(0)) - .sequenceNumber(TEST_SEQUENCE_NUMBER) - .subSequenceNumber(TEST_SUB_SEQUENCE_NUMBER) - .partitionKey(TEST_PARTITION_KEY) - .approximateArrivalTimestamp(arrivalTimestamp) - .build(); - } - - static void verifyMetadata(final Record metadata, final boolean expectTimestamp) { - assertEquals(TEST_STREAM_NAME, metadata.getValue("stream")); - assertEquals(TEST_SHARD_ID, metadata.getValue("shardId")); - assertEquals(TEST_SEQUENCE_NUMBER, metadata.getValue("sequenceNumber")); - assertEquals(TEST_SUB_SEQUENCE_NUMBER, metadata.getValue("subSequenceNumber")); - assertEquals(EXPECTED_SHARDED_SEQUENCE_NUMBER, metadata.getValue("shardedSequenceNumber")); - assertEquals(TEST_PARTITION_KEY, metadata.getValue("partitionKey")); - - if (expectTimestamp) { - assertEquals(TEST_ARRIVAL_TIMESTAMP.toEpochMilli(), metadata.getValue(APPROX_ARRIVAL_TIMESTAMP)); - } else { - assertNull(metadata.getValue(APPROX_ARRIVAL_TIMESTAMP)); - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java deleted file mode 100644 index b15e1efc5c45..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_RECORD; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_SCHEMA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.KINESIS_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.SCHEMA_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.createTestKinesisRecord; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.verifyMetadata; -import static org.junit.jupiter.api.Assertions.assertEquals; - -class WrapperRecordConverterTest { - - private static final RecordSchema EXPECTED_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField(KINESIS_METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)), - new RecordField("value", RecordFieldType.RECORD.getRecordDataType(INPUT_SCHEMA)) - )); - - private static final WrapperRecordConverter CONVERTER = new WrapperRecordConverter(); - - @Test - void testConvertWithApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(TEST_ARRIVAL_TIMESTAMP); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - assertEquals(INPUT_RECORD, record.getValue("value")); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = true; - verifyMetadata(metadata, expectTimestamp); - } - - @Test - void testConvertWithoutApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(null); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - assertEquals(INPUT_RECORD, record.getValue("value")); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = false; - verifyMetadata(metadata, expectTimestamp); - } -} From 144a8b842a0f9e097fba9742b1afb08cf98517a0 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Mon, 9 Mar 2026 15:58:34 -0400 Subject: [PATCH 2/7] NIFI-15669: Addressed review feedback --- .../aws/kinesis/CheckpointTableUtils.java | 69 +--- .../aws/kinesis/ConsumeKinesis.java | 334 +++++++++++------- .../aws/kinesis/EfoKinesisClient.java | 202 ++++++----- .../aws/kinesis/KinesisConsumerClient.java | 28 +- .../aws/kinesis/KinesisShardManager.java | 7 +- .../aws/kinesis/KplDeaggregator.java | 24 +- .../aws/kinesis/LegacyCheckpointMigrator.java | 6 +- .../aws/kinesis/PollingKinesisClient.java | 95 +++-- .../aws/kinesis/ShardCheckpoint.java | 4 +- .../aws/kinesis/ShardFetchResult.java | 9 +- .../aws/kinesis/CheckpointTableUtilsTest.java | 75 ++-- .../aws/kinesis/ConsumeKinesisTest.java | 24 +- .../kinesis/KinesisConsumerClientTest.java | 226 ++++-------- .../aws/kinesis/KinesisShardManagerTest.java | 31 +- .../aws/kinesis/PollingKinesisClientTest.java | 151 +++++++- 15 files changed, 707 insertions(+), 578 deletions(-) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java index 6661c2cd34d2..de038de782f3 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java @@ -63,17 +63,18 @@ enum TableSchema { static TableSchema getTableSchema(final DynamoDbClient client, final String tableName) { try { - final DescribeTableResponse describe = client.describeTable( - DescribeTableRequest.builder().tableName(tableName).build()); + final DescribeTableResponse describe = client.describeTable(DescribeTableRequest.builder().tableName(tableName).build()); final List keySchema = describe.table().keySchema(); if (keySchema.size() == 2 && hasKey(keySchema, "streamName", KeyType.HASH) && hasKey(keySchema, "shardId", KeyType.RANGE)) { return TableSchema.NEW; } + if (keySchema.size() == 1 && hasKey(keySchema, "leaseKey", KeyType.HASH)) { return TableSchema.LEGACY; } + return TableSchema.UNKNOWN; } catch (final ResourceNotFoundException notFound) { return TableSchema.NOT_FOUND; @@ -87,8 +88,7 @@ static void createNewSchemaTable(final DynamoDbClient client, final ComponentLog return; } if (tableSchema == TableSchema.LEGACY || tableSchema == TableSchema.UNKNOWN) { - throw new ProcessException( - "Checkpoint table [%s] exists but does not match expected schema".formatted(tableName)); + throw new ProcessException("Checkpoint table [%s] exists but does not match expected schema".formatted(tableName)); } logger.info("Creating DynamoDB checkpoint table [{}]", tableName); @@ -103,6 +103,7 @@ static void createNewSchemaTable(final DynamoDbClient client, final ComponentLog AttributeDefinition.builder().attributeName("shardId").attributeType(ScalarAttributeType.S).build()) .billingMode(BillingMode.PAY_PER_REQUEST) .build(); + client.createTable(request); } catch (final ResourceInUseException alreadyCreating) { logger.info("DynamoDB checkpoint table [{}] is already being created by another node", tableName); @@ -125,8 +126,8 @@ static void waitForTableActive(final DynamoDbClient client, final ComponentLog l throw new ProcessException("Interrupted while waiting for DynamoDB table to become ACTIVE", e); } } - throw new ProcessException("DynamoDB checkpoint table [%s] did not become ACTIVE within %d seconds" - .formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + + throw new ProcessException("DynamoDB checkpoint table [%s] did not become ACTIVE within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); } static void deleteTable(final DynamoDbClient client, final ComponentLog logger, final String tableName) { @@ -155,18 +156,13 @@ static void waitForTableDeleted(final DynamoDbClient client, final ComponentLog throw new ProcessException("Interrupted while waiting for DynamoDB table deletion", e); } } - throw new ProcessException( - "DynamoDB table [%s] was not deleted within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + + throw new ProcessException("DynamoDB table [%s] was not deleted within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); } static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog logger, final String sourceTableName, final String destTableName) { logger.info("Copying checkpoint items from [{}] to [{}]", sourceTableName, destTableName); - final TableSchema destinationSchema = getTableSchema(client, destTableName); - if (destinationSchema == TableSchema.NOT_FOUND || destinationSchema == TableSchema.UNKNOWN) { - throw new ProcessException("Cannot copy checkpoint items to [%s]: destination schema is %s" - .formatted(destTableName, destinationSchema)); - } Map exclusiveStartKey = null; int copied = 0; @@ -186,16 +182,9 @@ static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog } } - final Map destinationItem = convertItemForDestinationSchema(item, destinationSchema); - if (destinationItem == null) { - logger.debug("Skipping checkpoint item during copy because it cannot be converted for {} schema: keys={}", - destinationSchema, item.keySet()); - continue; - } - client.putItem(PutItemRequest.builder() .tableName(destTableName) - .item(destinationItem) + .item(item) .build()); copied++; } @@ -206,44 +195,6 @@ static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog logger.info("Copied {} checkpoint item(s) from [{}] to [{}]", copied, sourceTableName, destTableName); } - private static Map convertItemForDestinationSchema(final Map item, - final TableSchema destinationSchema) { - return switch (destinationSchema) { - case NEW -> item; - case LEGACY -> convertToLegacyItem(item); - case NOT_FOUND, UNKNOWN -> null; - }; - } - - private static Map convertToLegacyItem(final Map item) { - if (item.containsKey("leaseKey")) { - return item; - } - - final AttributeValue streamName = item.get("streamName"); - final AttributeValue shardId = item.get("shardId"); - if (streamName == null || shardId == null) { - return null; - } - - final String shardIdValue = shardId.s(); - if (shardIdValue == null || shardIdValue.isEmpty() - || shardIdValue.startsWith(NODE_HEARTBEAT_PREFIX) - || MIGRATION_MARKER_SHARD_ID.equals(shardIdValue)) { - return null; - } - - final AttributeValue sequenceNumber = item.get("sequenceNumber"); - final String leaseKey = streamName.s() + ":" + shardIdValue; - if (sequenceNumber != null && sequenceNumber.s() != null) { - return Map.of( - "leaseKey", AttributeValue.builder().s(leaseKey).build(), - "checkpoint", AttributeValue.builder().s(sequenceNumber.s()).build()); - } - - return Map.of("leaseKey", AttributeValue.builder().s(leaseKey).build()); - } - private static boolean hasKey(final List keySchema, final String keyName, final KeyType keyType) { for (final KeySchemaElement element : keySchema) { if (keyName.equals(element.attributeName()) && keyType == element.keyType()) { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index 61c223877bd2..e920687edd99 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -24,15 +24,18 @@ import org.apache.nifi.annotation.configuration.DefaultSettings; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnRemoved; import org.apache.nifi.annotation.lifecycle.OnScheduled; import org.apache.nifi.annotation.lifecycle.OnStopped; import org.apache.nifi.components.DescribedValue; import org.apache.nifi.components.PropertyDescriptor; import org.apache.nifi.components.Validator; +import org.apache.nifi.controller.NodeTypeProvider; import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.flowfile.attributes.CoreAttributes; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.migration.PropertyConfiguration; -import org.apache.nifi.migration.RelationshipConfiguration; +import org.apache.nifi.migration.ProxyServiceMigration; import org.apache.nifi.processor.AbstractProcessor; import org.apache.nifi.processor.DataUnit; import org.apache.nifi.processor.ProcessContext; @@ -58,6 +61,7 @@ import org.apache.nifi.serialization.record.RecordSchema; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.Protocol; import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.apache.ApacheHttpClient; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @@ -69,6 +73,7 @@ import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder; import software.amazon.awssdk.services.kinesis.KinesisClient; import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; +import software.amazon.awssdk.services.kinesis.model.DeregisterStreamConsumerRequest; import software.amazon.awssdk.services.kinesis.model.Shard; import java.io.ByteArrayInputStream; @@ -90,6 +95,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import static org.apache.nifi.processors.aws.region.RegionUtil.CUSTOM_REGION; import static org.apache.nifi.processors.aws.region.RegionUtil.REGION; @@ -120,7 +126,7 @@ @WritesAttribute(attribute = "aws.kinesis.last.subsequence.number", description = "Sub-Sequence Number of the last Kinesis Record in the FlowFile"), @WritesAttribute(attribute = "aws.kinesis.approximate.arrival.timestamp.ms", - description = "Approximate arrival timestamp of the last Kinesis Record in the FlowFile"), + description = "Approximate arrival timestamp associated with the Kinesis Record or records in the FlowFile"), @WritesAttribute(attribute = "mime.type", description = "Sets the mime.type attribute to the MIME Type specified by the Record Writer (if configured)"), @WritesAttribute(attribute = "record.count", @@ -136,8 +142,16 @@ @SystemResourceConsideration(resource = SystemResource.NETWORK, description = """ The processor will continually poll for new Records.""") @SystemResourceConsideration(resource = SystemResource.MEMORY, description = """ - ConsumeKinesis buffers Kinesis Records in memory until they can be processed. \ - The maximum size of the buffer is controlled by the 'Max Batch Size' property.""") + Records are fetched from Kinesis in the background and buffered in memory until they \ + can be written to FlowFiles. Up to 200 fetch responses may be buffered per shard for \ + both Shared Throughput and Enhanced Fan-Out. Each Shared Throughput response can \ + contain up to the value of 'Max Records Per Request' records (default 100) and up \ + to 10 MB, so the theoretical maximum is 20,000 records or approximately 2 GB per \ + shard at default settings. Each Enhanced Fan-Out push event can hold up to roughly \ + 2 MB, for a theoretical maximum of approximately 400 MB per shard. In practice the \ + buffer is typically much smaller because fetch threads block when the queue is \ + full and most responses are well below the maximum size. + """) public class ConsumeKinesis extends AbstractProcessor { static final String ATTR_STREAM_NAME = "aws.kinesis.stream.name"; @@ -149,6 +163,7 @@ public class ConsumeKinesis extends AbstractProcessor { static final String ATTR_PARTITION_KEY = "aws.kinesis.partition.key"; static final String ATTR_ARRIVAL_TIMESTAMP = "aws.kinesis.approximate.arrival.timestamp.ms"; static final String ATTR_MILLIS_BEHIND = "kinesis.millis.behind"; + static final String ATTR_RECORD_ERROR_MESSAGE = "record.error.message"; private static final long QUEUE_POLL_TIMEOUT_MILLIS = 100; private static final Duration API_CALL_TIMEOUT = Duration.ofSeconds(30); @@ -165,7 +180,10 @@ public class ConsumeKinesis extends AbstractProcessor { static final PropertyDescriptor APPLICATION_NAME = new PropertyDescriptor.Builder() .name("Application Name") - .description("The name of the Kinesis application. Used as the DynamoDB table name for checkpoint storage.") + .description(""" + The name of the Kinesis application. Used as the DynamoDB table name for checkpoint storage. \ + When Consumer Type is Enhanced Fan-Out, this value is also used as the registered consumer \ + name. This value should be unique for the stream.""") .required(true) .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) .build(); @@ -248,8 +266,9 @@ Specifies the string (interpreted as UTF-8) used to separate multiple Kinesis me .name("Max Records Per Request") .description("The maximum number of records to retrieve per GetRecords call. Maximum is 10,000.") .required(true) - .defaultValue("1000") + .defaultValue("100") .addValidator(StandardValidators.createLongValidator(1, 10000, true)) + .dependsOn(CONSUMER_TYPE, ConsumerType.SHARED_THROUGHPUT) .build(); static final PropertyDescriptor MAX_BATCH_DURATION = new PropertyDescriptor.Builder() @@ -330,6 +349,8 @@ Specifies the string (interpreted as UTF-8) used to separate multiple Kinesis me private volatile long maxBatchBytes; private volatile ProcessingStrategy processingStrategy = ProcessingStrategy.valueOf(PROCESSING_STRATEGY.getDefaultValue()); + private volatile String efoConsumerArn; + private final AtomicLong shardRoundRobinCounter = new AtomicLong(); @Override protected List getSupportedPropertyDescriptors() { @@ -353,16 +374,12 @@ public void onPropertyModified(final PropertyDescriptor descriptor, final String @Override public void migrateProperties(final PropertyConfiguration config) { + ProxyServiceMigration.renameProxyConfigurationServiceProperty(config); config.renameProperty("Max Bytes to Buffer", "Max Batch Size"); config.removeProperty("Checkpoint Interval"); config.removeProperty("Metrics Publishing"); } - @Override - public void migrateRelationships(final RelationshipConfiguration config) { - config.renameRelationship("parse failure", "parse.failure"); - } - @OnScheduled public void onScheduled(final ProcessContext context) { final Region region = RegionUtil.getRegion(context); @@ -403,15 +420,15 @@ public void onScheduled(final ProcessContext context) { final String checkpointTableName = context.getProperty(APPLICATION_NAME).getValue(); streamName = context.getProperty(STREAM_NAME).getValue(); - maxRecordsPerRequest = context.getProperty(MAX_RECORDS_PER_REQUEST).asInteger(); initialStreamPosition = context.getProperty(INITIAL_STREAM_POSITION).getValue(); maxBatchNanos = context.getProperty(MAX_BATCH_DURATION).asTimePeriod(TimeUnit.NANOSECONDS); maxBatchBytes = context.getProperty(MAX_BATCH_SIZE).asDataSize(DataUnit.B).longValue(); + final boolean efoMode = ConsumerType.ENHANCED_FAN_OUT.equals(context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class)); + maxRecordsPerRequest = efoMode ? 0 : context.getProperty(MAX_RECORDS_PER_REQUEST).asInteger(); + shardManager = createShardManager(kinesisClient, dynamoDbClient, getLogger(), checkpointTableName, streamName); shardManager.ensureCheckpointTableExists(); - - final boolean efoMode = ConsumerType.ENHANCED_FAN_OUT.equals(context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class)); consumerClient = createConsumerClient(kinesisClient, getLogger(), efoMode); final Instant timestampForPosition = resolveTimestampPosition(context); @@ -425,6 +442,7 @@ public void onScheduled(final ProcessContext context) { if (efoMode) { final NettyNioAsyncHttpClient.Builder nettyBuilder = NettyNioAsyncHttpClient.builder() + .protocol(Protocol.HTTP2) .maxConcurrency(500) .connectionAcquisitionTimeout(Duration.ofSeconds(60)); @@ -498,6 +516,9 @@ public void onStopped() { shardManager = null; } + if (consumerClient instanceof EfoKinesisClient efo) { + efoConsumerArn = efo.getConsumerArn(); + } if (consumerClient != null) { consumerClient.close(); consumerClient = null; @@ -524,9 +545,41 @@ public void onStopped() { dynamoHttpClient = null; } + @OnRemoved + public void onRemoved(final ProcessContext context) { + final String arn = efoConsumerArn; + efoConsumerArn = null; + if (arn == null) { + return; + } + + final Region region = RegionUtil.getRegion(context); + final AwsCredentialsProvider credentialsProvider = context.getProperty(AWS_CREDENTIALS_PROVIDER_SERVICE) + .asControllerService(AwsCredentialsProviderService.class).getAwsCredentialsProvider(); + final String endpointOverride = context.getProperty(ENDPOINT_OVERRIDE).getValue(); + + final KinesisClientBuilder builder = KinesisClient.builder() + .region(region) + .credentialsProvider(credentialsProvider); + + if (endpointOverride != null && !endpointOverride.isEmpty()) { + builder.endpointOverride(URI.create(endpointOverride)); + } + + try (final KinesisClient tempClient = builder.build()) { + tempClient.deregisterStreamConsumer(DeregisterStreamConsumerRequest.builder() + .consumerARN(arn) + .build()); + getLogger().info("Deregistered EFO consumer [{}]", arn); + } catch (final Exception e) { + getLogger().warn("Failed to deregister EFO consumer [{}]; manual cleanup may be required", arn, e); + } + } + @Override public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { - final int clusterMemberCount = Math.max(1, getNodeTypeProvider().getClusterMembers().size()); + final NodeTypeProvider nodeTypeProvider = getNodeTypeProvider(); + final int clusterMemberCount = nodeTypeProvider.isClustered() ? 0 : Math.max(1, nodeTypeProvider.getClusterMembers().size()); shardManager.refreshLeasesIfNecessary(clusterMemberCount); final List ownedShards = shardManager.getOwnedShards(); @@ -545,49 +598,53 @@ public void onTrigger(final ProcessContext context, final ProcessSession session consumerClient.logDiagnostics(ownedShards.size(), shardManager.getCachedShardCount()); final Set claimedShards = new HashSet<>(); - final List consumed = consumeRecords(claimedShards); - final List accepted = discardRelinquishedResults(consumed, claimedShards); + try { + final List consumed = consumeRecords(claimedShards); + final List accepted = discardRelinquishedResults(consumed, claimedShards); - if (accepted.isEmpty()) { - consumerClient.releaseShards(claimedShards); - context.yield(); - return; - } + if (accepted.isEmpty()) { + consumerClient.releaseShards(claimedShards); + context.yield(); + return; + } - final PartitionedBatch batch = partitionByShardAndCheckpoint(accepted); + final PartitionedBatch batch = partitionByShardAndCheckpoint(accepted); - final WriteResult output; - try { - output = writeResults(session, context, batch.resultsByShard()); - } catch (final Exception e) { - handleWriteFailure(e, accepted, claimedShards, context); - return; - } + final WriteResult output; + try { + output = writeResults(session, context, batch.resultsByShard()); + } catch (final Exception e) { + handleWriteFailure(e, accepted, claimedShards, context); + return; + } - if (output.produced().isEmpty() && output.parseFailures().isEmpty()) { - consumerClient.releaseShards(claimedShards); - context.yield(); - return; - } + if (output.produced().isEmpty() && output.parseFailures().isEmpty()) { + consumerClient.releaseShards(claimedShards); + context.yield(); + return; + } - session.transfer(output.produced(), REL_SUCCESS); - if (!output.parseFailures().isEmpty()) { - session.transfer(output.parseFailures(), REL_PARSE_FAILURE); - session.adjustCounter("Records Parse Failure", output.parseFailures().size(), false); - } - session.adjustCounter("Records Consumed", output.totalRecordCount(), false); - final long dedupEvents = consumerClient.drainDeduplicatedEventCount(); - if (dedupEvents > 0) { - session.adjustCounter("EFO Deduplicated Events", dedupEvents, false); - } + session.transfer(output.produced(), REL_SUCCESS); + if (!output.parseFailures().isEmpty()) { + session.transfer(output.parseFailures(), REL_PARSE_FAILURE); + session.adjustCounter("Records Parse Failure", output.parseFailures().size(), false); + } + session.adjustCounter("Records Consumed", output.totalRecordCount(), false); + final long dedupEvents = consumerClient.drainDeduplicatedEventCount(); + if (dedupEvents > 0) { + session.adjustCounter("EFO Deduplicated Events", dedupEvents, false); + } - session.commitAsync( + session.commitAsync( () -> { try { shardManager.writeCheckpoints(batch.checkpoints()); - consumerClient.acknowledgeResults(accepted); } finally { - consumerClient.releaseShards(claimedShards); + try { + consumerClient.acknowledgeResults(accepted); + } finally { + consumerClient.releaseShards(claimedShards); + } } }, failure -> { @@ -598,6 +655,10 @@ public void onTrigger(final ProcessContext context, final ProcessSession session consumerClient.releaseShards(claimedShards); } }); + } catch (final Exception e) { + consumerClient.releaseShards(claimedShards); + throw e; + } } private List discardRelinquishedResults(final List consumedResults, final Set claimedShards) { @@ -614,8 +675,8 @@ private List discardRelinquishedResults(final List new ArrayList<>()).add(result); } for (final List shardResults : resultsByShard.values()) { - shardResults.sort(Comparator.comparing(r -> new BigInteger(r.firstSequenceNumber()))); + shardResults.sort(Comparator.comparing(ShardFetchResult::firstSequenceNumber)); } final Map checkpoints = new HashMap<>(); - for (final ShardFetchResult result : accepted) { - final ShardCheckpoint incoming = new ShardCheckpoint(result.lastSequenceNumber(), result.lastSubSequenceNumber()); - checkpoints.merge(result.shardId(), incoming, ShardCheckpoint::max); + for (final List shardResults : resultsByShard.values()) { + final ShardFetchResult last = shardResults.getLast(); + checkpoints.put(last.shardId(), new ShardCheckpoint(last.lastSequenceNumber(), last.lastSubSequenceNumber())); } return new PartitionedBatch(resultsByShard, checkpoints); @@ -647,39 +708,41 @@ private List consumeRecords(final Set claimedShards) { long estimatedBytes = 0; while (System.nanoTime() < startNanos + maxBatchNanos && estimatedBytes < maxBatchBytes) { - boolean foundAny = false; + final List readyShards = consumerClient.getShardIdsWithResults(); + if (readyShards.isEmpty()) { + if (!consumerClient.hasPendingFetches()) { + break; + } - for (final String shardId : consumerClient.getShardIdsWithResults()) { - if (estimatedBytes >= maxBatchBytes) { + try { + consumerClient.awaitResults(QUEUE_POLL_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); break; } + continue; + } + + boolean foundAny = false; + final int shardCount = readyShards.size(); + final int startOffset = (int) (shardRoundRobinCounter.getAndIncrement() % shardCount); + for (int i = 0; i < shardCount && estimatedBytes < maxBatchBytes; i++) { + final String shardId = readyShards.get((startOffset + i) % shardCount); if (!claimedShards.contains(shardId) && !consumerClient.claimShard(shardId)) { continue; } claimedShards.add(shardId); - ShardFetchResult result; - while ((result = consumerClient.pollShardResult(shardId)) != null) { + final ShardFetchResult result = consumerClient.pollShardResult(shardId); + if (result != null) { results.add(result); estimatedBytes += estimateResultBytes(result); foundAny = true; - if (estimatedBytes >= maxBatchBytes) { - break; - } } } if (!foundAny) { - if (!consumerClient.hasPendingFetches()) { - break; - } - - try { - consumerClient.awaitResults(QUEUE_POLL_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - break; - } + break; } } @@ -902,6 +965,8 @@ private void writeRecordBatch(final ProcessSession session, final RecordReaderFa final String streamName, final BatchAccumulator batch, final List output) { FlowFile flowFile = session.create(); + final Map attributes = new HashMap<>(); + try { flowFile = session.write(flowFile, new OutputStreamCallback() { @Override @@ -926,17 +991,17 @@ public void process(final OutputStream out) throws IOException { int recordIndex = 0; org.apache.nifi.serialization.record.Record nifiRecord; while ((nifiRecord = reader.nextRecord()) != null) { - if (recordIndex < records.size()) { - final DeaggregatedRecord record = records.get(recordIndex); - nifiRecord = decorateRecord(nifiRecord, record, record.shardId(), streamName, outputStrategy, writeSchema); - recordIndex++; - } + final DeaggregatedRecord record = records.get(recordIndex++); + nifiRecord = decorateRecord(nifiRecord, record, record.shardId(), streamName, outputStrategy, writeSchema); writer.write(nifiRecord); batch.incrementRecordCount(); } - writer.finishRecordSet(); + final org.apache.nifi.serialization.WriteResult writeResult = writer.finishRecordSet(); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), writer.getMimeType()); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); } } catch (final MalformedRecordException | SchemaNotFoundException e) { throw new IOException(e); @@ -944,7 +1009,8 @@ public void process(final OutputStream out) throws IOException { } }); - flowFile = session.putAllAttributes(flowFile, createFlowFileAttributes(streamName, batch)); + attributes.putAll(createFlowFileAttributes(streamName, batch)); + flowFile = session.putAllAttributes(flowFile, attributes); session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, batch.getLastShardId())); output.add(flowFile); } catch (final Exception e) { @@ -984,19 +1050,18 @@ private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session final List output = new ArrayList<>(); final List parseFailureOutput = new ArrayList<>(); - final List unparseable = new ArrayList<>(); + final List unparseable = new ArrayList<>(); FlowFile currentFlowFile = null; OutputStream currentOut = null; RecordSetWriter currentWriter = null; RecordSchema currentReadSchema = null; RecordSchema currentWriteSchema = null; - int currentRecordCount = 0; try { for (final DeaggregatedRecord record : records) { if (record.data().length == 0) { - unparseable.add(record); + unparseable.add(new ParseFailureRecord(record, "Record content is empty")); continue; } @@ -1011,27 +1076,31 @@ private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session parsedRecords.add(nifiRecord); } } catch (final MalformedRecordException | SchemaNotFoundException | IOException e) { - getLogger().debug("Kinesis record seq {} classified as unparseable: {}", - record.sequenceNumber(), e.getMessage()); - unparseable.add(record); + getLogger().debug("Kinesis record seq {} classified as unparseable: {}", record.sequenceNumber(), e.getMessage()); + unparseable.add(new ParseFailureRecord(record, e.toString())); continue; } finally { closeQuietly(reader); } if (parsedRecords.isEmpty()) { - unparseable.add(record); + unparseable.add(new ParseFailureRecord(record, "Record content produced no parsed records")); continue; } if (currentWriter == null || !readSchema.equals(currentReadSchema)) { if (currentWriter != null) { - currentWriter.finishRecordSet(); + final org.apache.nifi.serialization.WriteResult writeResult = currentWriter.finishRecordSet(); + currentWriter.close(); currentOut.close(); - final Map attrs = createFlowFileAttributes(streamName, batch); - attrs.put("record.count", String.valueOf(currentRecordCount)); - currentFlowFile = session.putAllAttributes(currentFlowFile, attrs); + + final Map attributes = createFlowFileAttributes(streamName, batch); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), currentWriter.getMimeType()); + currentFlowFile = session.putAllAttributes(currentFlowFile, attributes); + session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); output.add(currentFlowFile); currentFlowFile = null; @@ -1043,26 +1112,30 @@ private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session currentOut = session.write(currentFlowFile); currentWriter = writerFactory.createWriter(getLogger(), currentWriteSchema, currentOut, Map.of()); currentWriter.beginRecordSet(); - currentRecordCount = 0; + batch.resetRanges(); } + batch.updateRecordRange(record); + for (final org.apache.nifi.serialization.record.Record parsed : parsedRecords) { - // TODO: We need to use decorateRecord above also. final org.apache.nifi.serialization.record.Record decorated = decorateRecord(parsed, record, record.shardId(), streamName, outputStrategy, currentWriteSchema); currentWriter.write(decorated); - currentRecordCount++; batch.incrementRecordCount(); } } if (currentWriter != null) { - currentWriter.finishRecordSet(); + final org.apache.nifi.serialization.WriteResult writeResult = currentWriter.finishRecordSet(); currentWriter.close(); currentOut.close(); - final Map attrs = createFlowFileAttributes(streamName, batch); - attrs.put("record.count", String.valueOf(currentRecordCount)); - currentFlowFile = session.putAllAttributes(currentFlowFile, attrs); + + final Map attributes = createFlowFileAttributes(streamName, batch); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), currentWriter.getMimeType()); + + currentFlowFile = session.putAllAttributes(currentFlowFile, attributes); session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); output.add(currentFlowFile); currentFlowFile = null; @@ -1135,10 +1208,11 @@ private static org.apache.nifi.serialization.record.Record decorateRecord( }; } - private void writeParseFailures(final ProcessSession session, final List unparseable, + private void writeParseFailures(final ProcessSession session, final List unparseable, final String streamName, final BatchAccumulator batch, final List parseFailureOutput) { - for (final DeaggregatedRecord record : unparseable) { + for (final ParseFailureRecord parseFailureRecord : unparseable) { + final DeaggregatedRecord record = parseFailureRecord.record(); FlowFile flowFile = session.create(); try { final byte[] rawBytes = record.data(); @@ -1155,6 +1229,7 @@ private void writeParseFailures(final ProcessSession session, final List createFlowFileAttributes(final String streamN if (batch.getLastPartitionKey() != null) { attributes.put(ATTR_PARTITION_KEY, batch.getLastPartitionKey()); } - if (batch.getEarliestArrivalTimestamp() != null) { - attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(batch.getEarliestArrivalTimestamp().toEpochMilli())); + if (batch.getLatestArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(batch.getLatestArrivalTimestamp().toEpochMilli())); } return attributes; } - private static void closeQuietly(final AutoCloseable closeable) { + private void closeQuietly(final AutoCloseable closeable) { if (closeable != null) { try { closeable.close(); - } catch (final Exception ignored) { + } catch (final Exception e) { + getLogger().warn("Failed to close Record Writer", e); } } } private static String buildTransitUri(final String streamName, final String shardId) { - if (shardId == null || shardId.isEmpty()) { - return "kinesis://" + streamName; - } return "kinesis://" + streamName + "/" + shardId; } @@ -1241,6 +1314,9 @@ protected KinesisConsumerClient createConsumerClient(final KinesisClient kinesis private record RecordBatchResult(List output, List parseFailures) { } + private record ParseFailureRecord(DeaggregatedRecord record, String reason) { + } + private static final class KinesisRecordInputStream extends InputStream { private final List chunks; private int chunkIndex; @@ -1339,12 +1415,12 @@ private static final class BatchAccumulator { private long bytesConsumed; private long recordCount; private long maxMillisBehind = -1; - private String minSequenceNumber; - private String maxSequenceNumber; + private BigInteger minSequenceNumber; + private BigInteger maxSequenceNumber; private long minSubSequenceNumber = Long.MAX_VALUE; private long maxSubSequenceNumber = Long.MIN_VALUE; private String lastPartitionKey; - private Instant earliestArrivalTimestamp; + private Instant latestArrivalTimestamp; private String lastShardId; long getBytesConsumed() { @@ -1360,11 +1436,11 @@ long getMaxMillisBehind() { } String getMinSequenceNumber() { - return minSequenceNumber; + return minSequenceNumber != null ? minSequenceNumber.toString() : null; } String getMaxSequenceNumber() { - return maxSequenceNumber; + return maxSequenceNumber != null ? maxSequenceNumber.toString() : null; } long getMinSubSequenceNumber() { @@ -1379,8 +1455,8 @@ String getLastPartitionKey() { return lastPartitionKey; } - Instant getEarliestArrivalTimestamp() { - return earliestArrivalTimestamp; + Instant getLatestArrivalTimestamp() { + return latestArrivalTimestamp; } String getLastShardId() { @@ -1408,17 +1484,18 @@ void updateMillisBehind(final long millisBehindLatest) { } void updateSequenceRange(final ShardFetchResult result) { - final String firstSeq = result.firstSequenceNumber(); - final String lastSeq = result.lastSequenceNumber(); - if (minSequenceNumber == null || new BigInteger(firstSeq).compareTo(new BigInteger(minSequenceNumber)) < 0) { + final BigInteger firstSeq = result.firstSequenceNumber(); + final BigInteger lastSeq = result.lastSequenceNumber(); + if (minSequenceNumber == null || firstSeq.compareTo(minSequenceNumber) < 0) { minSequenceNumber = firstSeq; } - if (maxSequenceNumber == null || new BigInteger(lastSeq).compareTo(new BigInteger(maxSequenceNumber)) > 0) { + if (maxSequenceNumber == null || lastSeq.compareTo(maxSequenceNumber) > 0) { maxSequenceNumber = lastSeq; } } void updateRecordRange(final DeaggregatedRecord record) { + updateSequenceFromRecord(record); final long subSeq = record.subSequenceNumber(); if (subSeq < minSubSequenceNumber) { minSubSequenceNumber = subSeq; @@ -1428,10 +1505,29 @@ void updateRecordRange(final DeaggregatedRecord record) { } lastPartitionKey = record.partitionKey(); final Instant arrival = record.approximateArrivalTimestamp(); - if (arrival != null && (earliestArrivalTimestamp == null || arrival.isBefore(earliestArrivalTimestamp))) { - earliestArrivalTimestamp = arrival; + if (arrival != null && (latestArrivalTimestamp == null || arrival.isAfter(latestArrivalTimestamp))) { + latestArrivalTimestamp = arrival; } } + + void updateSequenceFromRecord(final DeaggregatedRecord record) { + final BigInteger seqNum = new BigInteger(record.sequenceNumber()); + if (minSequenceNumber == null || seqNum.compareTo(minSequenceNumber) < 0) { + minSequenceNumber = seqNum; + } + if (maxSequenceNumber == null || seqNum.compareTo(maxSequenceNumber) > 0) { + maxSequenceNumber = seqNum; + } + } + + void resetRanges() { + minSequenceNumber = null; + maxSequenceNumber = null; + minSubSequenceNumber = Long.MAX_VALUE; + maxSubSequenceNumber = Long.MIN_VALUE; + lastPartitionKey = null; + latestArrivalTimestamp = null; + } } enum ConsumerType implements DescribedValue { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java index 32b6132662e4..7c9d2ab3b97b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java @@ -25,7 +25,6 @@ import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisClient; import software.amazon.awssdk.services.kinesis.model.ConsumerStatus; -import software.amazon.awssdk.services.kinesis.model.DeregisterStreamConsumerRequest; import software.amazon.awssdk.services.kinesis.model.DescribeStreamConsumerRequest; import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; @@ -46,18 +45,18 @@ import java.io.IOException; import java.math.BigInteger; import java.time.Instant; -import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Set; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; /** @@ -70,10 +69,10 @@ final class EfoKinesisClient extends KinesisConsumerClient { private static final long SUBSCRIBE_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(5); private static final long CONSUMER_REGISTRATION_POLL_MILLIS = 1_000; private static final int CONSUMER_REGISTRATION_MAX_ATTEMPTS = 60; - private static final int MAX_SUBSCRIPTIONS_PER_TRIGGER = 10; private static final int MAX_QUEUED_RESULTS = 200; private final Map shardConsumers = new ConcurrentHashMap<>(); + private final Queue pausedConsumers = new ConcurrentLinkedQueue<>(); private volatile KinesisAsyncClient kinesisAsyncClient; private volatile String consumerArn; private volatile Instant timestampForInitialPosition; @@ -92,28 +91,29 @@ void initialize(final KinesisAsyncClient asyncClient, final String streamName, f registerEfoConsumer(streamName, consumerName); } + void initializeForTest(final KinesisAsyncClient asyncClient, final String theConsumerArn) { + this.kinesisAsyncClient = asyncClient; + this.consumerArn = theConsumerArn; + } + + ShardConsumer getShardConsumer(final String shardId) { + return shardConsumers.get(shardId); + } + @Override void startFetches(final List shards, final String streamName, final int batchSize, final String initialStreamPosition, final KinesisShardManager shardManager) { - if (totalQueuedResults() >= MAX_QUEUED_RESULTS) { - return; - } - - int subscriptionsCreated = 0; final long now = System.nanoTime(); for (final Shard shard : shards) { - if (subscriptionsCreated >= MAX_SUBSCRIPTIONS_PER_TRIGGER) { - break; - } - final String shardId = shard.shardId(); final ShardConsumer existing = shardConsumers.get(shardId); if (existing == null) { - final String lastSeq = shardManager.readCheckpoint(shardId); + final String checkpoint = shardManager.readCheckpoint(shardId); + final BigInteger lastSeq = checkpoint != null ? new BigInteger(checkpoint) : null; final StartingPosition startingPosition = buildStartingPosition(lastSeq, initialStreamPosition); logger.info("Creating EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), lastSeq); - final ShardConsumer sc = new ShardConsumer(shardId, EfoKinesisClient.this::enqueueResult, logger); + final ShardConsumer sc = new ShardConsumer(shardId, EfoKinesisClient.this::enqueueResult, pausedConsumers, logger); final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, sc); if (prior == null) { try { @@ -122,7 +122,6 @@ void startFetches(final List shards, final String streamName, final int b shardConsumers.remove(shardId, sc); throw e; } - subscriptionsCreated++; } } else if (existing.isSubscriptionExpired()) { final long lastAttempt = existing.getLastSubscribeAttemptNanos(); @@ -130,16 +129,17 @@ void startFetches(final List shards, final String streamName, final int b continue; } - final String resumeSeq = maxSequenceNumber( - existing.getLastQueuedSequenceNumber(), - existing.getLastAcknowledgedSequenceNumber(), - shardManager.readCheckpoint(shardId)); + final String checkpoint = shardManager.readCheckpoint(shardId); + final BigInteger checkpointSeq = checkpoint == null ? null : new BigInteger(checkpoint); + final BigInteger lastQueued = existing.getLastQueuedSequenceNumber(); + final BigInteger resumeSeq = maxSequenceNumber(lastQueued, checkpointSeq); final StartingPosition startingPosition = buildStartingPosition(resumeSeq, initialStreamPosition); logger.debug("Renewing expired EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), resumeSeq); existing.subscribe(kinesisAsyncClient, consumerArn, startingPosition); - subscriptionsCreated++; } } + + resumePausedConsumers(); } @Override @@ -158,25 +158,25 @@ long drainDeduplicatedEventCount() { @Override void acknowledgeResults(final List results) { - final Map consumersToRequest = new HashMap<>(); + resumePausedConsumers(); + } - for (final ShardFetchResult result : results) { - final ShardConsumer consumer = shardConsumers.get(result.shardId()); - if (consumer != null) { - consumer.acknowledgeSequenceNumber(result.lastSequenceNumber()); - consumer.markAcknowledged(); - consumersToRequest.putIfAbsent(result.shardId(), consumer); + private void resumePausedConsumers() { + int available = MAX_QUEUED_RESULTS - totalQueuedResults(); + int resumed = 0; + ShardConsumer consumer; + while (resumed < available && (consumer = pausedConsumers.poll()) != null) { + if (consumer.requestNextIfReady()) { + resumed++; } } - - for (final ShardConsumer consumer : consumersToRequest.values()) { - consumer.requestNext(); - } } @Override void rollbackResults(final List results) { for (final ShardFetchResult result : results) { + drainShardQueue(result.shardId()); + final ShardConsumer sc = shardConsumers.remove(result.shardId()); if (sc != null) { sc.cancel(); @@ -229,8 +229,6 @@ void close() { } shardConsumers.clear(); - deregisterEfoConsumer(); - if (kinesisAsyncClient != null) { kinesisAsyncClient.close(); kinesisAsyncClient = null; @@ -239,21 +237,8 @@ void close() { super.close(); } - private void deregisterEfoConsumer() { - final String arn = consumerArn; - consumerArn = null; - if (arn == null) { - return; - } - - try { - kinesisClient.deregisterStreamConsumer(DeregisterStreamConsumerRequest.builder() - .consumerARN(arn) - .build()); - logger.info("Deregistered EFO consumer [{}]", arn); - } catch (final Exception e) { - logger.warn("Failed to deregister EFO consumer [{}]; manual cleanup may be required", arn, e); - } + String getConsumerArn() { + return consumerArn; } private void registerEfoConsumer(final String streamName, final String consumerName) { @@ -319,26 +304,21 @@ private void waitForConsumerActive(final String theStreamArn, final String consu throw new ProcessException("EFO consumer [%s] did not become ACTIVE within %d seconds".formatted(consumerName, CONSUMER_REGISTRATION_MAX_ATTEMPTS)); } - private static String maxSequenceNumber(final String... candidates) { - BigInteger max = null; - String maxStr = null; - for (final String candidate : candidates) { - if (candidate != null) { - final BigInteger value = new BigInteger(candidate); - if (max == null || value.compareTo(max) > 0) { - max = value; - maxStr = candidate; - } - } + private static BigInteger maxSequenceNumber(final BigInteger a, final BigInteger b) { + if (a == null) { + return b; } - return maxStr; + if (b == null) { + return a; + } + return a.compareTo(b) >= 0 ? a : b; } - private StartingPosition buildStartingPosition(final String sequenceNumber, final String initialStreamPosition) { + private StartingPosition buildStartingPosition(final BigInteger sequenceNumber, final String initialStreamPosition) { if (sequenceNumber != null) { return StartingPosition.builder() .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER) - .sequenceNumber(sequenceNumber) + .sequenceNumber(sequenceNumber.toString()) .build(); } final ShardIteratorType iteratorType = ShardIteratorType.fromValue(initialStreamPosition); @@ -352,21 +332,21 @@ private StartingPosition buildStartingPosition(final String sequenceNumber, fina static final class ShardConsumer { private final String shardId; private final Consumer resultSink; + private final Queue pausedConsumers; private final ComponentLog consumerLogger; private final AtomicBoolean subscribing = new AtomicBoolean(false); + private final AtomicBoolean paused = new AtomicBoolean(false); private final AtomicInteger subscriptionGeneration = new AtomicInteger(); private final AtomicLong deduplicatedEvents = new AtomicLong(); private volatile Subscription subscription; private volatile CompletableFuture subscriptionFuture; private volatile long lastSubscribeAttemptNanos; - private volatile String lastQueuedSequenceNumber; - private volatile String lastOnNextMaxSequence; - private final AtomicReference lastAcknowledgedSequenceNumber = new AtomicReference<>(); - private final AtomicInteger queuedResultCount = new AtomicInteger(); - - ShardConsumer(final String shardId, final Consumer resultSink, final ComponentLog consumerLogger) { + private volatile BigInteger lastQueuedSequenceNumber; + ShardConsumer(final String shardId, final Consumer resultSink, + final Queue pausedConsumers, final ComponentLog consumerLogger) { this.shardId = shardId; this.resultSink = resultSink; + this.pausedConsumers = pausedConsumers; this.consumerLogger = consumerLogger; } @@ -407,6 +387,17 @@ void requestNext() { } } + boolean requestNextIfReady() { + if (paused.compareAndSet(true, false)) { + final Subscription sub = subscription; + if (sub != null) { + sub.request(1); + return true; + } + } + return false; + } + boolean isSubscriptionExpired() { final CompletableFuture future = subscriptionFuture; return future == null || future.isDone(); @@ -416,11 +407,7 @@ long getLastSubscribeAttemptNanos() { return lastSubscribeAttemptNanos; } - String getLastAcknowledgedSequenceNumber() { - return lastAcknowledgedSequenceNumber.get(); - } - - String getLastQueuedSequenceNumber() { + BigInteger getLastQueuedSequenceNumber() { return lastQueuedSequenceNumber; } @@ -428,28 +415,6 @@ long drainDeduplicatedEventCount() { return deduplicatedEvents.getAndSet(0); } - void markAcknowledged() { - queuedResultCount.updateAndGet(value -> value > 0 ? value - 1 : 0); - } - - void acknowledgeSequenceNumber(final String sequenceNumber) { - if (sequenceNumber == null) { - return; - } - - final BigInteger incoming = new BigInteger(sequenceNumber); - String existing; - while (true) { - existing = lastAcknowledgedSequenceNumber.get(); - if (existing != null && incoming.compareTo(new BigInteger(existing)) <= 0) { - return; - } - if (lastAcknowledgedSequenceNumber.compareAndSet(existing, sequenceNumber)) { - return; - } - } - } - void cancel() { final CompletableFuture future = subscriptionFuture; if (future != null) { @@ -458,6 +423,37 @@ void cancel() { subscription = null; } + void resetForRenewal() { + subscribing.set(false); + lastSubscribeAttemptNanos = 0; + } + + void setLastQueuedSequenceNumber(final BigInteger seq) { + lastQueuedSequenceNumber = seq; + } + + void setSubscription(final Subscription sub) { + subscription = sub; + } + + Subscription getSubscription() { + return subscription; + } + + int getSubscriptionGeneration() { + return subscriptionGeneration.get(); + } + + boolean isSubscribing() { + return subscribing.get(); + } + + void pause() { + if (paused.compareAndSet(false, true)) { + pausedConsumers.add(this); + } + } + private void logSubscriptionError(final Throwable t) { if (isCancellation(t)) { consumerLogger.debug("EFO subscription cancelled for shard {}", shardId); @@ -504,7 +500,7 @@ private static boolean isRetryableStreamDisconnect(final Throwable t) { return cause != null && cause != t && isRetryableStreamDisconnect(cause); } - private void endSubscriptionIfCurrent(final int generation) { + void endSubscriptionIfCurrent(final int generation) { if (subscriptionGeneration.get() == generation) { subscription = null; subscribing.set(false); @@ -512,12 +508,10 @@ private void endSubscriptionIfCurrent(final int generation) { } private List deduplicateRecords(final List records) { - final String prevMax = lastOnNextMaxSequence; - if (prevMax == null) { + final BigInteger threshold = lastQueuedSequenceNumber; + if (threshold == null) { return records; } - - final BigInteger threshold = new BigInteger(prevMax); int firstNewIndex = records.size(); for (int i = 0; i < records.size(); i++) { if (new BigInteger(records.get(i).sequenceNumber()).compareTo(threshold) > 0) { @@ -551,6 +545,7 @@ private class DemandDrivenSubscriber implements Subscriber queue = shardQueues.get(shardId); - return queue != null ? queue.poll() : null; + final ShardFetchResult result = queue != null ? queue.poll() : null; + if (result != null) { + onResultPolled(); + } + return result; + } + + protected void onResultPolled() { + } + + int drainShardQueue(final String shardId) { + final Queue queue = shardQueues.get(shardId); + if (queue == null) { + return 0; + } + int drained = 0; + while (queue.poll() != null) { + drained++; + } + return drained; } ShardFetchResult pollAnyResult(final long timeout, final TimeUnit unit) throws InterruptedException { @@ -97,6 +116,7 @@ ShardFetchResult pollAnyResult(final long timeout, final TimeUnit unit) throws I for (final Queue queue : shardQueues.values()) { final ShardFetchResult result = queue.poll(); if (result != null) { + onResultPolled(); return result; } } @@ -110,8 +130,8 @@ ShardFetchResult pollAnyResult(final long timeout, final TimeUnit unit) throws I return null; } - Set getShardIdsWithResults() { - final Set ids = new HashSet<>(); + List getShardIdsWithResults() { + final List ids = new ArrayList<>(); for (final Map.Entry> entry : shardQueues.entrySet()) { if (!entry.getValue().isEmpty()) { ids.add(entry.getKey()); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java index 322389317982..5885de7b25c7 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java @@ -31,7 +31,6 @@ import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; import software.amazon.awssdk.services.kinesis.model.Shard; -import java.math.BigInteger; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; @@ -499,8 +498,6 @@ private int countActiveNodes(final long now) { } private void writeCheckpoint(final String shardId, final ShardCheckpoint checkpoint) { - final BigInteger incomingSeq = new BigInteger(checkpoint.sequenceNumber()); - final ShardCheckpoint written = highestWrittenCheckpoints.compute(shardId, (key, existing) -> { if (existing != null && checkpoint.max(existing) == existing) { return existing; @@ -515,7 +512,7 @@ private void writeCheckpoint(final String shardId, final ShardCheckpoint checkpo + " lastUpdateTimestamp = :ts, leaseExpiry = :exp") .conditionExpression("leaseOwner = :owner") .expressionAttributeValues(Map.of( - ":seq", AttributeValue.builder().s(checkpoint.sequenceNumber()).build(), + ":seq", AttributeValue.builder().s(checkpoint.sequenceNumber().toString()).build(), ":subSeq", AttributeValue.builder().n(String.valueOf(checkpoint.subSequenceNumber())).build(), ":ts", AttributeValue.builder().n(String.valueOf(now)).build(), ":exp", AttributeValue.builder().n(String.valueOf(now + leaseDurationMillis)).build(), @@ -532,7 +529,7 @@ private void writeCheckpoint(final String shardId, final ShardCheckpoint checkpo return existing; }); - if (written != null && incomingSeq.compareTo(new BigInteger(written.sequenceNumber())) < 0) { + if (written != null && checkpoint.sequenceNumber().compareTo(written.sequenceNumber()) < 0) { logger.debug("Skipped checkpoint regression for shard {} (highest: {}, attempted: {})", shardId, written.sequenceNumber(), checkpoint.sequenceNumber()); } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java index 0e9918f129c7..c5ca7fe2a447 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java @@ -85,16 +85,16 @@ private static void deaggregateRecord(final String shardId, final Record record, return; } - final byte[] protobufBytes = Arrays.copyOfRange(data, KPL_MAGIC.length, data.length - MD5_DIGEST_LENGTH); - final byte[] trailingMd5 = Arrays.copyOfRange(data, data.length - MD5_DIGEST_LENGTH, data.length); + final int protobufOffset = KPL_MAGIC.length; + final int protobufLength = data.length - KPL_MAGIC.length - MD5_DIGEST_LENGTH; - if (!verifyMd5(protobufBytes, trailingMd5)) { + if (!verifyMd5(data, protobufOffset, protobufLength)) { out.add(passthrough(shardId, record, data)); return; } try { - parseAggregatedRecord(shardId, record, protobufBytes, out); + parseAggregatedRecord(shardId, record, data, protobufOffset, protobufLength, out); } catch (final Exception e) { out.add(passthrough(shardId, record, data)); } @@ -110,22 +110,26 @@ static boolean isAggregated(final byte[] data) { && data[3] == KPL_MAGIC[3]; } - private static boolean verifyMd5(final byte[] protobufBytes, final byte[] expectedMd5) { + private static boolean verifyMd5(final byte[] data, final int protobufOffset, final int protobufLength) { try { - final byte[] computed = MessageDigest.getInstance("MD5").digest(protobufBytes); - return Arrays.equals(computed, expectedMd5); + final MessageDigest md5 = MessageDigest.getInstance("MD5"); + md5.update(data, protobufOffset, protobufLength); + final byte[] computed = md5.digest(); + final int md5Offset = protobufOffset + protobufLength; + return Arrays.equals(computed, 0, MD5_DIGEST_LENGTH, data, md5Offset, md5Offset + MD5_DIGEST_LENGTH); } catch (final NoSuchAlgorithmException e) { return false; } } - private static void parseAggregatedRecord(final String shardId, final Record kinesisRecord, final byte[] protobufBytes, - final List out) throws Exception { + private static void parseAggregatedRecord(final String shardId, final Record kinesisRecord, + final byte[] data, final int protobufOffset, final int protobufLength, final List out) throws Exception { + final List partitionKeyTable = new ArrayList<>(); final List subRecordDataList = new ArrayList<>(); final List subRecordPkIndexList = new ArrayList<>(); - final CodedInputStream input = CodedInputStream.newInstance(protobufBytes); + final CodedInputStream input = CodedInputStream.newInstance(data, protobufOffset, protobufLength); while (!input.isAtEnd()) { final int tag = input.readTag(); final int fieldNumber = WireFormat.getTagFieldNumber(tag); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java index 69ed6065584a..d74b76451559 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java @@ -82,8 +82,7 @@ void cleanupLingeringMigration() { if (lingeringMigration == null) { return; } - logger.info("Cleaning up lingering migration table [{}]", lingeringMigration); - CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, lingeringMigration, checkpointTableName); + logger.info("Deleting orphaned migration table [{}]; legacy checkpoint table [{}] retains original data", lingeringMigration, checkpointTableName); CheckpointTableUtils.deleteTable(dynamoDbClient, logger, lingeringMigration); } @@ -332,9 +331,10 @@ private void migrateLegacyCheckpoints(final String sourceTableName, final String .key(Map.of( "streamName", AttributeValue.builder().s(streamName).build(), "shardId", AttributeValue.builder().s(shardId).build())) - .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts") + .updateExpression("SET sequenceNumber = :seq, subSequenceNumber = :subSeq, lastUpdateTimestamp = :ts") .expressionAttributeValues(Map.of( ":seq", AttributeValue.builder().s(checkpoint).build(), + ":subSeq", AttributeValue.builder().n("0").build(), ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) .build(); dynamoDbClient.updateItem(request); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java index 6c735c176df8..592f50db9558 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java @@ -19,9 +19,12 @@ import org.apache.nifi.logging.ComponentLog; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; +import software.amazon.awssdk.services.kinesis.model.LimitExceededException; +import software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException; import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; @@ -45,19 +48,21 @@ * *

Concurrency is bounded by a semaphore with {@value #MAX_CONCURRENT_FETCHES} permits so * that at most that many GetRecords HTTP calls are in flight at any moment, preventing - * connection-pool exhaustion. Additionally, when the shared result queue exceeds - * {@value #MAX_QUEUED_RESULTS} entries the fetch loop sleeps until the processor drains results. + * connection-pool exhaustion. A second fair semaphore with {@value #MAX_QUEUED_RESULTS} + * permits ensures that fetch threads block when the result queue is full, with FIFO ordering + * guaranteeing that all shard threads get equal opportunity to enqueue results. */ final class PollingKinesisClient extends KinesisConsumerClient { private static final long DEFAULT_EMPTY_SHARD_BACKOFF_NANOS = TimeUnit.MILLISECONDS.toNanos(500); private static final long DEFAULT_ERROR_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(2); - static final int MAX_QUEUED_RESULTS = 500; - static final int MAX_CONCURRENT_FETCHES = 50; + static final int MAX_QUEUED_RESULTS = 200; + static final int MAX_CONCURRENT_FETCHES = 25; private final ExecutorService fetchExecutor = Executors.newVirtualThreadPerTaskExecutor(); private final Map pollingShardStates = new ConcurrentHashMap<>(); private final Semaphore fetchPermits = new Semaphore(MAX_CONCURRENT_FETCHES, true); + private final Semaphore queuePermits = new Semaphore(MAX_QUEUED_RESULTS, true); private final long emptyShardBackoffNanos; private final long errorBackoffNanos; private volatile Instant timestampForInitialPosition; @@ -166,6 +171,11 @@ cachedShardCount, ownedCount, totalQueuedResults(), pollingShardStates.size(), active, exhausted, stopped, dead, MAX_CONCURRENT_FETCHES - fetchPermits.availablePermits()); } + @Override + protected void onResultPolled() { + queuePermits.release(); + } + @Override void close() { for (final PollingShardState state : pollingShardStates.values()) { @@ -213,6 +223,10 @@ private void runFetchLoop(final PollingShardState state, final String shardId, if (state.isResetRequested()) { state.clearReset(); + final int drained = drainShardQueue(shardId); + if (drained > 0) { + queuePermits.release(drained); + } state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); } @@ -224,42 +238,52 @@ private void runFetchLoop(final PollingShardState state, final String shardId, } } - if (totalQueuedResults() >= MAX_QUEUED_RESULTS) { - sleepNanos(emptyShardBackoffNanos); - continue; - } - try { - fetchPermits.acquire(); + queuePermits.acquire(); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); return; } - final GetRecordsResponse response; + boolean queuePermitConsumed = false; try { - response = fetchRecords(shardId, state, batchSize); - } finally { - fetchPermits.release(); - } - if (response == null) { - continue; - } + try { + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } - final List records = response.records(); - if (!records.isEmpty()) { - final long millisBehind = response.millisBehindLatest() != null ? response.millisBehindLatest() : -1; - enqueueResult(createFetchResult(shardId, records, millisBehind)); - } + final GetRecordsResponse response; + try { + response = fetchRecords(shardId, state, batchSize); + } finally { + fetchPermits.release(); + } + if (response == null) { + continue; + } - state.setIterator(response.nextShardIterator()); - if (state.getIterator() == null) { - state.markExhausted(); - return; - } + final List records = response.records(); + if (!records.isEmpty()) { + final long millisBehind = response.millisBehindLatest() != null ? response.millisBehindLatest() : -1; + enqueueResult(createFetchResult(shardId, records, millisBehind)); + queuePermitConsumed = true; + } + + state.setIterator(response.nextShardIterator()); + if (state.getIterator() == null) { + state.markExhausted(); + return; + } - if (records.isEmpty()) { - sleepNanos(emptyShardBackoffNanos); + if (records.isEmpty()) { + sleepNanos(emptyShardBackoffNanos); + } + } finally { + if (!queuePermitConsumed) { + queuePermits.release(); + } } } catch (final Exception e) { if (!state.isStopped()) { @@ -279,26 +303,25 @@ private GetRecordsResponse fetchRecords(final String shardId, final PollingShard try { return kinesisClient.getRecords(request); - } catch (final software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException - | software.amazon.awssdk.services.kinesis.model.LimitExceededException e) { + } catch (final ProvisionedThroughputExceededException | LimitExceededException e) { logger.debug("GetRecords throttled for shard {}; will retry after backoff", shardId); sleepNanos(errorBackoffNanos); return null; - } catch (final software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException e) { + } catch (final ExpiredIteratorException e) { logger.info("Shard iterator expired for shard {}; will re-acquire", shardId); - state.setIterator(null); + state.requestReset(); sleepNanos(errorBackoffNanos); return null; } catch (final SdkClientException e) { if (!state.isStopped()) { - logger.warn("GetRecords timed out for shard {}; will retry with existing iterator", shardId); + logger.warn("GetRecords failed for shard {}; will retry with existing iterator", shardId, e); sleepNanos(errorBackoffNanos); } return null; } catch (final Exception e) { if (!state.isStopped()) { logger.error("GetRecords failed for shard {}", shardId, e); - state.setIterator(null); + state.requestReset(); sleepNanos(errorBackoffNanos); } return null; diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java index 0a84d411f6ed..60292b2c1233 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java @@ -26,14 +26,14 @@ * @param sequenceNumber the Kinesis record sequence number * @param subSequenceNumber the sub-record index within a KPL aggregate (0 for non-aggregated) */ -record ShardCheckpoint(String sequenceNumber, long subSequenceNumber) { +record ShardCheckpoint(BigInteger sequenceNumber, long subSequenceNumber) { /** * Returns the higher of two checkpoints. Comparison is first by sequence number, * then by sub-sequence number within the same aggregate. */ ShardCheckpoint max(final ShardCheckpoint other) { - final int comparison = new BigInteger(this.sequenceNumber).compareTo(new BigInteger(other.sequenceNumber)); + final int comparison = this.sequenceNumber.compareTo(other.sequenceNumber); if (comparison > 0) { return this; } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java index 343b393f956f..e72bd6289c97 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java @@ -16,16 +16,17 @@ */ package org.apache.nifi.processors.aws.kinesis; +import java.math.BigInteger; import java.util.List; record ShardFetchResult(String shardId, List records, long millisBehindLatest) { - String firstSequenceNumber() { - return records.getFirst().sequenceNumber(); + BigInteger firstSequenceNumber() { + return new BigInteger(records.getFirst().sequenceNumber()); } - String lastSequenceNumber() { - return records.getLast().sequenceNumber(); + BigInteger lastSequenceNumber() { + return new BigInteger(records.getLast().sequenceNumber()); } long lastSubSequenceNumber() { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java index e5fdff3b30cc..5a3bd1279f00 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java @@ -21,16 +21,10 @@ import org.mockito.ArgumentCaptor; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; -import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; -import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; -import software.amazon.awssdk.services.dynamodb.model.KeyType; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.ScanResponse; -import software.amazon.awssdk.services.dynamodb.model.TableDescription; -import software.amazon.awssdk.services.dynamodb.model.TableStatus; import java.util.List; import java.util.Map; @@ -38,6 +32,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -45,42 +40,29 @@ class CheckpointTableUtilsTest { @Test - void testCopyCheckpointItemsConvertsNewSchemaItemsForLegacyDestination() { + void testCopyCheckpointItemsCopiesShardItems() { final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); final ComponentLog logger = mock(ComponentLog.class); - when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { - final DescribeTableRequest request = invocation.getArgument(0); - if ("legacy-table".equals(request.tableName())) { - return legacySchemaResponse(); - } - return newSchemaResponse(); - }); - - final Map newSchemaItem = Map.of( + final Map item = Map.of( "streamName", AttributeValue.builder().s("my-stream").build(), "shardId", AttributeValue.builder().s("shardId-0001").build(), "sequenceNumber", AttributeValue.builder().s("12345").build()); - when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(newSchemaItem).build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(item).build()); when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); - CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "migration-table", "legacy-table"); + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); verify(dynamoDb, times(1)).putItem(putCaptor.capture()); - - final Map copiedItem = putCaptor.getValue().item(); - assertEquals("my-stream:shardId-0001", copiedItem.get("leaseKey").s()); - assertEquals("12345", copiedItem.get("checkpoint").s()); + assertEquals(item, putCaptor.getValue().item()); } @Test - void testCopyCheckpointItemsSkipsNodeAndMigrationMarkersForLegacyDestination() { + void testCopyCheckpointItemsSkipsNodeAndMigrationMarkers() { final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); final ComponentLog logger = mock(ComponentLog.class); - when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenReturn(legacySchemaResponse()); - final Map nodeItem = Map.of( "streamName", AttributeValue.builder().s("my-stream").build(), "shardId", AttributeValue.builder().s("__node__#node-a").build()); @@ -96,38 +78,27 @@ void testCopyCheckpointItemsSkipsNodeAndMigrationMarkersForLegacyDestination() { ScanResponse.builder().items(List.of(nodeItem, migrationMarkerItem, shardItem)).build()); when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); - CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "migration-table", "legacy-table"); + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); verify(dynamoDb, times(1)).putItem(putCaptor.capture()); - assertEquals("my-stream:shardId-0002", putCaptor.getValue().item().get("leaseKey").s()); + assertEquals(shardItem, putCaptor.getValue().item()); } - private static DescribeTableResponse newSchemaResponse() { - final KeySchemaElement hashKey = KeySchemaElement.builder() - .attributeName("streamName") - .keyType(KeyType.HASH) - .build(); - final KeySchemaElement rangeKey = KeySchemaElement.builder() - .attributeName("shardId") - .keyType(KeyType.RANGE) - .build(); - final TableDescription table = TableDescription.builder() - .keySchema(hashKey, rangeKey) - .tableStatus(TableStatus.ACTIVE) - .build(); - return DescribeTableResponse.builder().table(table).build(); - } + @Test + void testCopyCheckpointItemsSkipsAllMarkers() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + final Map nodeItem = Map.of( + "streamName", AttributeValue.builder().s("my-stream").build(), + "shardId", AttributeValue.builder().s("__node__#node-b").build()); + + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( + ScanResponse.builder().items(List.of(nodeItem)).build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); - private static DescribeTableResponse legacySchemaResponse() { - final KeySchemaElement hashKey = KeySchemaElement.builder() - .attributeName("leaseKey") - .keyType(KeyType.HASH) - .build(); - final TableDescription table = TableDescription.builder() - .keySchema(hashKey) - .tableStatus(TableStatus.ACTIVE) - .build(); - return DescribeTableResponse.builder().table(table).build(); + verify(dynamoDb, never()).putItem(any(PutItemRequest.class)); } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java index 478648e4b51b..773a96389904 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java @@ -204,6 +204,23 @@ void testDemarcatorDeliversAllRecords() throws Exception { success.assertAttributeEquals("record.count", "3"); } + @Test + void testDelimitedUsesLatestArrivalTimestamp() throws Exception { + final Instant firstArrival = Instant.parse("2025-01-15T00:00:00Z"); + final Instant secondArrival = Instant.parse("2025-01-15T00:00:05Z"); + final Instant thirdArrival = Instant.parse("2025-01-15T00:00:03Z"); + final List records = List.of( + testRecord("1", "line-one", firstArrival), + testRecord("2", "line-two", secondArrival), + testRecord("3", "line-three", thirdArrival)); + + triggerWithStrategy(records, "LINE_DELIMITED", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals(ConsumeKinesis.ATTR_ARRIVAL_TIMESTAMP, String.valueOf(secondArrival.toEpochMilli())); + } + @Test void testMultipleShardsNoDataLoss() throws Exception { final ShardFetchResult shard1Result = new ShardFetchResult("shard-A", @@ -377,6 +394,7 @@ private void assertInvalidRecordAtPosition(final String expectedFailureSequence, final MockFlowFile failure = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE).getFirst(); failure.assertContentEquals(expectedFailureContent); failure.assertAttributeEquals(ConsumeKinesis.ATTR_FIRST_SEQUENCE, expectedFailureSequence); + assertNotNull(failure.getAttribute(ConsumeKinesis.ATTR_RECORD_ERROR_MESSAGE)); } private void triggerWithOutputStrategy(final List records, final String outputStrategy) throws Exception { @@ -454,13 +472,17 @@ private static KinesisShardManager buildShardManager(final String... shardIds) { } private static DeaggregatedRecord testRecord(final String sequenceNumber, final String data) { + return testRecord(sequenceNumber, data, Instant.now()); + } + + private static DeaggregatedRecord testRecord(final String sequenceNumber, final String data, final Instant arrivalTimestamp) { return new DeaggregatedRecord( "shardId-000000000001", sequenceNumber, 0, "pk-" + sequenceNumber, data.getBytes(StandardCharsets.UTF_8), - Instant.now()); + arrivalTimestamp); } static class TestableConsumeKinesis extends ConsumeKinesis { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java index 0c48f0099d02..a773b88f8ee2 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java @@ -27,16 +27,14 @@ import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; -import java.lang.reflect.Field; +import java.math.BigInteger; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -73,14 +71,16 @@ void testSubscriptionRenewalUsesLastAcknowledgedSequenceNumber() throws Exceptio assertEquals("11111", capturedRequests.get(0).startingPosition().sequenceNumber(), "Initial subscription should use the DynamoDB checkpoint"); - simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", "99999"); + final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.setLastQueuedSequenceNumber(new BigInteger("99999")); + consumer.resetForRenewal(); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); assertEquals(2, capturedRequests.size()); assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, capturedRequests.get(1).startingPosition().type()); assertEquals("99999", capturedRequests.get(1).startingPosition().sequenceNumber(), - "Renewal should use max(lastAcknowledged, checkpoint) = lastAcknowledged"); + "Renewal should use max(lastQueued, checkpoint) = lastQueued"); verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); } @@ -100,7 +100,9 @@ void testSubscriptionRenewalFallsBackToCheckpointWhenNoQueuedData() throws Excep final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", null); + + final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.resetForRenewal(); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); @@ -113,11 +115,10 @@ void testSubscriptionRenewalFallsBackToCheckpointWhenNoQueuedData() throws Excep } /** - * Verifies that acknowledged sequence tracking is monotonic. If acknowledgements are observed - * out-of-order for a shard, renewal must use the highest acknowledged sequence. + * Verifies that renewal uses the lastQueuedSequenceNumber when it exceeds the checkpoint. */ @Test - void testSubscriptionRenewalUsesHighestAcknowledgedSequence() throws Exception { + void testSubscriptionRenewalUsesLastQueuedSequence() throws Exception { final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("10000"); @@ -127,24 +128,23 @@ void testSubscriptionRenewalUsesHighestAcknowledgedSequence() throws Exception { client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - client.acknowledgeResults(List.of( - shardFetchResult("shardId-000000000001", "20000"), - shardFetchResult("shardId-000000000001", "15000"))); + final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.setLastQueuedSequenceNumber(new BigInteger("20000")); + consumer.resetForRenewal(); - simulateExpiredSubscriptionWithAcknowledgedData(client, "shardId-000000000001", null); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); assertEquals(2, capturedRequests.size()); assertEquals("20000", capturedRequests.get(1).startingPosition().sequenceNumber(), - "Renewal should use the highest acknowledged sequence"); + "Renewal should use max(lastQueued=20000, checkpoint=10000) = lastQueued"); verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); } /** - * Verifies that renewal always uses the maximum of lastQueued, lastAcknowledged, and - * the DynamoDB checkpoint. This prevents one-event replay duplicates caused by races - * between onNext counter updates and concurrent handler onError callbacks. + * Verifies that renewal always uses the maximum of lastQueued and the DynamoDB checkpoint. + * This prevents one-event replay duplicates caused by races between onNext counter updates + * and concurrent handler onError callbacks. */ @Test void testSubscriptionRenewalAlwaysUsesMaxSequence() throws Exception { @@ -157,27 +157,26 @@ void testSubscriptionRenewalAlwaysUsesMaxSequence() throws Exception { client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000", "70000", 2); + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000"); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); assertEquals(2, capturedRequests.size()); assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), - "Renewal should use max(lastQueued=90000, lastAcked=70000, checkpoint=50000) = 90000"); + "Renewal should use max(lastQueued=90000, checkpoint=50000) = 90000"); - simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "95000", "80000", 0); + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "95000"); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); assertEquals(3, capturedRequests.size()); assertEquals("95000", capturedRequests.get(2).startingPosition().sequenceNumber(), - "Renewal should use max(lastQueued=95000, lastAcked=80000, checkpoint=50000) = 95000"); + "Renewal should use max(lastQueued=95000, checkpoint=50000) = 95000"); verify(mockShardManager, times(3)).readCheckpoint("shardId-000000000001"); } /** * Verifies that polling a queued result does not affect the renewal position. The renewal - * uses max(lastQueued, lastAcknowledged, checkpoint) regardless of whether results have - * been polled or acknowledged. + * uses max(lastQueued, checkpoint) regardless of whether results have been polled. */ @Test void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws Exception { @@ -189,7 +188,7 @@ void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws E final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000", "70000", 1); + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000"); client.enqueueResult(shardFetchResult("shardId-000000000001", "90000")); final ShardFetchResult polled = client.pollShardResult("shardId-000000000001"); @@ -199,7 +198,7 @@ void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws E assertEquals(2, capturedRequests.size()); assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), - "Renewal should use max(lastQueued=90000, lastAcked=70000, checkpoint=50000) = 90000"); + "Renewal should use max(lastQueued=90000, checkpoint=50000) = 90000"); verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); } @@ -218,10 +217,10 @@ void testAcknowledgeResultsRequestsNextOncePerShard() throws Exception { final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - final Object shardConsumer = getShardConsumer(client, "shardId-000000000001"); + final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); final Subscription subscription = mock(Subscription.class); - setField(shardConsumer, "subscription", subscription); - setField(shardConsumer, "queuedResultCount", new AtomicInteger(2)); + consumer.setSubscription(subscription); + consumer.pause(); client.acknowledgeResults(List.of( shardFetchResult("shardId-000000000001", "60000"), @@ -250,8 +249,7 @@ void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws }); final EfoKinesisClient client = new EfoKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); - setField(client, "kinesisAsyncClient", mockAsyncClient); - setField(client, "consumerArn", "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); + client.initializeForTest(mockAsyncClient, "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); final CountDownLatch startLatch = new CountDownLatch(1); @@ -268,7 +266,7 @@ void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws "Concurrent startup should create only one initial SubscribeToShard request per shard"); } - private static EfoKinesisClient createEfoClient(final List capturedRequests) throws Exception { + private static EfoKinesisClient createEfoClient(final List capturedRequests) { final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) @@ -278,34 +276,17 @@ private static EfoKinesisClient createEfoClient(final List consumers = (Map) consumersField.get(client); - final Object shardConsumer = consumers.get(shardId); - final Class scClass = shardConsumer.getClass(); - - setField(shardConsumer, scClass, "subscribing", new AtomicBoolean(false)); - setField(shardConsumer, scClass, "lastSubscribeAttemptNanos", 0L); - if (lastAckedSeq != null) { - getAtomicRef(shardConsumer, scClass, "lastAcknowledgedSequenceNumber").set(lastAckedSeq); - } + private static void simulateExpiredSubscriptionWithState( + final EfoKinesisClient client, + final String shardId, + final String lastQueuedSeq) { + final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer(shardId); + consumer.resetForRenewal(); + consumer.setLastQueuedSequenceNumber(new BigInteger(lastQueuedSeq)); } /** @@ -331,128 +312,49 @@ void testStaleErrorCallbackDoesNotCorruptNewSubscription() throws Exception { .thenReturn(CompletableFuture.completedFuture(null)); final EfoKinesisClient.ShardConsumer consumer = - new EfoKinesisClient.ShardConsumer("shardId-000000000001", result -> { }, mockLogger); + new EfoKinesisClient.ShardConsumer("shardId-000000000001", result -> { }, new ConcurrentLinkedQueue<>(), mockLogger); final StartingPosition pos = StartingPosition.builder() .type(ShardIteratorType.TRIM_HORIZON) .build(); consumer.subscribe(mockAsyncClient, "test-arn", pos); - final int gen1 = getGeneration(consumer); + final int gen1 = consumer.getSubscriptionGeneration(); assertEquals(1, gen1); - simulateSubscriberActive(consumer); + consumer.setSubscription(mock(Subscription.class)); - final Subscription gen1Subscription = getSubscription(consumer); - assertNotNull(gen1Subscription, "Subscription should be set after onSubscribe"); + assertNotNull(consumer.getSubscription(), "Subscription should be set after onSubscribe"); - endSubscriptionIfCurrent(consumer, gen1); - assertFalse(getSubscribing(consumer), "subscribing should be false after endSubscription"); + consumer.endSubscriptionIfCurrent(gen1); + assertFalse(consumer.isSubscribing(), "subscribing should be false after endSubscription"); consumer.subscribe(mockAsyncClient, "test-arn", pos); - final int gen2 = getGeneration(consumer); + final int gen2 = consumer.getSubscriptionGeneration(); assertEquals(2, gen2); - simulateSubscriberActive(consumer); + consumer.setSubscription(mock(Subscription.class)); - final Subscription gen2Subscription = getSubscription(consumer); - assertNotNull(gen2Subscription, "New subscription should be set"); + assertNotNull(consumer.getSubscription(), "New subscription should be set"); - endSubscriptionIfCurrent(consumer, gen1); + consumer.endSubscriptionIfCurrent(gen1); - assertNotNull(getSubscription(consumer), + assertNotNull(consumer.getSubscription(), "Stale callback (gen1) must NOT null out gen2's subscription"); - assertTrue(getSubscribing(consumer), + assertTrue(consumer.isSubscribing(), "Stale callback (gen1) must NOT reset gen2's subscribing flag"); - endSubscriptionIfCurrent(consumer, gen2); + consumer.endSubscriptionIfCurrent(gen2); - assertFalse(getSubscribing(consumer), + assertFalse(consumer.isSubscribing(), "Current-generation callback should clean up normally"); } - private static void simulateSubscriberActive(final Object shardConsumer) throws Exception { - final Subscription mockSub = mock(Subscription.class); - final Field subField = shardConsumer.getClass().getDeclaredField("subscription"); - subField.setAccessible(true); - subField.set(shardConsumer, mockSub); - } - - private static Subscription getSubscription(final Object shardConsumer) throws Exception { - final Field field = shardConsumer.getClass().getDeclaredField("subscription"); - field.setAccessible(true); - return (Subscription) field.get(shardConsumer); - } - - private static boolean getSubscribing(final Object shardConsumer) throws Exception { - final Field field = shardConsumer.getClass().getDeclaredField("subscribing"); - field.setAccessible(true); - return ((AtomicBoolean) field.get(shardConsumer)).get(); - } - - private static int getGeneration(final Object shardConsumer) throws Exception { - final Field field = shardConsumer.getClass().getDeclaredField("subscriptionGeneration"); - field.setAccessible(true); - return ((AtomicInteger) field.get(shardConsumer)).get(); - } - - private static void endSubscriptionIfCurrent(final Object shardConsumer, final int generation) throws Exception { - final java.lang.reflect.Method method = shardConsumer.getClass().getDeclaredMethod("endSubscriptionIfCurrent", int.class); - method.setAccessible(true); - method.invoke(shardConsumer, generation); - } - - private static void setField(final Object target, final String fieldName, final Object value) throws Exception { - setField(target, target.getClass(), fieldName, value); - } - - private static void setField(final Object target, final Class clazz, final String fieldName, final Object value) throws Exception { - final Field field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - field.set(target, value); - } - - @SuppressWarnings("unchecked") - private static AtomicReference getAtomicRef( - final Object target, final Class clazz, final String fieldName) throws Exception { - final Field field = clazz.getDeclaredField(fieldName); - field.setAccessible(true); - return (AtomicReference) field.get(target); - } - - @SuppressWarnings("unchecked") - private static void simulateExpiredSubscriptionWithState( - final EfoKinesisClient client, - final String shardId, - final String lastQueuedSeq, - final String lastAckedSeq, - final int queuedCount) throws Exception { - final Field consumersField = EfoKinesisClient.class.getDeclaredField("shardConsumers"); - consumersField.setAccessible(true); - final Map consumers = (Map) consumersField.get(client); - final Object shardConsumer = consumers.get(shardId); - final Class scClass = shardConsumer.getClass(); - - setField(shardConsumer, scClass, "subscribing", new AtomicBoolean(false)); - setField(shardConsumer, scClass, "lastSubscribeAttemptNanos", 0L); - setField(shardConsumer, scClass, "lastQueuedSequenceNumber", lastQueuedSeq); - getAtomicRef(shardConsumer, scClass, "lastAcknowledgedSequenceNumber").set(lastAckedSeq); - setField(shardConsumer, scClass, "queuedResultCount", new AtomicInteger(queuedCount)); - } - private static ShardFetchResult shardFetchResult(final String shardId, final String sequenceNumber) { final DeaggregatedRecord record = new DeaggregatedRecord(shardId, sequenceNumber, 0, "pk", "{}".getBytes(), null); return new ShardFetchResult(shardId, List.of(record), 0L); } - @SuppressWarnings("unchecked") - private static Object getShardConsumer(final EfoKinesisClient client, final String shardId) throws Exception { - final Field consumersField = EfoKinesisClient.class.getDeclaredField("shardConsumers"); - consumersField.setAccessible(true); - final Map consumers = (Map) consumersField.get(client); - return consumers.get(shardId); - } - /** * Verifies the per-shard exclusive processing lock: claimShard returns false when a shard is * already claimed, and releaseShards makes it claimable again. Different shards are independent. @@ -502,13 +404,13 @@ void testPerShardOrderingPreservedAcrossEnqueues() { client.enqueueResult(shardFetchResult("shard-3", "600")); client.enqueueResult(shardFetchResult("shard-5", "300")); - assertEquals("100", client.pollShardResult("shard-5").firstSequenceNumber()); - assertEquals("200", client.pollShardResult("shard-5").firstSequenceNumber()); - assertEquals("300", client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals(new BigInteger("100"), client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals(new BigInteger("200"), client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals(new BigInteger("300"), client.pollShardResult("shard-5").firstSequenceNumber()); assertNull(client.pollShardResult("shard-5"), "Queue should be empty after draining"); - assertEquals("500", client.pollShardResult("shard-3").firstSequenceNumber()); - assertEquals("600", client.pollShardResult("shard-3").firstSequenceNumber()); + assertEquals(new BigInteger("500"), client.pollShardResult("shard-3").firstSequenceNumber()); + assertEquals(new BigInteger("600"), client.pollShardResult("shard-3").firstSequenceNumber()); assertNull(client.pollShardResult("shard-3"), "Queue should be empty after draining"); } @@ -530,13 +432,13 @@ void testQueueIntrospectionMethods() { assertTrue(client.hasQueuedResults()); assertEquals(3, client.totalQueuedResults()); - assertEquals(Set.of("shard-1", "shard-2"), client.getShardIdsWithResults()); + assertEquals(Set.of("shard-1", "shard-2"), new HashSet<>(client.getShardIdsWithResults())); client.pollShardResult("shard-1"); client.pollShardResult("shard-1"); assertEquals(1, client.totalQueuedResults()); - assertEquals(Set.of("shard-2"), client.getShardIdsWithResults()); + assertEquals(List.of("shard-2"), client.getShardIdsWithResults()); } /** @@ -563,7 +465,7 @@ void testPerShardQueuesPreventOutOfOrderDeliveryAcrossInvocations() { // Task-1: claims shard-1, polls A(100) from its per-shard queue assertTrue(client.claimShard("shard-1"), "Task-1 should claim shard-1"); final ShardFetchResult polledA = client.pollShardResult("shard-1"); - assertEquals("100", polledA.firstSequenceNumber()); + assertEquals(new BigInteger("100"), polledA.firstSequenceNumber()); // Task-2: cannot claim shard-1 (held by Task-1), so it skips shard-1 entirely. // B and C remain at the head of shard-1's queue, undisturbed. @@ -572,7 +474,7 @@ void testPerShardQueuesPreventOutOfOrderDeliveryAcrossInvocations() { // Task-2: claims shard-2 and polls its result assertTrue(client.claimShard("shard-2")); final ShardFetchResult polledOther = client.pollShardResult("shard-2"); - assertEquals("999", polledOther.firstSequenceNumber()); + assertEquals(new BigInteger("999"), polledOther.firstSequenceNumber()); // Task-1 commits A and releases shard-1 client.releaseShards(List.of("shard-1")); @@ -583,8 +485,8 @@ void testPerShardQueuesPreventOutOfOrderDeliveryAcrossInvocations() { assertNotNull(firstPoll, "Expected shard-1 queue to have B"); assertNotNull(secondPoll, "Expected shard-1 queue to have C"); - assertEquals("200", firstPoll.firstSequenceNumber(), "First poll must be B(200), not C(300)"); - assertEquals("300", secondPoll.firstSequenceNumber(), "Second poll must be C(300)"); + assertEquals(new BigInteger("200"), firstPoll.firstSequenceNumber(), "First poll must be B(200), not C(300)"); + assertEquals(new BigInteger("300"), secondPoll.firstSequenceNumber(), "Second poll must be C(300)"); assertNull(client.pollShardResult("shard-1"), "shard-1 queue should be empty"); } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java index 2fdc6df5c639..1d60e386e279 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java @@ -44,6 +44,7 @@ import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse; import software.amazon.awssdk.services.kinesis.KinesisClient; +import java.math.BigInteger; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -77,9 +78,9 @@ void testCheckpointMonotonicity() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("30000", 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("70000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("30000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("70000"), 0))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(2)).updateItem(captor.capture()); @@ -100,11 +101,11 @@ void testCheckpointMonotonicityPerShard() { when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); manager.writeCheckpoints(Map.of( - "shard-1", new ShardCheckpoint("50000", 0), - "shard-2", new ShardCheckpoint("20000", 0))); + "shard-1", new ShardCheckpoint(new BigInteger("50000"), 0), + "shard-2", new ShardCheckpoint(new BigInteger("20000"), 0))); manager.writeCheckpoints(Map.of( - "shard-1", new ShardCheckpoint("30000", 0), - "shard-2", new ShardCheckpoint("40000", 0))); + "shard-1", new ShardCheckpoint(new BigInteger("30000"), 0), + "shard-2", new ShardCheckpoint(new BigInteger("40000"), 0))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(3)).updateItem(captor.capture()); @@ -131,9 +132,9 @@ void testCloseResetsCheckpointGuard() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); manager.close(); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("30000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("30000"), 0))); verify(dynamoDb, times(2)).updateItem(any(UpdateItemRequest.class)); } @@ -182,8 +183,8 @@ void testWriteCheckpointHandlesLostLeaseGracefully() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenThrow(lostLease).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("70000", 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("70000"), 0))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(2)).updateItem(captor.capture()); @@ -201,7 +202,7 @@ void testCheckpointStoresSubSequenceNumber() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 7))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 7))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(1)).updateItem(captor.capture()); @@ -221,9 +222,9 @@ void testCheckpointMonotonicityWithSubSequenceNumber() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 3))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 1))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint("50000", 5))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 3))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 1))); + manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 5))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(2)).updateItem(captor.capture()); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java index 3277cb9ce958..129a8e3eca8b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java @@ -33,11 +33,13 @@ import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -92,8 +94,8 @@ void testExhaustedShardDeliversAllRecords() throws Exception { final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); assertNotNull(result, "Last batch of an exhausted shard must not be dropped"); assertEquals(3, result.records().size()); - assertEquals("100", result.firstSequenceNumber()); - assertEquals("300", result.lastSequenceNumber()); + assertEquals(new BigInteger("100"), result.firstSequenceNumber()); + assertEquals(new BigInteger("300"), result.lastSequenceNumber()); assertEventuallyNoPendingFetches(); } @@ -123,7 +125,7 @@ void testDeadLoopRecoveryRestoresDataFlow() throws Exception { final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); assertNotNull(result, "Dead loop must be restarted and produce records"); - assertEquals("100", result.firstSequenceNumber()); + assertEquals(new BigInteger("100"), result.firstSequenceNumber()); } /** @@ -179,6 +181,74 @@ void testExpiredIteratorRecovery() throws Exception { verify(mockKinesisClient, timeout(5000).atLeast(2)).getShardIterator(any(GetShardIteratorRequest.class)); } + /** + * Verifies that iterator recovery does not reorder results within a shard when newer data + * is already queued ahead of replay from the persisted checkpoint. + */ + @Test + void testExpiredIteratorRecoveryDoesNotDeliverSameShardOutOfOrder() throws Exception { + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn("100"); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int call = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + call).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest req = invocation.getArgument(0); + + return switch (req.shardIterator()) { + case "iter-1" -> GetRecordsResponse.builder() + .records(record("200", "A")) + .nextShardIterator("iter-1a") + .millisBehindLatest(0L) + .build(); + case "iter-1a" -> GetRecordsResponse.builder() + .records(record("300", "B")) + .nextShardIterator("iter-1b") + .millisBehindLatest(0L) + .build(); + case "iter-1b" -> throw ExpiredIteratorException.builder().message("expired").build(); + case "iter-2" -> GetRecordsResponse.builder() + .records(record("200", "A-replay")) + .nextShardIterator("iter-2a") + .millisBehindLatest(0L) + .build(); + default -> GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(req.shardIterator()) + .millisBehindLatest(0L) + .build(); + }; + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "Initial result must be available"); + assertEquals(new BigInteger("200"), firstResult.firstSequenceNumber()); + + final long newerQueuedDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < newerQueuedDeadline && getRecordsCallCount.get() < 2) { + Thread.sleep(20); + } + Thread.sleep(50); + + final long replayQueuedDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < replayQueuedDeadline && (getShardIteratorCallCount.get() < 2 || getRecordsCallCount.get() < 4)) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRecovery = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRecovery, "Queue must contain results after iterator recovery"); + assertEquals(new BigInteger("200"), firstAfterRecovery.firstSequenceNumber(), + "Replayed checkpoint data must be delivered before any stale newer batch from the same shard"); + } + /** * Validates that throttling (ProvisionedThroughputExceededException) is transient: * the loop retries and records eventually arrive without data loss. @@ -267,6 +337,81 @@ void testHasPendingFetchesFalseWhenAllShardsExhausted() throws Exception { assertEventuallyNoPendingFetches(); } + /** + * Reproduces out-of-order delivery caused by stale results remaining in the per-shard + * queue after a rollback. The scenario for a single shard: + * + *

    + *
  1. Fetch loop enqueues R1 (seq 100-200) and R2 (seq 300-400).
  2. + *
  3. Consumer polls R1 only; R2 remains in the queue.
  4. + *
  5. Consumer calls rollbackResults on R1, which sets the reset flag. + * The fetch loop detects the flag, drains R2 from the queue, + * resets the shard iterator, and re-fetches.
  6. + *
  7. After the reset, the first result polled must come from the re-fetched + * sequence (seq 500), not the stale R2 (seq 300).
  8. + *
+ */ + @Test + void testRollbackDrainsStaleResultsFromQueue() throws Exception { + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int call = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + call).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest req = invocation.getArgument(0); + + if (req.shardIterator().equals("iter-1")) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (req.shardIterator().equals("iter-1a")) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (req.shardIterator().startsWith("iter-1")) { + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(req.shardIterator()).millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(record("500", "E"), record("600", "F")) + .nextShardIterator("iter-post-reset-next").millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult r1 = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(r1, "R1 must be available"); + assertEquals(new BigInteger("100"), r1.firstSequenceNumber()); + + final long r2Deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < r2Deadline && getRecordsCallCount.get() < 2) { + Thread.sleep(20); + } + Thread.sleep(50); + + consumer.rollbackResults(List.of(r1)); + + final long resetDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (getShardIteratorCallCount.get() < 2 && System.nanoTime() < resetDeadline) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRollback = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRollback, "Queue must contain results after rollback + re-fetch"); + assertEquals(new BigInteger("500"), firstAfterRollback.firstSequenceNumber(), + "First result after rollback must be re-fetched data, not a stale pre-rollback result"); + } + private void drainAllResults() throws InterruptedException { ShardFetchResult discarded; do { From c40712d2b938f42674e6e5c8b1c7d5d7200f0301 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Tue, 10 Mar 2026 10:37:05 -0400 Subject: [PATCH 3/7] NIFI-15669: Simplified checkpointing by eliminating subsequences because we always include all sub-records within a single ProcessSession so we don't need to checkpoint partial sequences --- .../aws/kinesis/ConsumeKinesis.java | 6 +- .../aws/kinesis/KinesisShardManager.java | 77 +++++++++++-------- .../aws/kinesis/LegacyCheckpointMigrator.java | 3 +- .../aws/kinesis/ShardCheckpoint.java | 45 ----------- .../aws/kinesis/ShardFetchResult.java | 4 - .../aws/kinesis/KinesisShardManagerTest.java | 63 +++------------ 6 files changed, 58 insertions(+), 140 deletions(-) delete mode 100644 nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index e920687edd99..9a9fd306d09b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -693,10 +693,10 @@ private PartitionedBatch partitionByShardAndCheckpoint(final List checkpoints = new HashMap<>(); + final Map checkpoints = new HashMap<>(); for (final List shardResults : resultsByShard.values()) { final ShardFetchResult last = shardResults.getLast(); - checkpoints.put(last.shardId(), new ShardCheckpoint(last.lastSequenceNumber(), last.lastSubSequenceNumber())); + checkpoints.put(last.shardId(), last.lastSequenceNumber()); } return new PartitionedBatch(resultsByShard, checkpoints); @@ -1404,7 +1404,7 @@ public void reset() throws IOException { } } - private record PartitionedBatch(Map> resultsByShard, Map checkpoints) { + private record PartitionedBatch(Map> resultsByShard, Map checkpoints) { } private record WriteResult(List produced, List parseFailures, diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java index 5885de7b25c7..752e2fec62d3 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java @@ -31,6 +31,7 @@ import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; import software.amazon.awssdk.services.kinesis.model.Shard; +import java.math.BigInteger; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; @@ -98,7 +99,7 @@ final class KinesisShardManager { private final Set ownedShards = ConcurrentHashMap.newKeySet(); private final Map pendingRelinquishDeadlines = new ConcurrentHashMap<>(); private final AtomicBoolean leaseRefreshInProgress = new AtomicBoolean(false); - private final Map highestWrittenCheckpoints = new ConcurrentHashMap<>(); + private final Map highestWrittenCheckpoints = new ConcurrentHashMap<>(); private volatile Instant lastLeaseRefresh = Instant.EPOCH; private volatile String activeCheckpointTableName; @@ -315,8 +316,8 @@ String readCheckpoint(final String shardId) { return null; } - void writeCheckpoints(final Map checkpoints) { - for (final Map.Entry entry : checkpoints.entrySet()) { + void writeCheckpoints(final Map checkpoints) { + for (final Map.Entry entry : checkpoints.entrySet()) { writeCheckpoint(entry.getKey(), entry.getValue()); } } @@ -497,41 +498,49 @@ private int countActiveNodes(final long now) { return Math.max(1, activeNodes); } - private void writeCheckpoint(final String shardId, final ShardCheckpoint checkpoint) { - final ShardCheckpoint written = highestWrittenCheckpoints.compute(shardId, (key, existing) -> { - if (existing != null && checkpoint.max(existing) == existing) { - return existing; - } + private void writeCheckpoint(final String shardId, final BigInteger checkpoint) { + final BigInteger written = highestWrittenCheckpoints.compute(shardId, + (key, existing) -> persistIfHigher(shardId, checkpoint, existing)); - try { - final long now = Instant.now().toEpochMilli(); - final UpdateItemRequest checkpointRequest = UpdateItemRequest.builder() - .tableName(activeCheckpointTableName) - .key(checkpointKey(shardId)) - .updateExpression("SET sequenceNumber = :seq, subSequenceNumber = :subSeq," - + " lastUpdateTimestamp = :ts, leaseExpiry = :exp") - .conditionExpression("leaseOwner = :owner") - .expressionAttributeValues(Map.of( - ":seq", AttributeValue.builder().s(checkpoint.sequenceNumber().toString()).build(), - ":subSeq", AttributeValue.builder().n(String.valueOf(checkpoint.subSequenceNumber())).build(), - ":ts", AttributeValue.builder().n(String.valueOf(now)).build(), - ":exp", AttributeValue.builder().n(String.valueOf(now + leaseDurationMillis)).build(), - ":owner", AttributeValue.builder().s(nodeId).build())) - .build(); - dynamoDbClient.updateItem(checkpointRequest); - logger.debug("Checkpointed shard {} at sequence {}/{}", shardId, checkpoint.sequenceNumber(), checkpoint.subSequenceNumber()); - return checkpoint; - } catch (final ConditionalCheckFailedException e) { - logger.warn("Lost lease on shard {} during checkpoint; another node may have taken it", shardId); - } catch (final Exception e) { - logger.error("Failed to write checkpoint for shard {}", shardId, e); - } + if (written != null && checkpoint.compareTo(written) < 0) { + logger.debug("Skipped checkpoint regression for shard {} (highest: {}, attempted: {})", shardId, written, checkpoint); + } + } + + /** + * Writes the checkpoint to DynamoDB if it is higher than the existing value. Returns the + * new highest checkpoint on success, or {@code existing} if the checkpoint was lower or + * the write failed. + */ + private BigInteger persistIfHigher(final String shardId, final BigInteger checkpoint, final BigInteger existing) { + if (existing != null && checkpoint.max(existing).equals(existing)) { return existing; - }); + } - if (written != null && checkpoint.sequenceNumber().compareTo(written.sequenceNumber()) < 0) { - logger.debug("Skipped checkpoint regression for shard {} (highest: {}, attempted: {})", shardId, written.sequenceNumber(), checkpoint.sequenceNumber()); + try { + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest checkpointRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts, leaseExpiry = :exp") + .conditionExpression("leaseOwner = :owner") + .expressionAttributeValues(Map.of( + ":seq", AttributeValue.builder().s(checkpoint.toString()).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build(), + ":exp", AttributeValue.builder().n(String.valueOf(now + leaseDurationMillis)).build(), + ":owner", AttributeValue.builder().s(nodeId).build())) + .build(); + dynamoDbClient.updateItem(checkpointRequest); + logger.debug("Checkpointed shard {} at sequence {}", shardId, checkpoint); + + return checkpoint; + } catch (final ConditionalCheckFailedException e) { + logger.warn("Lost lease on shard {} during checkpoint; another node may have taken it", shardId); + } catch (final Exception e) { + logger.error("Failed to write checkpoint for shard {}", shardId, e); } + + return existing; } private Map checkpointKey(final String shardId) { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java index d74b76451559..d994094d2c47 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java @@ -331,10 +331,9 @@ private void migrateLegacyCheckpoints(final String sourceTableName, final String .key(Map.of( "streamName", AttributeValue.builder().s(streamName).build(), "shardId", AttributeValue.builder().s(shardId).build())) - .updateExpression("SET sequenceNumber = :seq, subSequenceNumber = :subSeq, lastUpdateTimestamp = :ts") + .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts") .expressionAttributeValues(Map.of( ":seq", AttributeValue.builder().s(checkpoint).build(), - ":subSeq", AttributeValue.builder().n("0").build(), ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) .build(); dynamoDbClient.updateItem(request); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java deleted file mode 100644 index 60292b2c1233..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardCheckpoint.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import java.math.BigInteger; - -/** - * Immutable checkpoint position within a Kinesis shard, composed of a sequence number - * and a sub-sequence number. The sub-sequence number is non-zero only for KPL-aggregated - * records and identifies the position within the aggregate. - * - * @param sequenceNumber the Kinesis record sequence number - * @param subSequenceNumber the sub-record index within a KPL aggregate (0 for non-aggregated) - */ -record ShardCheckpoint(BigInteger sequenceNumber, long subSequenceNumber) { - - /** - * Returns the higher of two checkpoints. Comparison is first by sequence number, - * then by sub-sequence number within the same aggregate. - */ - ShardCheckpoint max(final ShardCheckpoint other) { - final int comparison = this.sequenceNumber.compareTo(other.sequenceNumber); - if (comparison > 0) { - return this; - } - if (comparison < 0) { - return other; - } - return this.subSequenceNumber >= other.subSequenceNumber ? this : other; - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java index e72bd6289c97..507f79e0d508 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java @@ -28,8 +28,4 @@ BigInteger firstSequenceNumber() { BigInteger lastSequenceNumber() { return new BigInteger(records.getLast().sequenceNumber()); } - - long lastSubSequenceNumber() { - return records.getLast().subSequenceNumber(); - } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java index 1d60e386e279..31c900d701f7 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java @@ -78,9 +78,9 @@ void testCheckpointMonotonicity() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("30000"), 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("70000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("30000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("70000"))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(2)).updateItem(captor.capture()); @@ -101,11 +101,11 @@ void testCheckpointMonotonicityPerShard() { when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); manager.writeCheckpoints(Map.of( - "shard-1", new ShardCheckpoint(new BigInteger("50000"), 0), - "shard-2", new ShardCheckpoint(new BigInteger("20000"), 0))); + "shard-1", new BigInteger("50000"), + "shard-2", new BigInteger("20000"))); manager.writeCheckpoints(Map.of( - "shard-1", new ShardCheckpoint(new BigInteger("30000"), 0), - "shard-2", new ShardCheckpoint(new BigInteger("40000"), 0))); + "shard-1", new BigInteger("30000"), + "shard-2", new BigInteger("40000"))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(3)).updateItem(captor.capture()); @@ -132,9 +132,9 @@ void testCloseResetsCheckpointGuard() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); manager.close(); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("30000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("30000"))); verify(dynamoDb, times(2)).updateItem(any(UpdateItemRequest.class)); } @@ -183,8 +183,8 @@ void testWriteCheckpointHandlesLostLeaseGracefully() { final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenThrow(lostLease).thenReturn(emptyResponse); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 0))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("70000"), 0))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("70000"))); final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); verify(dynamoDb, times(2)).updateItem(captor.capture()); @@ -194,47 +194,6 @@ void testWriteCheckpointHandlesLostLeaseGracefully() { "After a lost-lease failure, the next higher checkpoint must still be attempted"); } - /** - * Verifies that writeCheckpoints stores the subSequenceNumber alongside the sequenceNumber. - */ - @Test - void testCheckpointStoresSubSequenceNumber() { - final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); - when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 7))); - - final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); - verify(dynamoDb, times(1)).updateItem(captor.capture()); - - final UpdateItemRequest request = captor.getValue(); - assertEquals("50000", request.expressionAttributeValues().get(":seq").s()); - assertEquals("7", request.expressionAttributeValues().get(":subSeq").n(), - "subSequenceNumber must be persisted in the DynamoDB checkpoint"); - } - - /** - * Verifies that for the same sequence number, a higher sub-sequence number is written - * and a lower one is skipped. - */ - @Test - void testCheckpointMonotonicityWithSubSequenceNumber() { - final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); - when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); - - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 3))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 1))); - manager.writeCheckpoints(Map.of("shard-1", new ShardCheckpoint(new BigInteger("50000"), 5))); - - final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); - verify(dynamoDb, times(2)).updateItem(captor.capture()); - - final List requests = captor.getAllValues(); - assertEquals("3", requests.get(0).expressionAttributeValues().get(":subSeq").n()); - assertEquals("5", requests.get(1).expressionAttributeValues().get(":subSeq").n(), - "Only increasing sub-sequence checkpoints within the same sequence should be written"); - } - /** * Verifies that when no checkpoint table exists and no orphaned migration table exists, * a fresh table is created with the configured name (no suffix). From ecf53327688d47704b98d5e657ceacf67b7a2b55 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Tue, 10 Mar 2026 12:12:35 -0400 Subject: [PATCH 4/7] NIFI-15669: Addressed a couple of additional corner cases --- .../aws/kinesis/ConsumeKinesis.java | 9 +- .../aws/kinesis/EfoKinesisClient.java | 31 ++- .../aws/kinesis/KinesisConsumerClient.java | 6 + .../aws/kinesis/PollingKinesisClient.java | 102 +++++---- .../aws/kinesis/PollingKinesisClientTest.java | 201 +++++++++++++++--- 5 files changed, 268 insertions(+), 81 deletions(-) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index 9a9fd306d09b..78d67495c237 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -598,8 +598,9 @@ public void onTrigger(final ProcessContext context, final ProcessSession session consumerClient.logDiagnostics(ownedShards.size(), shardManager.getCachedShardCount()); final Set claimedShards = new HashSet<>(); + List consumed = List.of(); try { - final List consumed = consumeRecords(claimedShards); + consumed = consumeRecords(claimedShards); final List accepted = discardRelinquishedResults(consumed, claimedShards); if (accepted.isEmpty()) { @@ -635,6 +636,7 @@ public void onTrigger(final ProcessContext context, final ProcessSession session session.adjustCounter("EFO Deduplicated Events", dedupEvents, false); } + consumed = List.of(); session.commitAsync( () -> { try { @@ -656,6 +658,9 @@ public void onTrigger(final ProcessContext context, final ProcessSession session } }); } catch (final Exception e) { + if (!consumed.isEmpty()) { + consumerClient.rollbackResults(consumed); + } consumerClient.releaseShards(claimedShards); throw e; } @@ -1288,7 +1293,7 @@ private void closeQuietly(final AutoCloseable closeable) { try { closeable.close(); } catch (final Exception e) { - getLogger().warn("Failed to close Record Writer", e); + getLogger().warn("Failed to close resource", e); } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java index 7c9d2ab3b97b..4bfa114cc279 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java @@ -113,13 +113,13 @@ void startFetches(final List shards, final String streamName, final int b final BigInteger lastSeq = checkpoint != null ? new BigInteger(checkpoint) : null; final StartingPosition startingPosition = buildStartingPosition(lastSeq, initialStreamPosition); logger.info("Creating EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), lastSeq); - final ShardConsumer sc = new ShardConsumer(shardId, EfoKinesisClient.this::enqueueResult, pausedConsumers, logger); - final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, sc); + final ShardConsumer shardConsumer = new ShardConsumer(shardId, result -> enqueueIfActiveConsumer(shardId, result), pausedConsumers, logger); + final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, shardConsumer); if (prior == null) { try { - sc.subscribe(kinesisAsyncClient, consumerArn, startingPosition); + shardConsumer.subscribe(kinesisAsyncClient, consumerArn, startingPosition); } catch (final Exception e) { - shardConsumers.remove(shardId, sc); + shardConsumers.remove(shardId, shardConsumer); throw e; } } @@ -172,14 +172,27 @@ private void resumePausedConsumers() { } } + private void enqueueIfActiveConsumer(final String shardId, final ShardFetchResult result) { + synchronized (getShardLock(shardId)) { + if (shardConsumers.containsKey(shardId)) { + enqueueResult(result); + } + } + } + + private ShardConsumer drainAndRemoveConsumer(final String shardId) { + synchronized (getShardLock(shardId)) { + drainShardQueue(shardId); + return shardConsumers.remove(shardId); + } + } + @Override void rollbackResults(final List results) { for (final ShardFetchResult result : results) { - drainShardQueue(result.shardId()); - - final ShardConsumer sc = shardConsumers.remove(result.shardId()); - if (sc != null) { - sc.cancel(); + final ShardConsumer shardConsumer = drainAndRemoveConsumer(result.shardId()); + if (shardConsumer != null) { + shardConsumer.cancel(); } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java index 453214d00f83..bdcdafcc595a 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java @@ -49,6 +49,7 @@ abstract class KinesisConsumerClient { protected final KinesisClient kinesisClient; protected final ComponentLog logger; private final Map> shardQueues = new ConcurrentHashMap<>(); + private final Map shardLocks = new ConcurrentHashMap<>(); private final Semaphore resultNotification = new Semaphore(0); protected final Set shardsInFlight = ConcurrentHashMap.newKeySet(); @@ -75,8 +76,13 @@ abstract void startFetches(List shards, String streamName, int batchSize, abstract void logDiagnostics(int ownedCount, int cachedShardCount); + Object getShardLock(final String shardId) { + return shardLocks.computeIfAbsent(shardId, k -> new Object()); + } + void close() { shardQueues.clear(); + shardLocks.clear(); resultNotification.drainPermits(); shardsInFlight.clear(); } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java index 592f50db9558..aed1bb553347 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java @@ -127,7 +127,7 @@ void rollbackResults(final List results) { for (final ShardFetchResult result : results) { final PollingShardState state = pollingShardStates.get(result.shardId()); if (state != null) { - state.requestReset(); + resetAndDrainShard(result.shardId(), state); } } } @@ -267,8 +267,7 @@ private void runFetchLoop(final PollingShardState state, final String shardId, final List records = response.records(); if (!records.isEmpty()) { final long millisBehind = response.millisBehindLatest() != null ? response.millisBehindLatest() : -1; - enqueueResult(createFetchResult(shardId, records, millisBehind)); - queuePermitConsumed = true; + queuePermitConsumed = enqueueIfActive(shardId, state, createFetchResult(shardId, records, millisBehind)); } state.setIterator(response.nextShardIterator()); @@ -328,6 +327,26 @@ private GetRecordsResponse fetchRecords(final String shardId, final PollingShard } } + private boolean enqueueIfActive(final String shardId, final PollingShardState state, final ShardFetchResult result) { + synchronized (getShardLock(shardId)) { + if (state.isResetRequested()) { + return false; + } + enqueueResult(result); + return true; + } + } + + private void resetAndDrainShard(final String shardId, final PollingShardState state) { + synchronized (getShardLock(shardId)) { + state.requestReset(); + final int drained = drainShardQueue(shardId); + if (drained > 0) { + queuePermits.release(drained); + } + } + } + private static void sleepNanos(final long nanos) { try { TimeUnit.NANOSECONDS.sleep(nanos); @@ -338,57 +357,56 @@ private static void sleepNanos(final long nanos) { private String getShardIterator(final PollingShardState state, final String streamName, final String shardId, final String initialStreamPosition, final KinesisShardManager shardManager) { - final String lastSequenceNumber; try { - lastSequenceNumber = shardManager.readCheckpoint(shardId); - } catch (final Exception e) { - if (!state.isStopped()) { - logger.warn("Failed to read checkpoint for shard {}; will retry", shardId, e); - } + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); return null; } - final ShardIteratorType iteratorType; - final String startingSequenceNumber; - final Instant timestamp; - if (lastSequenceNumber != null) { - iteratorType = ShardIteratorType.AFTER_SEQUENCE_NUMBER; - startingSequenceNumber = lastSequenceNumber; - timestamp = null; - } else { - iteratorType = ShardIteratorType.fromValue(initialStreamPosition); - startingSequenceNumber = null; - timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? timestampForInitialPosition : null; - } + try { + final String lastSequenceNumber; + try { + lastSequenceNumber = shardManager.readCheckpoint(shardId); + } catch (final Exception e) { + if (!state.isStopped()) { + logger.warn("Failed to read checkpoint for shard {}; will retry", shardId, e); + } + return null; + } - logger.debug("Getting shard iterator for shard {} with type={}, startingSeq={}, timestamp={}", - shardId, iteratorType, startingSequenceNumber, timestamp); + final ShardIteratorType iteratorType; + final String startingSequenceNumber; + final Instant timestamp; + if (lastSequenceNumber != null) { + iteratorType = ShardIteratorType.AFTER_SEQUENCE_NUMBER; + startingSequenceNumber = lastSequenceNumber; + timestamp = null; + } else { + iteratorType = ShardIteratorType.fromValue(initialStreamPosition); + startingSequenceNumber = null; + timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? timestampForInitialPosition : null; + } - final GetShardIteratorRequest.Builder iteratorRequestBuilder = GetShardIteratorRequest.builder() - .streamName(streamName) - .shardId(shardId) - .shardIteratorType(iteratorType); + logger.debug("Getting shard iterator for shard {} with type={}, startingSeq={}, timestamp={}", + shardId, iteratorType, startingSequenceNumber, timestamp); - if (startingSequenceNumber != null) { - iteratorRequestBuilder.startingSequenceNumber(startingSequenceNumber); - } - if (timestamp != null) { - iteratorRequestBuilder.timestamp(timestamp); - } + final GetShardIteratorRequest.Builder iteratorRequestBuilder = GetShardIteratorRequest.builder() + .streamName(streamName) + .shardId(shardId) + .shardIteratorType(iteratorType); - try { - fetchPermits.acquire(); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - return null; - } + if (startingSequenceNumber != null) { + iteratorRequestBuilder.startingSequenceNumber(startingSequenceNumber); + } + if (timestamp != null) { + iteratorRequestBuilder.timestamp(timestamp); + } - try { return kinesisClient.getShardIterator(iteratorRequestBuilder.build()).shardIterator(); } catch (final Exception e) { if (!state.isStopped()) { - logger.error("Failed to get shard iterator for shard {} (type={}, seq={})", - shardId, iteratorType, startingSequenceNumber, e); + logger.error("Failed to get shard iterator for shard {} in stream {}", shardId, streamName, e); } return null; } finally { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java index 129a8e3eca8b..5133adc26d48 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java @@ -38,12 +38,14 @@ import java.time.Instant; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -192,15 +194,15 @@ void testExpiredIteratorRecoveryDoesNotDeliverSameShardOutOfOrder() throws Excep when(mockShardManager.readCheckpoint(anyString())).thenReturn("100"); when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { - final int call = getShardIteratorCallCount.incrementAndGet(); - return GetShardIteratorResponse.builder().shardIterator("iter-" + call).build(); + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); }); when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { getRecordsCallCount.incrementAndGet(); - final GetRecordsRequest req = invocation.getArgument(0); + final GetRecordsRequest request = invocation.getArgument(0); - return switch (req.shardIterator()) { + return switch (request.shardIterator()) { case "iter-1" -> GetRecordsResponse.builder() .records(record("200", "A")) .nextShardIterator("iter-1a") @@ -219,7 +221,7 @@ void testExpiredIteratorRecoveryDoesNotDeliverSameShardOutOfOrder() throws Excep .build(); default -> GetRecordsResponse.builder() .records(List.of()) - .nextShardIterator(req.shardIterator()) + .nextShardIterator(request.shardIterator()) .millisBehindLatest(0L) .build(); }; @@ -338,17 +340,16 @@ void testHasPendingFetchesFalseWhenAllShardsExhausted() throws Exception { } /** - * Reproduces out-of-order delivery caused by stale results remaining in the per-shard - * queue after a rollback. The scenario for a single shard: + * Reproduces out-of-order delivery caused by stale results remaining in the per-shard queue after a rollback. + * The scenario for a single shard: * *
    - *
  1. Fetch loop enqueues R1 (seq 100-200) and R2 (seq 300-400).
  2. - *
  3. Consumer polls R1 only; R2 remains in the queue.
  4. - *
  5. Consumer calls rollbackResults on R1, which sets the reset flag. - * The fetch loop detects the flag, drains R2 from the queue, - * resets the shard iterator, and re-fetches.
  6. - *
  7. After the reset, the first result polled must come from the re-fetched - * sequence (seq 500), not the stale R2 (seq 300).
  8. + *
  9. Fetch loop enqueues result 1 (sequence 100-200) and result 2 (sequence 300-400).
  10. + *
  11. Consumer polls result 1 only; result 2 remains in the queue.
  12. + *
  13. Consumer calls rollbackResults on result 1, which drains the queue synchronously and sets the reset flag. + * The fetch loop detects the flag, drains any stragglers, resets the shard iterator, and re-fetches.
  14. + *
  15. After the reset, the first result polled must come from the re-fetched sequence (sequence 500), + * not the stale result 2 (sequence 300).
  16. *
*/ @Test @@ -358,28 +359,28 @@ void testRollbackDrainsStaleResultsFromQueue() throws Exception { when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { - final int call = getShardIteratorCallCount.incrementAndGet(); - return GetShardIteratorResponse.builder().shardIterator("iter-" + call).build(); + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); }); when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { getRecordsCallCount.incrementAndGet(); - final GetRecordsRequest req = invocation.getArgument(0); + final GetRecordsRequest request = invocation.getArgument(0); - if (req.shardIterator().equals("iter-1")) { + if (request.shardIterator().equals("iter-1")) { return GetRecordsResponse.builder() .records(record("100", "A"), record("200", "B")) .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); } - if (req.shardIterator().equals("iter-1a")) { + if (request.shardIterator().equals("iter-1a")) { return GetRecordsResponse.builder() .records(record("300", "C"), record("400", "D")) .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); } - if (req.shardIterator().startsWith("iter-1")) { + if (request.shardIterator().startsWith("iter-1")) { return GetRecordsResponse.builder() .records(List.of()) - .nextShardIterator(req.shardIterator()).millisBehindLatest(0L).build(); + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); } return GetRecordsResponse.builder() .records(record("500", "E"), record("600", "F")) @@ -388,17 +389,17 @@ void testRollbackDrainsStaleResultsFromQueue() throws Exception { consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); - final ShardFetchResult r1 = consumer.pollAnyResult(5, TimeUnit.SECONDS); - assertNotNull(r1, "R1 must be available"); - assertEquals(new BigInteger("100"), r1.firstSequenceNumber()); + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); - final long r2Deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); - while (System.nanoTime() < r2Deadline && getRecordsCallCount.get() < 2) { + final long enqueueDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < enqueueDeadline && getRecordsCallCount.get() < 2) { Thread.sleep(20); } Thread.sleep(50); - consumer.rollbackResults(List.of(r1)); + consumer.rollbackResults(List.of(firstResult)); final long resetDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); while (getShardIteratorCallCount.get() < 2 && System.nanoTime() < resetDeadline) { @@ -407,11 +408,155 @@ void testRollbackDrainsStaleResultsFromQueue() throws Exception { Thread.sleep(100); final ShardFetchResult firstAfterRollback = consumer.pollShardResult("shard-1"); - assertNotNull(firstAfterRollback, "Queue must contain results after rollback + re-fetch"); + assertNotNull(firstAfterRollback, "Queue must contain results after rollback and re-fetch"); assertEquals(new BigInteger("500"), firstAfterRollback.firstSequenceNumber(), "First result after rollback must be re-fetched data, not a stale pre-rollback result"); } + /** + * Verifies that rollbackResults drains the queue synchronously so that a concurrent consumer cannot poll stale + * pre-rollback results. The fetch loop is blocked inside a GetRecords call while rollback happens, proving the + * drain must occur in rollbackResults itself rather than being deferred to the fetch loop thread. + */ + @Test + void testRollbackDrainsSynchronouslyPreventingConcurrentStaleRead() throws Exception { + final CountDownLatch fetchLoopBlocked = new CountDownLatch(1); + final CountDownLatch unblockFetchLoop = new CountDownLatch(1); + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))) + .thenReturn(GetShardIteratorResponse.builder().shardIterator("iter-1").build()); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + final int callNumber = getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + if (callNumber == 1) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (callNumber == 2) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (callNumber == 3) { + fetchLoopBlocked.countDown(); + unblockFetchLoop.await(10, TimeUnit.SECONDS); + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator("iter-1c").millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); + + fetchLoopBlocked.await(5, TimeUnit.SECONDS); + + consumer.rollbackResults(List.of(firstResult)); + + final ShardFetchResult staleResult = consumer.pollShardResult("shard-1"); + + unblockFetchLoop.countDown(); + + assertNull(staleResult, "After rollback, stale results must not be pollable — rollbackResults must drain synchronously"); + } + + /** + * Reproduces the race condition where the fetch loop returns records from GetRecords AFTER rollbackResults has + * already drained the queue. Without the per-shard lock, the fetch loop would enqueue the stale result after the + * drain, making it visible to the next consumer poll. With the lock, the enqueue is rejected because the fetch + * loop sees isResetRequested inside the synchronized block. + * + *

Timeline: + *

    + *
  1. Fetch loop enqueues result 1 (sequence 100-200) and result 2 (sequence 300-400).
  2. + *
  3. Consumer polls result 1.
  4. + *
  5. GetRecords call 3 blocks, holding a response with records (sequence 500-600).
  6. + *
  7. Consumer rolls back result 1 — this drains result 2 and sets the reset flag.
  8. + *
  9. GetRecords call 3 unblocks — fetch loop receives the stale response.
  10. + *
  11. Fetch loop calls enqueueIfActive under the shard lock, sees isResetRequested, and discards the result.
  12. + *
  13. Fetch loop processes the reset and re-fetches from the checkpoint (sequence 800).
  14. + *
  15. The first result available after rollback must be the re-fetched data (sequence 800), not the stale data (sequence 500).
  16. + *
+ */ + @Test + void testLockPreventsStaleEnqueueDuringConcurrentRollback() throws Exception { + final CountDownLatch fetchLoopBlockedInsideGetRecords = new CountDownLatch(1); + final CountDownLatch unblockGetRecords = new CountDownLatch(1); + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + final int callNumber = getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + if (callNumber == 1) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (callNumber == 2) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (callNumber == 3) { + fetchLoopBlockedInsideGetRecords.countDown(); + unblockGetRecords.await(10, TimeUnit.SECONDS); + return GetRecordsResponse.builder() + .records(record("500", "E"), record("600", "F")) + .nextShardIterator("iter-1c").millisBehindLatest(0L).build(); + } + if (request.shardIterator().startsWith("iter-2")) { + return GetRecordsResponse.builder() + .records(record("800", "G"), record("900", "H")) + .nextShardIterator("iter-2a").millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); + + fetchLoopBlockedInsideGetRecords.await(5, TimeUnit.SECONDS); + + consumer.rollbackResults(List.of(firstResult)); + + unblockGetRecords.countDown(); + + final long resetDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (getShardIteratorCallCount.get() < 2 && System.nanoTime() < resetDeadline) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRollback = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRollback, "Queue must contain results after rollback and re-fetch"); + assertEquals(new BigInteger("800"), firstAfterRollback.firstSequenceNumber(), + "First result after rollback must be re-fetched data (800), not the stale data (500) that was returned from GetRecords during the rollback"); + } + private void drainAllResults() throws InterruptedException { ShardFetchResult discarded; do { From 07e3bd614021db51e8f6e6166f454f03afef1913 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Wed, 11 Mar 2026 14:20:33 -0400 Subject: [PATCH 5/7] NIFI-15669: Addressed review feedback --- .../nifi-aws-kinesis-nar/pom.xml | 2 +- .../nifi-aws-bundle/nifi-aws-kinesis/pom.xml | 8 +- .../aws/kinesis/CheckpointTableUtils.java | 20 ++--- .../aws/kinesis/ConsumeKinesis.java | 58 +++++++------- ...sClient.java => EnhancedFanOutClient.java} | 76 +++++++++---------- .../aws/kinesis/KinesisConsumerClient.java | 14 +++- .../aws/kinesis/KinesisRecordMetadata.java | 2 +- .../aws/kinesis/KinesisShardManager.java | 8 +- .../aws/kinesis/PollingKinesisClient.java | 14 +--- ....java => ProducerLibraryDeaggregator.java} | 63 ++++++++------- .../aws/kinesis/ShardFetchResult.java | 2 +- ...eaggregatedRecord.java => UserRecord.java} | 2 +- .../aws/kinesis/CheckpointTableUtilsTest.java | 40 ++++++---- .../aws/kinesis/ConsumeKinesisIT.java | 8 +- .../aws/kinesis/ConsumeKinesisTest.java | 34 ++++----- .../kinesis/KinesisConsumerClientTest.java | 36 ++++----- ...a => ProducerLibraryDeaggregatorTest.java} | 58 +++++++------- 17 files changed, 226 insertions(+), 219 deletions(-) rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/{EfoKinesisClient.java => EnhancedFanOutClient.java} (88%) rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/{KplDeaggregator.java => ProducerLibraryDeaggregator.java} (76%) rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/{DeaggregatedRecord.java => UserRecord.java} (98%) rename nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/{KplDeaggregatorTest.java => ProducerLibraryDeaggregatorTest.java} (84%) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml index 838404b604d5..2ab90ef73abf 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml @@ -98,7 +98,7 @@
software.amazon.awssdk - apache-client + apache5-client ${software.amazon.awssdk.version} provided diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml index 3f9caf4a3cea..3b03857ff839 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml @@ -53,13 +53,7 @@ software.amazon.awssdk - apache-client - - - commons-logging - commons-logging - - + apache5-client software.amazon.awssdk diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java index de038de782f3..633d9deb6dbc 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java @@ -46,6 +46,8 @@ */ final class CheckpointTableUtils { + static final String ATTR_STREAM_NAME = "streamName"; + static final String ATTR_SHARD_ID = "shardId"; static final String NODE_HEARTBEAT_PREFIX = "__node__#"; static final String MIGRATION_MARKER_SHARD_ID = "__migration__"; @@ -66,8 +68,8 @@ static TableSchema getTableSchema(final DynamoDbClient client, final String tabl final DescribeTableResponse describe = client.describeTable(DescribeTableRequest.builder().tableName(tableName).build()); final List keySchema = describe.table().keySchema(); if (keySchema.size() == 2 - && hasKey(keySchema, "streamName", KeyType.HASH) - && hasKey(keySchema, "shardId", KeyType.RANGE)) { + && hasKey(keySchema, ATTR_STREAM_NAME, KeyType.HASH) + && hasKey(keySchema, ATTR_SHARD_ID, KeyType.RANGE)) { return TableSchema.NEW; } @@ -96,11 +98,11 @@ static void createNewSchemaTable(final DynamoDbClient client, final ComponentLog final CreateTableRequest request = CreateTableRequest.builder() .tableName(tableName) .keySchema( - KeySchemaElement.builder().attributeName("streamName").keyType(KeyType.HASH).build(), - KeySchemaElement.builder().attributeName("shardId").keyType(KeyType.RANGE).build()) + KeySchemaElement.builder().attributeName(ATTR_STREAM_NAME).keyType(KeyType.HASH).build(), + KeySchemaElement.builder().attributeName(ATTR_SHARD_ID).keyType(KeyType.RANGE).build()) .attributeDefinitions( - AttributeDefinition.builder().attributeName("streamName").attributeType(ScalarAttributeType.S).build(), - AttributeDefinition.builder().attributeName("shardId").attributeType(ScalarAttributeType.S).build()) + AttributeDefinition.builder().attributeName(ATTR_STREAM_NAME).attributeType(ScalarAttributeType.S).build(), + AttributeDefinition.builder().attributeName(ATTR_SHARD_ID).attributeType(ScalarAttributeType.S).build()) .billingMode(BillingMode.PAY_PER_REQUEST) .build(); @@ -123,7 +125,7 @@ static void waitForTableActive(final DynamoDbClient client, final ComponentLog l Thread.sleep(TABLE_POLL_MILLIS); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); - throw new ProcessException("Interrupted while waiting for DynamoDB table to become ACTIVE", e); + throw new ProcessException("Interrupted while waiting for DynamoDB table [%s] to become ACTIVE".formatted(tableName), e); } } @@ -153,7 +155,7 @@ static void waitForTableDeleted(final DynamoDbClient client, final ComponentLog Thread.sleep(TABLE_POLL_MILLIS); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); - throw new ProcessException("Interrupted while waiting for DynamoDB table deletion", e); + throw new ProcessException("Interrupted while waiting for DynamoDB table [%s] deletion".formatted(tableName), e); } } @@ -173,7 +175,7 @@ static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog final ScanResponse scanResponse = client.scan(scanRequest); for (final Map item : scanResponse.items()) { - final AttributeValue shardIdAttr = item.get("shardId"); + final AttributeValue shardIdAttr = item.get(ATTR_SHARD_ID); if (shardIdAttr != null) { final String shardId = shardIdAttr.s(); if (shardId.startsWith(NODE_HEARTBEAT_PREFIX) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index 78d67495c237..aab3b702453d 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -63,7 +63,7 @@ import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.http.Protocol; import software.amazon.awssdk.http.SdkHttpClient; -import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.http.apache5.Apache5HttpClient; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; @@ -433,11 +433,7 @@ public void onScheduled(final ProcessContext context) { final Instant timestampForPosition = resolveTimestampPosition(context); if (timestampForPosition != null) { - if (consumerClient instanceof PollingKinesisClient polling) { - polling.setTimestampForInitialPosition(timestampForPosition); - } else if (consumerClient instanceof EfoKinesisClient efo) { - efo.setTimestampForInitialPosition(timestampForPosition); - } + consumerClient.setTimestampForInitialPosition(timestampForPosition); } if (efoMode) { @@ -484,18 +480,18 @@ private static Instant resolveTimestampPosition(final ProcessContext context) { } /** - * Builds an {@link ApacheHttpClient} with the given connection pool size and optional proxy + * Builds an {@link Apache5HttpClient} with the given connection pool size and optional proxy * configuration. Each AWS service client (Kinesis, DynamoDB) should receive its own HTTP client * so their connection pools are isolated and cannot starve each other under high shard counts. */ private static SdkHttpClient buildApacheHttpClient(final ProxyConfiguration proxyConfig, final int maxConnections) { - final ApacheHttpClient.Builder builder = ApacheHttpClient.builder() + final Apache5HttpClient.Builder builder = Apache5HttpClient.builder() .maxConnections(maxConnections); if (Proxy.Type.HTTP.equals(proxyConfig.getProxyType())) { final URI proxyEndpoint = URI.create(String.format("http://%s:%s", proxyConfig.getProxyServerHost(), proxyConfig.getProxyServerPort())); - final software.amazon.awssdk.http.apache.ProxyConfiguration.Builder proxyBuilder = - software.amazon.awssdk.http.apache.ProxyConfiguration.builder().endpoint(proxyEndpoint); + final software.amazon.awssdk.http.apache5.ProxyConfiguration.Builder proxyBuilder = + software.amazon.awssdk.http.apache5.ProxyConfiguration.builder().endpoint(proxyEndpoint); if (proxyConfig.hasCredential()) { proxyBuilder.username(proxyConfig.getProxyUserName()); @@ -516,7 +512,7 @@ public void onStopped() { shardManager = null; } - if (consumerClient instanceof EfoKinesisClient efo) { + if (consumerClient instanceof EnhancedFanOutClient efo) { efoConsumerArn = efo.getConsumerArn(); } if (consumerClient != null) { @@ -776,7 +772,7 @@ private WriteResult writeResults(final ProcessSession session, final ProcessCont for (final List shardResults : resultsByShard.values()) { for (final ShardFetchResult result : shardResults) { batch.updateMillisBehind(result.millisBehindLatest()); - for (final DeaggregatedRecord record : result.records()) { + for (final UserRecord record : result.records()) { batch.addBytes(record.data().length); } } @@ -792,7 +788,7 @@ private WriteResult writeResults(final ProcessSession session, final ProcessCont for (final ShardFetchResult result : entry.getValue()) { batch.updateMillisBehind(result.millisBehindLatest()); batch.updateSequenceRange(result); - for (final DeaggregatedRecord record : result.records()) { + for (final UserRecord record : result.records()) { batch.addBytes(record.data().length); batch.updateRecordRange(record); } @@ -828,7 +824,7 @@ private WriteResult writeResults(final ProcessSession session, final ProcessCont private void writeFlowFilePerRecord(final ProcessSession session, final List results, final String streamName, final BatchAccumulator batch, final List output) { for (final ShardFetchResult result : results) { - for (final DeaggregatedRecord record : result.records()) { + for (final UserRecord record : result.records()) { final byte[] recordBytes = record.data(); FlowFile flowFile = session.create(); try { @@ -872,7 +868,7 @@ private void writeDelimited(final ProcessSession session, final List output, final List parseFailureOutput) { - final List allRecords = new ArrayList<>(); + final List allRecords = new ArrayList<>(); for (final ShardFetchResult result : results) { allRecords.addAll(result.records()); } @@ -966,7 +962,7 @@ private void writeRecordOriented(final ProcessSession session, final ProcessCont private void writeRecordBatch(final ProcessSession session, final RecordReaderFactory readerFactory, final RecordSetWriterFactory writerFactory, final OutputStrategy outputStrategy, - final List records, + final List records, final String streamName, final BatchAccumulator batch, final List output) { FlowFile flowFile = session.create(); @@ -996,7 +992,7 @@ public void process(final OutputStream out) throws IOException { int recordIndex = 0; org.apache.nifi.serialization.record.Record nifiRecord; while ((nifiRecord = reader.nextRecord()) != null) { - final DeaggregatedRecord record = records.get(recordIndex++); + final UserRecord record = records.get(recordIndex++); nifiRecord = decorateRecord(nifiRecord, record, record.shardId(), streamName, outputStrategy, writeSchema); writer.write(nifiRecord); @@ -1046,7 +1042,7 @@ public void process(final OutputStream out) throws IOException { * @return a {@link RecordBatchResult} containing the successfully written FlowFiles and any parse-failure FlowFiles */ private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session, final ProcessContext context, - final List records, + final List records, final String streamName, final BatchAccumulator batch) { final RecordReaderFactory readerFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); @@ -1063,7 +1059,7 @@ private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session RecordSchema currentWriteSchema = null; try { - for (final DeaggregatedRecord record : records) { + for (final UserRecord record : records) { if (record.data().length == 0) { unparseable.add(new ParseFailureRecord(record, "Record content is empty")); @@ -1192,7 +1188,7 @@ yield new SimpleRecordSchema(List.of( */ private static org.apache.nifi.serialization.record.Record decorateRecord( final org.apache.nifi.serialization.record.Record nifiRecord, - final DeaggregatedRecord kinesisRecord, final String shardId, + final UserRecord kinesisRecord, final String shardId, final String streamName, final OutputStrategy outputStrategy, final RecordSchema writeSchema) { return switch (outputStrategy) { @@ -1217,7 +1213,7 @@ private void writeParseFailures(final ProcessSession session, final List parseFailureOutput) { for (final ParseFailureRecord parseFailureRecord : unparseable) { - final DeaggregatedRecord record = parseFailureRecord.record(); + final UserRecord record = parseFailureRecord.record(); FlowFile flowFile = session.create(); try { final byte[] rawBytes = record.data(); @@ -1249,7 +1245,7 @@ private void writeParseFailures(final ProcessSession session, final List output, List parseFailures) { } - private record ParseFailureRecord(DeaggregatedRecord record, String reason) { + private record ParseFailureRecord(UserRecord record, String reason) { } private static final class KinesisRecordInputStream extends InputStream { @@ -1329,9 +1325,9 @@ private static final class KinesisRecordInputStream extends InputStream { private int markChunkIndex = -1; private int markPositionInChunk; - KinesisRecordInputStream(final List records) { + KinesisRecordInputStream(final List records) { this.chunks = new ArrayList<>(records.size()); - for (final DeaggregatedRecord record : records) { + for (final UserRecord record : records) { final byte[] data = record.data(); if (data.length > 0) { chunks.add(data); @@ -1441,11 +1437,11 @@ long getMaxMillisBehind() { } String getMinSequenceNumber() { - return minSequenceNumber != null ? minSequenceNumber.toString() : null; + return minSequenceNumber == null ? null : minSequenceNumber.toString(); } String getMaxSequenceNumber() { - return maxSequenceNumber != null ? maxSequenceNumber.toString() : null; + return maxSequenceNumber == null ? null : maxSequenceNumber.toString(); } long getMinSubSequenceNumber() { @@ -1499,7 +1495,7 @@ void updateSequenceRange(final ShardFetchResult result) { } } - void updateRecordRange(final DeaggregatedRecord record) { + void updateRecordRange(final UserRecord record) { updateSequenceFromRecord(record); final long subSeq = record.subSequenceNumber(); if (subSeq < minSubSequenceNumber) { @@ -1515,7 +1511,7 @@ void updateRecordRange(final DeaggregatedRecord record) { } } - void updateSequenceFromRecord(final DeaggregatedRecord record) { + void updateSequenceFromRecord(final UserRecord record) { final BigInteger seqNum = new BigInteger(record.sequenceNumber()); if (minSequenceNumber == null || seqNum.compareTo(minSequenceNumber) < 0) { minSequenceNumber = seqNum; diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java similarity index 88% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java index 4bfa114cc279..50b7397b33fe 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EfoKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java @@ -44,7 +44,6 @@ import java.io.IOException; import java.math.BigInteger; -import java.time.Instant; import java.util.List; import java.util.Map; import java.util.Queue; @@ -64,7 +63,7 @@ * per shard via HTTP/2. Uses Reactive Streams demand-driven backpressure to control the * rate of event delivery. */ -final class EfoKinesisClient extends KinesisConsumerClient { +final class EnhancedFanOutClient extends KinesisConsumerClient { private static final long SUBSCRIBE_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(5); private static final long CONSUMER_REGISTRATION_POLL_MILLIS = 1_000; @@ -75,31 +74,19 @@ final class EfoKinesisClient extends KinesisConsumerClient { private final Queue pausedConsumers = new ConcurrentLinkedQueue<>(); private volatile KinesisAsyncClient kinesisAsyncClient; private volatile String consumerArn; - private volatile Instant timestampForInitialPosition; + private volatile String streamName; - EfoKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger) { + EnhancedFanOutClient(final KinesisClient kinesisClient, final ComponentLog logger) { super(kinesisClient, logger); } - void setTimestampForInitialPosition(final Instant timestamp) { - this.timestampForInitialPosition = timestamp; - } - @Override void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { this.kinesisAsyncClient = asyncClient; + this.streamName = streamName; registerEfoConsumer(streamName, consumerName); } - void initializeForTest(final KinesisAsyncClient asyncClient, final String theConsumerArn) { - this.kinesisAsyncClient = asyncClient; - this.consumerArn = theConsumerArn; - } - - ShardConsumer getShardConsumer(final String shardId) { - return shardConsumers.get(shardId); - } - @Override void startFetches(final List shards, final String streamName, final int batchSize, final String initialStreamPosition, final KinesisShardManager shardManager) { final long now = System.nanoTime(); @@ -110,9 +97,9 @@ void startFetches(final List shards, final String streamName, final int b if (existing == null) { final String checkpoint = shardManager.readCheckpoint(shardId); - final BigInteger lastSeq = checkpoint != null ? new BigInteger(checkpoint) : null; + final BigInteger lastSeq = checkpoint == null ? null : new BigInteger(checkpoint); final StartingPosition startingPosition = buildStartingPosition(lastSeq, initialStreamPosition); - logger.info("Creating EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), lastSeq); + logger.info("Creating Enhanced Fan-Out subscription for stream [{}] shard [{}] type [{}] seq [{}]", streamName, shardId, startingPosition.type(), lastSeq); final ShardConsumer shardConsumer = new ShardConsumer(shardId, result -> enqueueIfActiveConsumer(shardId, result), pausedConsumers, logger); final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, shardConsumer); if (prior == null) { @@ -134,7 +121,7 @@ void startFetches(final List shards, final String streamName, final int b final BigInteger lastQueued = existing.getLastQueuedSequenceNumber(); final BigInteger resumeSeq = maxSequenceNumber(lastQueued, checkpointSeq); final StartingPosition startingPosition = buildStartingPosition(resumeSeq, initialStreamPosition); - logger.debug("Renewing expired EFO subscription for shard {} with type={}, seq={}", shardId, startingPosition.type(), resumeSeq); + logger.debug("Renewing expired Enhanced Fan-Out subscription for stream [{}] shard [{}] type [{}] seq [{}]", streamName, shardId, startingPosition.type(), resumeSeq); existing.subscribe(kinesisAsyncClient, consumerArn, startingPosition); } } @@ -231,7 +218,7 @@ void logDiagnostics(final int ownedCount, final int cachedShardCount) { } final int queueDepth = totalQueuedResults(); - logger.debug("Kinesis EFO diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, activeSubscriptions={}, expiredSubscriptions={}, backedOff={}", + logger.debug("Kinesis Enhanced Fan-Out diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, activeSubscriptions={}, expiredSubscriptions={}, backedOff={}", cachedShardCount, ownedCount, queueDepth, MAX_QUEUED_RESULTS, shardConsumers.size(), activeSubscriptions, expiredSubscriptions, backedOff); } @@ -250,6 +237,15 @@ void close() { super.close(); } + void initializeForTest(final KinesisAsyncClient asyncClient, final String consumerArn) { + this.kinesisAsyncClient = asyncClient; + this.consumerArn = consumerArn; + } + + ShardConsumer getShardConsumer(final String shardId) { + return shardConsumers.get(shardId); + } + String getConsumerArn() { return consumerArn; } @@ -267,7 +263,7 @@ private void registerEfoConsumer(final String streamName, final String consumerN final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerStatus(); if (status == ConsumerStatus.ACTIVE) { consumerArn = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerARN(); - logger.info("EFO consumer [{}] already registered and ACTIVE", consumerName); + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] already registered and ACTIVE", consumerName, streamName); return; } } catch (final ResourceNotFoundException ignored) { @@ -280,29 +276,29 @@ private void registerEfoConsumer(final String streamName, final String consumerN .build(); final RegisterStreamConsumerResponse registerResponse = kinesisClient.registerStreamConsumer(registerRequest); consumerArn = registerResponse.consumer().consumerARN(); - logger.info("Registered EFO consumer [{}], waiting for ACTIVE status", consumerName); + logger.info("Registered Enhanced Fan-Out consumer [{}] for stream [{}], waiting for ACTIVE status", consumerName, streamName); } catch (final ResourceInUseException e) { final DescribeStreamConsumerRequest fallbackRequest = DescribeStreamConsumerRequest.builder() .streamARN(arn) .consumerName(consumerName) .build(); consumerArn = kinesisClient.describeStreamConsumer(fallbackRequest).consumerDescription().consumerARN(); - logger.info("EFO consumer [{}] already being registered", consumerName); + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] already being registered", consumerName, streamName); } waitForConsumerActive(arn, consumerName); } - private void waitForConsumerActive(final String theStreamArn, final String consumerName) { + private void waitForConsumerActive(final String streamArn, final String consumerName) { final DescribeStreamConsumerRequest describeConsumerRequest = DescribeStreamConsumerRequest.builder() - .streamARN(theStreamArn) + .streamARN(streamArn) .consumerName(consumerName) .build(); for (int i = 0; i < CONSUMER_REGISTRATION_MAX_ATTEMPTS; i++) { final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerRequest).consumerDescription().consumerStatus(); if (status == ConsumerStatus.ACTIVE) { - logger.info("EFO consumer [{}] is now ACTIVE", consumerName); + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] is now ACTIVE", consumerName, streamName); return; } @@ -310,11 +306,11 @@ private void waitForConsumerActive(final String theStreamArn, final String consu Thread.sleep(CONSUMER_REGISTRATION_POLL_MILLIS); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); - throw new ProcessException("Interrupted while waiting for EFO consumer registration", e); + throw new ProcessException("Interrupted while waiting for Enhanced Fan-Out consumer [%s] registration for stream [%s]".formatted(consumerName, streamName), e); } } - throw new ProcessException("EFO consumer [%s] did not become ACTIVE within %d seconds".formatted(consumerName, CONSUMER_REGISTRATION_MAX_ATTEMPTS)); + throw new ProcessException("Enhanced Fan-Out consumer [%s] for stream [%s] did not become ACTIVE within %d seconds".formatted(consumerName, streamName, CONSUMER_REGISTRATION_MAX_ATTEMPTS)); } private static BigInteger maxSequenceNumber(final BigInteger a, final BigInteger b) { @@ -336,8 +332,8 @@ private StartingPosition buildStartingPosition(final BigInteger sequenceNumber, } final ShardIteratorType iteratorType = ShardIteratorType.fromValue(initialStreamPosition); final StartingPosition.Builder builder = StartingPosition.builder().type(iteratorType); - if (iteratorType == ShardIteratorType.AT_TIMESTAMP && timestampForInitialPosition != null) { - builder.timestamp(timestampForInitialPosition); + if (iteratorType == ShardIteratorType.AT_TIMESTAMP && getTimestampForInitialPosition() != null) { + builder.timestamp(getTimestampForInitialPosition()); } return builder.build(); } @@ -363,7 +359,7 @@ static final class ShardConsumer { this.consumerLogger = consumerLogger; } - void subscribe(final KinesisAsyncClient asyncClient, final String theConsumerArn, final StartingPosition startingPosition) { + void subscribe(final KinesisAsyncClient asyncClient, final String consumerArn, final StartingPosition startingPosition) { if (!subscribing.compareAndSet(false, true)) { return; } @@ -372,7 +368,7 @@ void subscribe(final KinesisAsyncClient asyncClient, final String theConsumerArn try { final SubscribeToShardRequest request = SubscribeToShardRequest.builder() - .consumerARN(theConsumerArn) + .consumerARN(consumerArn) .shardId(shardId) .startingPosition(startingPosition) .build(); @@ -469,13 +465,13 @@ void pause() { private void logSubscriptionError(final Throwable t) { if (isCancellation(t)) { - consumerLogger.debug("EFO subscription cancelled for shard {}", shardId); + consumerLogger.debug("Enhanced Fan-Out subscription cancelled for shard [{}]", shardId); } else if (isRetryableSubscriptionError(t)) { - consumerLogger.info("EFO subscription temporarily rejected for shard {}; will retry after backoff", shardId); + consumerLogger.info("Enhanced Fan-Out subscription temporarily rejected for shard [{}]; will retry after backoff", shardId); } else if (isRetryableStreamDisconnect(t)) { - consumerLogger.info("EFO subscription disconnected for shard {}; will retry", shardId); + consumerLogger.info("Enhanced Fan-Out subscription disconnected for shard [{}]; will retry", shardId); } else { - consumerLogger.error("EFO subscription error for shard {}", shardId, t); + consumerLogger.error("Enhanced Fan-Out subscription error for shard [{}]", shardId, t); } } @@ -540,9 +536,9 @@ private List deduplicateRecords(final List records) { final int kept = records.size() - firstNewIndex; deduplicatedEvents.incrementAndGet(); if (kept == 0) { - consumerLogger.debug("Skipped re-delivered EFO event for shard {} ({} records already seen)", shardId, records.size()); + consumerLogger.debug("Skipped re-delivered Enhanced Fan-Out event for shard [{}] ({} records already seen)", shardId, records.size()); } else { - consumerLogger.debug("Filtered {} duplicate record(s) from EFO event for shard {} (kept {})", firstNewIndex, shardId, kept); + consumerLogger.debug("Filtered {} duplicate record(s) from Enhanced Fan-Out event for shard [{}] (kept {})", firstNewIndex, shardId, kept); } return records.subList(firstNewIndex, records.size()); @@ -597,7 +593,7 @@ public void onError(final Throwable t) { @Override public void onComplete() { - consumerLogger.debug("EFO subscription completed normally for shard {}", shardId); + consumerLogger.debug("Enhanced Fan-Out subscription completed normally for shard [{}]", shardId); endSubscriptionIfCurrent(generation); } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java index bdcdafcc595a..ff1e8a0e4fa6 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java @@ -22,6 +22,7 @@ import software.amazon.awssdk.services.kinesis.model.Record; import software.amazon.awssdk.services.kinesis.model.Shard; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -54,12 +55,21 @@ abstract class KinesisConsumerClient { protected final Set shardsInFlight = ConcurrentHashMap.newKeySet(); private volatile long lastDiagnosticLogNanos; + private volatile Instant timestampForInitialPosition; KinesisConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger) { this.kinesisClient = kinesisClient; this.logger = logger; } + void setTimestampForInitialPosition(final Instant timestamp) { + this.timestampForInitialPosition = timestamp; + } + + Instant getTimestampForInitialPosition() { + return timestampForInitialPosition; + } + void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { } @@ -94,7 +104,7 @@ void enqueueResult(final ShardFetchResult result) { ShardFetchResult pollShardResult(final String shardId) { final Queue queue = shardQueues.get(shardId); - final ShardFetchResult result = queue != null ? queue.poll() : null; + final ShardFetchResult result = queue == null ? null : queue.poll(); if (result != null) { onResultPolled(); } @@ -177,7 +187,7 @@ void releaseShards(final Collection shardIds) { } protected static ShardFetchResult createFetchResult(final String shardId, final List records, final long millisBehindLatest) { - return new ShardFetchResult(shardId, KplDeaggregator.deaggregate(shardId, records), millisBehindLatest); + return new ShardFetchResult(shardId, ProducerLibraryDeaggregator.deaggregate(shardId, records), millisBehindLatest); } long drainDeduplicatedEventCount() { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java index 9d9243d6dcc0..bdbeb697ec5e 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java @@ -58,7 +58,7 @@ public final class KinesisRecordMetadata { public static final RecordField FIELD_METADATA = new RecordField(METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)); - public static Record composeMetadataObject(final DeaggregatedRecord record, final String streamName, final String shardId) { + public static Record composeMetadataObject(final UserRecord record, final String streamName, final String shardId) { final Map metadata = new HashMap<>(7, 1.0f); metadata.put(STREAM, streamName); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java index 752e2fec62d3..53fbfd41bd3c 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java @@ -219,7 +219,7 @@ void refreshLeasesIfNecessary(final int clusterMemberCount) { if (item != null && item.containsKey("leaseOwner")) { final String owner = item.get("leaseOwner").s(); final AttributeValue expiryAttr = item.get("leaseExpiry"); - final long expiry = expiryAttr != null ? Long.parseLong(expiryAttr.n()) : 0; + final long expiry = expiryAttr == null ? 0 : Long.parseLong(expiryAttr.n()); if (expiry >= now) { ownerToShards.computeIfAbsent(owner, k -> new ArrayList<>()).add(shardId); } else { @@ -446,7 +446,7 @@ private Map> queryAllLeaseItems() { : queryBuilder.exclusiveStartKey(exclusiveStartKey).build(); final QueryResponse queryResponse = dynamoDbClient.query(queryRequest); for (final Map item : queryResponse.items()) { - final AttributeValue shardIdAttr = item.get("shardId"); + final AttributeValue shardIdAttr = item.get(CheckpointTableUtils.ATTR_SHARD_ID); if (shardIdAttr == null) { continue; } @@ -545,8 +545,8 @@ private BigInteger persistIfHigher(final String shardId, final BigInteger checkp private Map checkpointKey(final String shardId) { return Map.of( - "streamName", AttributeValue.builder().s(streamName).build(), - "shardId", AttributeValue.builder().s(shardId).build()); + CheckpointTableUtils.ATTR_STREAM_NAME, AttributeValue.builder().s(streamName).build(), + CheckpointTableUtils.ATTR_SHARD_ID, AttributeValue.builder().s(shardId).build()); } private void releaseLease(final String shardId) { diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java index aed1bb553347..e21dd3c7ccb2 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java @@ -65,8 +65,6 @@ final class PollingKinesisClient extends KinesisConsumerClient { private final Semaphore queuePermits = new Semaphore(MAX_QUEUED_RESULTS, true); private final long emptyShardBackoffNanos; private final long errorBackoffNanos; - private volatile Instant timestampForInitialPosition; - PollingKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger) { this(kinesisClient, logger, DEFAULT_EMPTY_SHARD_BACKOFF_NANOS, DEFAULT_ERROR_BACKOFF_NANOS); } @@ -78,10 +76,6 @@ final class PollingKinesisClient extends KinesisConsumerClient { this.errorBackoffNanos = errorBackoffNanos; } - void setTimestampForInitialPosition(final Instant timestamp) { - this.timestampForInitialPosition = timestamp; - } - @Override void startFetches(final List shards, final String streamName, final int batchSize, final String initialStreamPosition, final KinesisShardManager shardManager) { @@ -99,7 +93,7 @@ void startFetches(final List shards, final String streamName, final int b } } else if (!existing.isExhausted() && !existing.isStopped() && !existing.isLoopRunning() && existing.tryStartLoop()) { - logger.warn("Restarting dead fetch loop for shard {}", shardId); + logger.warn("Restarting dead fetch loop for stream [{}] shard [{}]", streamName, shardId); launchFetchLoop(existing, shardId, streamName, batchSize, initialStreamPosition, shardManager); } } @@ -205,7 +199,7 @@ private void launchFetchLoop(final PollingShardState state, final String shardId }); } catch (final RejectedExecutionException e) { state.markLoopStopped(); - logger.debug("Executor shut down; cannot start fetch loop for shard {}", shardId); + logger.debug("Executor shut down; cannot start fetch loop for stream [{}] shard [{}]", streamName, shardId); } } @@ -286,7 +280,7 @@ private void runFetchLoop(final PollingShardState state, final String shardId, } } catch (final Exception e) { if (!state.isStopped()) { - logger.error("Unexpected error in fetch loop for shard {}; will retry", shardId, e); + logger.warn("Unexpected error in fetch loop for shard [{}]; will retry", shardId, e); state.setIterator(null); sleepNanos(errorBackoffNanos); } @@ -385,7 +379,7 @@ private String getShardIterator(final PollingShardState state, final String stre } else { iteratorType = ShardIteratorType.fromValue(initialStreamPosition); startingSequenceNumber = null; - timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? timestampForInitialPosition : null; + timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? getTimestampForInitialPosition() : null; } logger.debug("Getting shard iterator for shard {} with type={}, startingSeq={}, timestamp={}", diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java similarity index 76% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java index c5ca7fe2a447..db4b580a4798 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregator.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java @@ -32,7 +32,7 @@ * *

KPL aggregation packs multiple user records into a single Kinesis record using a protobuf * envelope with a 4-byte magic header and a 16-byte MD5 trailer. Non-aggregated records - * pass through unchanged as a single {@link DeaggregatedRecord} with {@code subSequenceNumber=0}. + * pass through unchanged as a single {@link UserRecord} with {@code subSequenceNumber=0}. * *

If a record has the magic header but fails MD5 verification or protobuf parsing, it falls * back to passthrough to avoid data loss. @@ -44,7 +44,7 @@ * * @see KPL Aggregation Format */ -final class KplDeaggregator { +final class ProducerLibraryDeaggregator { static final byte[] KPL_MAGIC = {(byte) 0xF3, (byte) 0x89, (byte) 0x9A, (byte) 0xC2}; private static final int MD5_DIGEST_LENGTH = 16; @@ -58,7 +58,7 @@ final class KplDeaggregator { private static final int RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX = 2; private static final int RECORD_FIELD_DATA = 3; - private KplDeaggregator() { + private ProducerLibraryDeaggregator() { } /** @@ -69,15 +69,15 @@ private KplDeaggregator() { * @param records raw Kinesis records from the API * @return list of deaggregated records preserving original order */ - static List deaggregate(final String shardId, final List records) { - final List result = new ArrayList<>(); + static List deaggregate(final String shardId, final List records) { + final List result = new ArrayList<>(); for (final Record record : records) { deaggregateRecord(shardId, record, result); } return result; } - private static void deaggregateRecord(final String shardId, final Record record, final List out) { + private static void deaggregateRecord(final String shardId, final Record record, final List out) { final byte[] data = record.data().asByteArrayUnsafe(); if (!isAggregated(data)) { @@ -111,23 +111,27 @@ static boolean isAggregated(final byte[] data) { } private static boolean verifyMd5(final byte[] data, final int protobufOffset, final int protobufLength) { + final MessageDigest md5 = getMd5Digest(); + md5.update(data, protobufOffset, protobufLength); + final byte[] computed = md5.digest(); + final int md5Offset = protobufOffset + protobufLength; + return Arrays.equals(computed, 0, MD5_DIGEST_LENGTH, data, md5Offset, md5Offset + MD5_DIGEST_LENGTH); + } + + private static MessageDigest getMd5Digest() { try { - final MessageDigest md5 = MessageDigest.getInstance("MD5"); - md5.update(data, protobufOffset, protobufLength); - final byte[] computed = md5.digest(); - final int md5Offset = protobufOffset + protobufLength; - return Arrays.equals(computed, 0, MD5_DIGEST_LENGTH, data, md5Offset, md5Offset + MD5_DIGEST_LENGTH); + return MessageDigest.getInstance("MD5"); } catch (final NoSuchAlgorithmException e) { - return false; + throw new IllegalStateException("MD5 algorithm not available", e); } } private static void parseAggregatedRecord(final String shardId, final Record kinesisRecord, - final byte[] data, final int protobufOffset, final int protobufLength, final List out) throws Exception { + final byte[] data, final int protobufOffset, final int protobufLength, final List out) throws Exception { final List partitionKeyTable = new ArrayList<>(); - final List subRecordDataList = new ArrayList<>(); - final List subRecordPkIndexList = new ArrayList<>(); + final List subRecordData = new ArrayList<>(); + final List subRecordPartitionKeyIndexes = new ArrayList<>(); final CodedInputStream input = CodedInputStream.newInstance(data, protobufOffset, protobufLength); while (!input.isAtEnd()) { @@ -143,20 +147,20 @@ private static void parseAggregatedRecord(final String shardId, final Record kin case FIELD_RECORDS: final int length = input.readRawVarint32(); final int oldLimit = input.pushLimit(length); - int pkIndex = 0; - byte[] subData = new byte[0]; + int partitionKeyIndex = 0; + byte[] subRecordPayload = new byte[0]; while (!input.isAtEnd()) { final int innerTag = input.readTag(); final int innerField = WireFormat.getTagFieldNumber(innerTag); switch (innerField) { case RECORD_FIELD_PARTITION_KEY_INDEX: - pkIndex = (int) input.readUInt64(); + partitionKeyIndex = (int) input.readUInt64(); break; case RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX: input.readUInt64(); break; case RECORD_FIELD_DATA: - subData = input.readByteArray(); + subRecordPayload = input.readByteArray(); break; default: input.skipField(innerTag); @@ -164,8 +168,8 @@ private static void parseAggregatedRecord(final String shardId, final Record kin } } input.popLimit(oldLimit); - subRecordDataList.add(subData); - subRecordPkIndexList.add(pkIndex); + subRecordData.add(subRecordPayload); + subRecordPartitionKeyIndexes.add(partitionKeyIndex); break; default: input.skipField(tag); @@ -177,17 +181,18 @@ private static void parseAggregatedRecord(final String shardId, final Record kin final Instant arrival = kinesisRecord.approximateArrivalTimestamp(); final String fallbackPartitionKey = kinesisRecord.partitionKey(); - for (int i = 0; i < subRecordDataList.size(); i++) { - final int pkIdx = subRecordPkIndexList.get(i); - final String partitionKey = pkIdx < partitionKeyTable.size() - ? partitionKeyTable.get(pkIdx) - : fallbackPartitionKey; - out.add(new DeaggregatedRecord(shardId, sequenceNumber, i, partitionKey, subRecordDataList.get(i), arrival)); + for (int i = 0; i < subRecordData.size(); i++) { + final int partitionKeyTableIndex = subRecordPartitionKeyIndexes.get(i); + final String partitionKey = partitionKeyTableIndex < partitionKeyTable.size() + ? partitionKeyTable.get(partitionKeyTableIndex) : fallbackPartitionKey; + + final UserRecord record = new UserRecord(shardId, sequenceNumber, i, partitionKey, subRecordData.get(i), arrival); + out.add(record); } } - private static DeaggregatedRecord passthrough(final String shardId, final Record record, final byte[] data) { - return new DeaggregatedRecord( + private static UserRecord passthrough(final String shardId, final Record record, final byte[] data) { + return new UserRecord( shardId, record.sequenceNumber(), 0, diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java index 507f79e0d508..9b88d71a9a4a 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java @@ -19,7 +19,7 @@ import java.math.BigInteger; import java.util.List; -record ShardFetchResult(String shardId, List records, long millisBehindLatest) { +record ShardFetchResult(String shardId, List records, long millisBehindLatest) { BigInteger firstSequenceNumber() { return new BigInteger(records.getFirst().sequenceNumber()); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java similarity index 98% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java index 95290725881e..e44de3c506f1 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/DeaggregatedRecord.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java @@ -31,7 +31,7 @@ * @param data the user payload bytes * @param approximateArrivalTimestamp approximate time the enclosing record arrived at Kinesis */ -record DeaggregatedRecord( +record UserRecord( String shardId, String sequenceNumber, long subSequenceNumber, diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java index 5a3bd1279f00..22fc1505299b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java @@ -39,19 +39,29 @@ class CheckpointTableUtilsTest { + private static final String STREAM_NAME = "my-stream"; + private static final String SHARD_ID_1 = "shardId-0001"; + private static final String SHARD_ID_2 = "shardId-0002"; + private static final String SOURCE_TABLE = "source-table"; + private static final String DEST_TABLE = "dest-table"; + + private static AttributeValue str(final String value) { + return AttributeValue.builder().s(value).build(); + } + @Test void testCopyCheckpointItemsCopiesShardItems() { final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); final ComponentLog logger = mock(ComponentLog.class); final Map item = Map.of( - "streamName", AttributeValue.builder().s("my-stream").build(), - "shardId", AttributeValue.builder().s("shardId-0001").build(), - "sequenceNumber", AttributeValue.builder().s("12345").build()); + "streamName", str(STREAM_NAME), + "shardId", str(SHARD_ID_1), + "sequenceNumber", str("12345")); when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(item).build()); when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); - CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); verify(dynamoDb, times(1)).putItem(putCaptor.capture()); @@ -64,21 +74,21 @@ void testCopyCheckpointItemsSkipsNodeAndMigrationMarkers() { final ComponentLog logger = mock(ComponentLog.class); final Map nodeItem = Map.of( - "streamName", AttributeValue.builder().s("my-stream").build(), - "shardId", AttributeValue.builder().s("__node__#node-a").build()); + "streamName", str(STREAM_NAME), + "shardId", str("__node__#node-a")); final Map migrationMarkerItem = Map.of( - "streamName", AttributeValue.builder().s("my-stream").build(), - "shardId", AttributeValue.builder().s("__migration__").build()); + "streamName", str(STREAM_NAME), + "shardId", str("__migration__")); final Map shardItem = Map.of( - "streamName", AttributeValue.builder().s("my-stream").build(), - "shardId", AttributeValue.builder().s("shardId-0002").build(), - "sequenceNumber", AttributeValue.builder().s("67890").build()); + "streamName", str(STREAM_NAME), + "shardId", str(SHARD_ID_2), + "sequenceNumber", str("67890")); when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( ScanResponse.builder().items(List.of(nodeItem, migrationMarkerItem, shardItem)).build()); when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); - CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); verify(dynamoDb, times(1)).putItem(putCaptor.capture()); @@ -91,13 +101,13 @@ void testCopyCheckpointItemsSkipsAllMarkers() { final ComponentLog logger = mock(ComponentLog.class); final Map nodeItem = Map.of( - "streamName", AttributeValue.builder().s("my-stream").build(), - "shardId", AttributeValue.builder().s("__node__#node-b").build()); + "streamName", str(STREAM_NAME), + "shardId", str("__node__#node-b")); when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( ScanResponse.builder().items(List.of(nodeItem)).build()); - CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, "source-table", "dest-table"); + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); verify(dynamoDb, never()).putItem(any(PutItemRequest.class)); } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java index 900da3bc101c..28e833222083 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java @@ -615,7 +615,7 @@ void testKplMultipleAggregatedRecords() throws Exception { runUntilOutput(runner); final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); - assertEquals(15, flowFiles.size(), "3 aggregated records x 5 sub-records each"); + assertEquals(15, flowFiles.size()); final Set contents = new HashSet<>(); for (final MockFlowFile ff : flowFiles) { @@ -837,12 +837,12 @@ private TestRunner createConfiguredRunner(final String streamName, final String /** * Runs the processor with retries until at least one FlowFile appears on any output relationship - * (success or parse failure), or a 10-second deadline is reached. This guards against + * (success or parse failure), or a 30-second deadline is reached. This guards against * timing-sensitive tests on slow systems where LocalStack may not propagate records within a * single 200ms batch window. */ private void runUntilOutput(final TestRunner testRunner) throws InterruptedException { - final long deadline = System.currentTimeMillis() + 10_000; + final long deadline = System.currentTimeMillis() + 30_000; testRunner.run(1, false, true); while (!hasOutput(testRunner) && System.currentTimeMillis() < deadline) { Thread.sleep(200); @@ -913,7 +913,7 @@ private void publishAggregatedRecord(final String streamName, final String outer try { final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); final ByteArrayOutputStream out = new ByteArrayOutputStream(); - out.write(KplDeaggregator.KPL_MAGIC); + out.write(ProducerLibraryDeaggregator.KPL_MAGIC); out.write(protobufBytes); out.write(md5); payload = out.toByteArray(); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java index 773a96389904..351f161ff6c0 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java @@ -96,7 +96,7 @@ void testProcessingStrategyValidation() throws Exception { @Test void testAllValidRecordsRoutedToSuccess() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "{\"name\":\"Charlie\"}")); @@ -121,7 +121,7 @@ void testSingleInvalidRecordRoutedToParseFailure() throws Exception { @Test void testMultipleInvalidRecordsInBatch() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "BAD FIRST"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "BAD THIRD"), @@ -148,7 +148,7 @@ void testMultipleInvalidRecordsInBatch() throws Exception { @Test void testAllInvalidRecordsRoutedToParseFailure() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "BAD1"), testRecord("2", "BAD2"), testRecord("3", "BAD3")); @@ -161,7 +161,7 @@ void testAllInvalidRecordsRoutedToParseFailure() throws Exception { @Test void testFlowFilePerRecordDeliversAllRecords() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "record-one"), testRecord("2", "record-two"), testRecord("3", "record-three")); @@ -190,7 +190,7 @@ void testFlowFilePerRecordDeliversAllRecords() throws Exception { @Test void testDemarcatorDeliversAllRecords() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "line-one"), testRecord("2", "line-two"), testRecord("3", "line-three")); @@ -209,7 +209,7 @@ void testDelimitedUsesLatestArrivalTimestamp() throws Exception { final Instant firstArrival = Instant.parse("2025-01-15T00:00:00Z"); final Instant secondArrival = Instant.parse("2025-01-15T00:00:05Z"); final Instant thirdArrival = Instant.parse("2025-01-15T00:00:03Z"); - final List records = List.of( + final List records = List.of( testRecord("1", "line-one", firstArrival), testRecord("2", "line-two", secondArrival), testRecord("3", "line-three", thirdArrival)); @@ -246,7 +246,7 @@ void testMultipleShardsNoDataLoss() throws Exception { @Test void testRecordMetadataInjectionPreservesRecordCount() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "{\"name\":\"Charlie\"}")); @@ -269,7 +269,7 @@ void testRecordMetadataInjectionPreservesRecordCount() throws Exception { @Test void testUseWrapperOutputStrategy() throws Exception { - final List records = List.of( + final List records = List.of( testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "{\"name\":\"Bob\"}")); @@ -336,9 +336,9 @@ void testDynamicRelationships() throws Exception { @Test void testEmptyRecordDoesNotCauseStuckState() throws Exception { - final DeaggregatedRecord emptyRecord = new DeaggregatedRecord("shardId-000000000001", "2", 0, "pk-2", new byte[0], Instant.now()); + final UserRecord emptyRecord = new UserRecord("shardId-000000000001", "2", 0, "pk-2", new byte[0], Instant.now()); - final List records = List.of( + final List records = List.of( testRecord("1", "{\"name\":\"Alice\"}"), emptyRecord, testRecord("3", "{\"name\":\"Charlie\"}")); @@ -351,7 +351,7 @@ void testEmptyRecordDoesNotCauseStuckState() throws Exception { success.assertAttributeEquals("record.count", "2"); } - private void triggerWithRecords(final List records) throws Exception { + private void triggerWithRecords(final List records) throws Exception { final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); @@ -382,7 +382,7 @@ private Set collectRelationshipNames() { } private void assertInvalidRecordAtPosition(final String expectedFailureSequence, final String expectedFailureContent, - final DeaggregatedRecord... records) throws Exception { + final UserRecord... records) throws Exception { triggerWithRecords(List.of(records)); runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); @@ -397,7 +397,7 @@ private void assertInvalidRecordAtPosition(final String expectedFailureSequence, assertNotNull(failure.getAttribute(ConsumeKinesis.ATTR_RECORD_ERROR_MESSAGE)); } - private void triggerWithOutputStrategy(final List records, final String outputStrategy) throws Exception { + private void triggerWithOutputStrategy(final List records, final String outputStrategy) throws Exception { final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); @@ -420,7 +420,7 @@ private void triggerWithOutputStrategy(final List records, f runner.run(); } - private void triggerWithStrategy(final List records, final String processingStrategy, + private void triggerWithStrategy(final List records, final String processingStrategy, final String shardId) throws Exception { final KinesisShardManager mockShardManager = buildShardManager(shardId); final ShardFetchResult fetchResult = new ShardFetchResult(shardId, records, 0L); @@ -471,12 +471,12 @@ private static KinesisShardManager buildShardManager(final String... shardIds) { return mockShardManager; } - private static DeaggregatedRecord testRecord(final String sequenceNumber, final String data) { + private static UserRecord testRecord(final String sequenceNumber, final String data) { return testRecord(sequenceNumber, data, Instant.now()); } - private static DeaggregatedRecord testRecord(final String sequenceNumber, final String data, final Instant arrivalTimestamp) { - return new DeaggregatedRecord( + private static UserRecord testRecord(final String sequenceNumber, final String data, final Instant arrivalTimestamp) { + return new UserRecord( "shardId-000000000001", sequenceNumber, 0, diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java index a773b88f8ee2..b7c9e102c0df 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java @@ -60,7 +60,7 @@ void testSubscriptionRenewalUsesLastAcknowledgedSequenceNumber() throws Exceptio when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("11111"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); @@ -71,7 +71,7 @@ void testSubscriptionRenewalUsesLastAcknowledgedSequenceNumber() throws Exceptio assertEquals("11111", capturedRequests.get(0).startingPosition().sequenceNumber(), "Initial subscription should use the DynamoDB checkpoint"); - final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); consumer.setLastQueuedSequenceNumber(new BigInteger("99999")); consumer.resetForRenewal(); @@ -95,13 +95,13 @@ void testSubscriptionRenewalFallsBackToCheckpointWhenNoQueuedData() throws Excep when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("55555"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); consumer.resetForRenewal(); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); @@ -123,12 +123,12 @@ void testSubscriptionRenewalUsesLastQueuedSequence() throws Exception { when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("10000"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); consumer.setLastQueuedSequenceNumber(new BigInteger("20000")); consumer.resetForRenewal(); @@ -152,7 +152,7 @@ void testSubscriptionRenewalAlwaysUsesMaxSequence() throws Exception { when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); @@ -184,7 +184,7 @@ void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws E when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); @@ -213,11 +213,11 @@ void testAcknowledgeResultsRequestsNextOncePerShard() throws Exception { when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); final List capturedRequests = new ArrayList<>(); - final EfoKinesisClient client = createEfoClient(capturedRequests); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); - final EfoKinesisClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); final Subscription subscription = mock(Subscription.class); consumer.setSubscription(subscription); consumer.pause(); @@ -248,7 +248,7 @@ void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws return new CompletableFuture<>(); }); - final EfoKinesisClient client = new EfoKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + final EnhancedFanOutClient client = new EnhancedFanOutClient(mock(KinesisClient.class), mock(ComponentLog.class)); client.initializeForTest(mockAsyncClient, "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); @@ -266,7 +266,7 @@ void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws "Concurrent startup should create only one initial SubscribeToShard request per shard"); } - private static EfoKinesisClient createEfoClient(final List capturedRequests) { + private static EnhancedFanOutClient createEfoClient(final List capturedRequests) { final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) @@ -275,16 +275,16 @@ private static EfoKinesisClient createEfoClient(final List { }, new ConcurrentLinkedQueue<>(), mockLogger); + final EnhancedFanOutClient.ShardConsumer consumer = + new EnhancedFanOutClient.ShardConsumer("shardId-000000000001", result -> { }, new ConcurrentLinkedQueue<>(), mockLogger); final StartingPosition pos = StartingPosition.builder() .type(ShardIteratorType.TRIM_HORIZON) @@ -351,7 +351,7 @@ void testStaleErrorCallbackDoesNotCorruptNewSubscription() throws Exception { } private static ShardFetchResult shardFetchResult(final String shardId, final String sequenceNumber) { - final DeaggregatedRecord record = new DeaggregatedRecord(shardId, sequenceNumber, 0, "pk", "{}".getBytes(), null); + final UserRecord record = new UserRecord(shardId, sequenceNumber, 0, "pk", "{}".getBytes(), null); return new ShardFetchResult(shardId, List.of(record), 0L); } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java similarity index 84% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java index 9d3d976ad01d..5e12d358a9b5 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KplDeaggregatorTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java @@ -37,7 +37,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -class KplDeaggregatorTest { +class ProducerLibraryDeaggregatorTest { private static final Instant ARRIVAL = Instant.parse("2025-06-15T12:00:00Z"); private static final String TEST_SHARD_ID = "shardId-000000000000"; @@ -47,10 +47,10 @@ void testNonAggregatedPassthrough() { final byte[] payload = "hello".getBytes(StandardCharsets.UTF_8); final Record record = buildKinesisRecord("seq-001", "pk-1", payload); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); assertEquals(1, result.size()); - final DeaggregatedRecord dr = result.getFirst(); + final UserRecord dr = result.getFirst(); assertEquals(TEST_SHARD_ID, dr.shardId()); assertEquals("seq-001", dr.sequenceNumber()); assertEquals(0, dr.subSequenceNumber()); @@ -66,10 +66,10 @@ void testSingleSubRecord() throws Exception { List.of(new SubRecord(0, "data-A".getBytes(StandardCharsets.UTF_8)))); final Record record = buildKinesisRecord("seq-100", "agg-pk", aggregated); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); assertEquals(1, result.size()); - final DeaggregatedRecord dr = result.getFirst(); + final UserRecord dr = result.getFirst(); assertEquals("seq-100", dr.sequenceNumber()); assertEquals(0, dr.subSequenceNumber()); assertEquals("pk-A", dr.partitionKey()); @@ -86,7 +86,7 @@ void testMultipleSubRecords() throws Exception { new SubRecord(0, "third".getBytes(StandardCharsets.UTF_8)))); final Record record = buildKinesisRecord("seq-200", "agg-pk", aggregated); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); assertEquals(3, result.size()); @@ -102,7 +102,7 @@ void testMultipleSubRecords() throws Exception { assertEquals(2, result.get(2).subSequenceNumber()); assertArrayEquals("third".getBytes(StandardCharsets.UTF_8), result.get(2).data()); - for (final DeaggregatedRecord dr : result) { + for (final UserRecord dr : result) { assertEquals("seq-200", dr.sequenceNumber()); assertEquals(ARRIVAL, dr.approximateArrivalTimestamp()); } @@ -118,7 +118,7 @@ void testMixedAggregatedAndNonAggregated() throws Exception { List.of(new SubRecord(0, "agg-data".getBytes(StandardCharsets.UTF_8)))); final Record aggRecord = buildKinesisRecord("seq-002", "pk-outer", aggregated); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); assertEquals(2, result.size()); assertEquals("seq-001", result.get(0).sequenceNumber()); @@ -129,20 +129,20 @@ void testMixedAggregatedAndNonAggregated() throws Exception { @Test void testCorruptedProtobufFallsBackToPassthrough() { - final byte[] corrupted = new byte[KplDeaggregator.KPL_MAGIC.length + 20 + 16]; - System.arraycopy(KplDeaggregator.KPL_MAGIC, 0, corrupted, 0, KplDeaggregator.KPL_MAGIC.length); + final byte[] corrupted = new byte[ProducerLibraryDeaggregator.KPL_MAGIC.length + 20 + 16]; + System.arraycopy(ProducerLibraryDeaggregator.KPL_MAGIC, 0, corrupted, 0, ProducerLibraryDeaggregator.KPL_MAGIC.length); final byte[] protobufPart = new byte[20]; protobufPart[0] = (byte) 0xFF; - System.arraycopy(protobufPart, 0, corrupted, KplDeaggregator.KPL_MAGIC.length, 20); + System.arraycopy(protobufPart, 0, corrupted, ProducerLibraryDeaggregator.KPL_MAGIC.length, 20); try { final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufPart); - System.arraycopy(md5, 0, corrupted, KplDeaggregator.KPL_MAGIC.length + 20, 16); + System.arraycopy(md5, 0, corrupted, ProducerLibraryDeaggregator.KPL_MAGIC.length + 20, 16); } catch (final Exception e) { throw new RuntimeException(e); } final Record record = buildKinesisRecord("seq-bad", "pk-bad", corrupted); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); assertEquals(1, result.size()); assertEquals("seq-bad", result.get(0).sequenceNumber()); @@ -159,7 +159,7 @@ void testMd5MismatchFallsBackToPassthrough() throws Exception { aggregated[aggregated.length - 1] ^= 0xFF; final Record record = buildKinesisRecord("seq-md5", "pk-md5", aggregated); - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); assertEquals(1, result.size()); assertEquals(0, result.get(0).subSequenceNumber()); @@ -168,24 +168,24 @@ void testMd5MismatchFallsBackToPassthrough() throws Exception { @Test void testIsAggregatedDetection() { - assertFalse(KplDeaggregator.isAggregated(new byte[0])); - assertFalse(KplDeaggregator.isAggregated(new byte[]{0x01, 0x02})); - assertFalse(KplDeaggregator.isAggregated("regular data".getBytes(StandardCharsets.UTF_8))); + assertFalse(ProducerLibraryDeaggregator.isAggregated(new byte[0])); + assertFalse(ProducerLibraryDeaggregator.isAggregated(new byte[]{0x01, 0x02})); + assertFalse(ProducerLibraryDeaggregator.isAggregated("regular data".getBytes(StandardCharsets.UTF_8))); - final byte[] withMagic = new byte[KplDeaggregator.KPL_MAGIC.length + 16 + 1]; - System.arraycopy(KplDeaggregator.KPL_MAGIC, 0, withMagic, 0, KplDeaggregator.KPL_MAGIC.length); - assertTrue(KplDeaggregator.isAggregated(withMagic)); + final byte[] withMagic = new byte[ProducerLibraryDeaggregator.KPL_MAGIC.length + 16 + 1]; + System.arraycopy(ProducerLibraryDeaggregator.KPL_MAGIC, 0, withMagic, 0, ProducerLibraryDeaggregator.KPL_MAGIC.length); + assertTrue(ProducerLibraryDeaggregator.isAggregated(withMagic)); } @Test void testEmptyRecordList() { - final List result = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of()); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of()); assertTrue(result.isEmpty()); } // ---- KCL cross-validation tests ---- // These tests use the KCL's own protobuf Messages class to create aggregated records, - // then verify that our KplDeaggregator produces the same results as the KCL's AggregatorUtil. + // then verify that our ProducerLibraryDeaggregator produces the same results as the KCL's AggregatorUtil. @Test void testKclAggregatedSingleRecord() { @@ -198,7 +198,7 @@ void testKclAggregatedSingleRecord() { final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); final Record kinesisRecord = buildKinesisRecord("seq-kcl-1", "outer-pk", payload); - final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); final List kclResult = deaggregateViaKcl(kinesisRecord); assertEquals(1, ourResult.size()); @@ -227,7 +227,7 @@ void testKclAggregatedMultipleRecords() { final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); final Record kinesisRecord = buildKinesisRecord("seq-kcl-multi", "outer-pk", payload); - final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); final List kclResult = deaggregateViaKcl(kinesisRecord); assertEquals(3, ourResult.size()); @@ -264,7 +264,7 @@ void testKclAggregatedMixedWithPlainRecords() { final Record aggRecord = buildKinesisRecord("seq-agg", "outer-pk", wrapAsKplPayload(aggProto.toByteArray())); - final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); final List kclPlain = deaggregateViaKcl(plainRecord); final List kclAgg = deaggregateViaKcl(aggRecord); @@ -290,7 +290,7 @@ void testKclAggregatedWithExplicitHashKeys() { final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); final Record kinesisRecord = buildKinesisRecord("seq-ehk", "outer-pk", payload); - final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); final List kclResult = deaggregateViaKcl(kinesisRecord); assertEquals(1, ourResult.size()); @@ -311,7 +311,7 @@ void testKclAggregatedLargeBatch() { final byte[] payload = wrapAsKplPayload(builder.build().toByteArray()); final Record kinesisRecord = buildKinesisRecord("seq-batch", "outer-pk", payload); - final List ourResult = KplDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); final List kclResult = deaggregateViaKcl(kinesisRecord); assertEquals(100, ourResult.size()); @@ -347,7 +347,7 @@ private static byte[] wrapAsKplPayload(final byte[] protobufBytes) { try { final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); final ByteArrayOutputStream result = new ByteArrayOutputStream(); - result.write(KplDeaggregator.KPL_MAGIC); + result.write(ProducerLibraryDeaggregator.KPL_MAGIC); result.write(protobufBytes); result.write(md5); return result.toByteArray(); @@ -367,7 +367,7 @@ private static List deaggregateViaKcl(final Record record) return new AggregatorUtil().deaggregate(List.of(kcr)); } - private static void assertDeaggregatedMatchesKcl(final DeaggregatedRecord ours, final KinesisClientRecord kcl) { + private static void assertDeaggregatedMatchesKcl(final UserRecord ours, final KinesisClientRecord kcl) { assertEquals(kcl.sequenceNumber(), ours.sequenceNumber(), "sequence number mismatch"); assertEquals(kcl.subSequenceNumber(), ours.subSequenceNumber(), "sub-sequence number mismatch"); assertEquals(kcl.partitionKey(), ours.partitionKey(), "partition key mismatch"); From 8812febc98773525c3e04cc8ab63e7ad1f5ebb24 Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Fri, 13 Mar 2026 15:41:57 -0400 Subject: [PATCH 6/7] NIFI-15669: Added apache5-client to api-nar; some very minor code and logging cleanup --- .../aws/kinesis/EnhancedFanOutClient.java | 13 +++++++------ .../nifi-aws-service-api-nar/pom.xml | 7 +++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java index 50b7397b33fe..70047be3ea8b 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java @@ -205,10 +205,10 @@ void logDiagnostics(final int ownedCount, final int cachedShardCount) { int activeSubscriptions = 0; int expiredSubscriptions = 0; int backedOff = 0; - for (final ShardConsumer sc : shardConsumers.values()) { - if (sc.isSubscriptionExpired()) { + for (final ShardConsumer shardConsumer : shardConsumers.values()) { + if (shardConsumer.isSubscriptionExpired()) { expiredSubscriptions++; - final long lastAttempt = sc.getLastSubscribeAttemptNanos(); + final long lastAttempt = shardConsumer.getLastSubscribeAttemptNanos(); if (lastAttempt > 0 && now < lastAttempt + SUBSCRIBE_BACKOFF_NANOS) { backedOff++; } @@ -218,7 +218,7 @@ void logDiagnostics(final int ownedCount, final int cachedShardCount) { } final int queueDepth = totalQueuedResults(); - logger.debug("Kinesis Enhanced Fan-Out diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, activeSubscriptions={}, expiredSubscriptions={}, backedOff={}", + logger.debug("Kinesis Enhanced Fan-Out diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, active={}, expired={}, backedOff={}", cachedShardCount, ownedCount, queueDepth, MAX_QUEUED_RESULTS, shardConsumers.size(), activeSubscriptions, expiredSubscriptions, backedOff); } @@ -467,9 +467,9 @@ private void logSubscriptionError(final Throwable t) { if (isCancellation(t)) { consumerLogger.debug("Enhanced Fan-Out subscription cancelled for shard [{}]", shardId); } else if (isRetryableSubscriptionError(t)) { - consumerLogger.info("Enhanced Fan-Out subscription temporarily rejected for shard [{}]; will retry after backoff", shardId); + consumerLogger.warn("Enhanced Fan-Out subscription temporarily rejected for shard [{}]; will retry after backoff", shardId, t); } else if (isRetryableStreamDisconnect(t)) { - consumerLogger.info("Enhanced Fan-Out subscription disconnected for shard [{}]; will retry", shardId); + consumerLogger.warn("Enhanced Fan-Out subscription disconnected for shard [{}]; will retry", shardId, t); } else { consumerLogger.error("Enhanced Fan-Out subscription error for shard [{}]", shardId, t); } @@ -555,6 +555,7 @@ private class DemandDrivenSubscriber implements Subscribernifi-aws-service-api compile + + + software.amazon.awssdk + apache5-client + ${software.amazon.awssdk.version} + compile + software.amazon.awssdk From 001227cf903e542b25bcf47842c3fcb866109bad Mon Sep 17 00:00:00 2001 From: Mark Payne Date: Mon, 16 Mar 2026 14:23:07 -0400 Subject: [PATCH 7/7] NIFI-15669: Added an explicit exclusion for apache httpclient version 4 --- .../nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml index 2ab90ef73abf..0ca22a89dffd 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml @@ -241,6 +241,12 @@ org.apache.nifi nifi-aws-kinesis ${project.version} + + + software.amazon.awssdk + apache-client + + org.apache.nifi