diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml index 838404b604d5..0ca22a89dffd 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis-nar/pom.xml @@ -98,7 +98,7 @@ software.amazon.awssdk - apache-client + apache5-client ${software.amazon.awssdk.version} provided @@ -241,6 +241,12 @@ org.apache.nifi nifi-aws-kinesis ${project.version} + + + software.amazon.awssdk + apache-client + + org.apache.nifi diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml index 97e4918a0b80..3b03857ff839 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/pom.xml @@ -37,10 +37,6 @@ org.apache.nifi nifi-aws-service-api - - org.apache.nifi - nifi-proxy-configuration-api - org.apache.nifi nifi-migration-utils @@ -50,11 +46,44 @@ org.apache.nifi nifi-record-serialization-service-api + + org.apache.nifi + nifi-proxy-configuration-api + + + software.amazon.awssdk + apache5-client + + + software.amazon.awssdk + kinesis + + + commons-logging + commons-logging + + + + + software.amazon.awssdk + dynamodb + + + software.amazon.awssdk + netty-nio-client + + + com.google.protobuf + protobuf-java + + + software.amazon.kinesis amazon-kinesis-client 3.4.1 + test com.google.code.findbugs @@ -78,12 +107,6 @@ - - software.amazon.awssdk - netty-nio-client - - - org.apache.nifi nifi-aws-processors @@ -116,6 +139,16 @@ nifi-ssl-context-service-api test + + org.mockito + mockito-core + test + + + org.testcontainers + testcontainers-junit-jupiter + test + org.testcontainers testcontainers-localstack diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java new file mode 100644 index 000000000000..633d9deb6dbc --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtils.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.BillingMode; +import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.ResourceInUseException; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.TableStatus; + +import java.util.List; +import java.util.Map; + +/** + * Shared DynamoDB table lifecycle operations for checkpoint tables. Used by both + * {@link KinesisShardManager} for runtime table management and + * {@link LegacyCheckpointMigrator} for migration and rename operations. + */ +final class CheckpointTableUtils { + + static final String ATTR_STREAM_NAME = "streamName"; + static final String ATTR_SHARD_ID = "shardId"; + static final String NODE_HEARTBEAT_PREFIX = "__node__#"; + static final String MIGRATION_MARKER_SHARD_ID = "__migration__"; + + private static final long TABLE_POLL_MILLIS = 1_000; + private static final int TABLE_POLL_MAX_ATTEMPTS = 60; + + private CheckpointTableUtils() { } + + enum TableSchema { + NEW, + LEGACY, + UNKNOWN, + NOT_FOUND + } + + static TableSchema getTableSchema(final DynamoDbClient client, final String tableName) { + try { + final DescribeTableResponse describe = client.describeTable(DescribeTableRequest.builder().tableName(tableName).build()); + final List keySchema = describe.table().keySchema(); + if (keySchema.size() == 2 + && hasKey(keySchema, ATTR_STREAM_NAME, KeyType.HASH) + && hasKey(keySchema, ATTR_SHARD_ID, KeyType.RANGE)) { + return TableSchema.NEW; + } + + if (keySchema.size() == 1 && hasKey(keySchema, "leaseKey", KeyType.HASH)) { + return TableSchema.LEGACY; + } + + return TableSchema.UNKNOWN; + } catch (final ResourceNotFoundException notFound) { + return TableSchema.NOT_FOUND; + } + } + + static void createNewSchemaTable(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final TableSchema tableSchema = getTableSchema(client, tableName); + if (tableSchema == TableSchema.NEW) { + logger.info("DynamoDB checkpoint table [{}] already exists", tableName); + return; + } + if (tableSchema == TableSchema.LEGACY || tableSchema == TableSchema.UNKNOWN) { + throw new ProcessException("Checkpoint table [%s] exists but does not match expected schema".formatted(tableName)); + } + + logger.info("Creating DynamoDB checkpoint table [{}]", tableName); + try { + final CreateTableRequest request = CreateTableRequest.builder() + .tableName(tableName) + .keySchema( + KeySchemaElement.builder().attributeName(ATTR_STREAM_NAME).keyType(KeyType.HASH).build(), + KeySchemaElement.builder().attributeName(ATTR_SHARD_ID).keyType(KeyType.RANGE).build()) + .attributeDefinitions( + AttributeDefinition.builder().attributeName(ATTR_STREAM_NAME).attributeType(ScalarAttributeType.S).build(), + AttributeDefinition.builder().attributeName(ATTR_SHARD_ID).attributeType(ScalarAttributeType.S).build()) + .billingMode(BillingMode.PAY_PER_REQUEST) + .build(); + + client.createTable(request); + } catch (final ResourceInUseException alreadyCreating) { + logger.info("DynamoDB checkpoint table [{}] is already being created by another node", tableName); + } + } + + static void waitForTableActive(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final DescribeTableRequest request = DescribeTableRequest.builder().tableName(tableName).build(); + for (int i = 0; i < TABLE_POLL_MAX_ATTEMPTS; i++) { + final TableStatus status = client.describeTable(request).table().tableStatus(); + if (status == TableStatus.ACTIVE) { + logger.info("DynamoDB checkpoint table [{}] is now ACTIVE", tableName); + return; + } + + try { + Thread.sleep(TABLE_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for DynamoDB table [%s] to become ACTIVE".formatted(tableName), e); + } + } + + throw new ProcessException("DynamoDB checkpoint table [%s] did not become ACTIVE within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + } + + static void deleteTable(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + try { + client.deleteTable(DeleteTableRequest.builder().tableName(tableName).build()); + logger.info("Initiated deletion of DynamoDB table [{}]", tableName); + } catch (final ResourceNotFoundException e) { + logger.debug("Table [{}] already deleted", tableName); + } + } + + static void waitForTableDeleted(final DynamoDbClient client, final ComponentLog logger, final String tableName) { + final DescribeTableRequest request = DescribeTableRequest.builder().tableName(tableName).build(); + for (int i = 0; i < TABLE_POLL_MAX_ATTEMPTS; i++) { + try { + client.describeTable(request); + } catch (final ResourceNotFoundException e) { + logger.info("DynamoDB table [{}] has been deleted", tableName); + return; + } + + try { + Thread.sleep(TABLE_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for DynamoDB table [%s] deletion".formatted(tableName), e); + } + } + + throw new ProcessException("DynamoDB table [%s] was not deleted within %d seconds".formatted(tableName, TABLE_POLL_MAX_ATTEMPTS)); + } + + static void copyCheckpointItems(final DynamoDbClient client, final ComponentLog logger, + final String sourceTableName, final String destTableName) { + logger.info("Copying checkpoint items from [{}] to [{}]", sourceTableName, destTableName); + + Map exclusiveStartKey = null; + int copied = 0; + do { + final ScanRequest scanRequest = exclusiveStartKey == null + ? ScanRequest.builder().tableName(sourceTableName).build() + : ScanRequest.builder().tableName(sourceTableName).exclusiveStartKey(exclusiveStartKey).build(); + final ScanResponse scanResponse = client.scan(scanRequest); + + for (final Map item : scanResponse.items()) { + final AttributeValue shardIdAttr = item.get(ATTR_SHARD_ID); + if (shardIdAttr != null) { + final String shardId = shardIdAttr.s(); + if (shardId.startsWith(NODE_HEARTBEAT_PREFIX) + || MIGRATION_MARKER_SHARD_ID.equals(shardId)) { + continue; + } + } + + client.putItem(PutItemRequest.builder() + .tableName(destTableName) + .item(item) + .build()); + copied++; + } + + exclusiveStartKey = scanResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + logger.info("Copied {} checkpoint item(s) from [{}] to [{}]", copied, sourceTableName, destTableName); + } + + private static boolean hasKey(final List keySchema, final String keyName, final KeyType keyType) { + for (final KeySchemaElement element : keySchema) { + if (keyName.equals(element.attributeName()) && keyType == element.keyType()) { + return true; + } + } + return false; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java index 52044ee0f478..aab3b702453d 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesis.java @@ -16,7 +16,6 @@ */ package org.apache.nifi.processors.aws.kinesis; -import jakarta.annotation.Nullable; import org.apache.nifi.annotation.behavior.InputRequirement; import org.apache.nifi.annotation.behavior.SystemResource; import org.apache.nifi.annotation.behavior.SystemResourceConsideration; @@ -25,6 +24,7 @@ import org.apache.nifi.annotation.configuration.DefaultSettings; import org.apache.nifi.annotation.documentation.CapabilityDescription; import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnRemoved; import org.apache.nifi.annotation.lifecycle.OnScheduled; import org.apache.nifi.annotation.lifecycle.OnStopped; import org.apache.nifi.components.DescribedValue; @@ -32,6 +32,7 @@ import org.apache.nifi.components.Validator; import org.apache.nifi.controller.NodeTypeProvider; import org.apache.nifi.flowfile.FlowFile; +import org.apache.nifi.flowfile.attributes.CoreAttributes; import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.migration.PropertyConfiguration; import org.apache.nifi.migration.ProxyServiceMigration; @@ -41,143 +42,134 @@ import org.apache.nifi.processor.ProcessSession; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; +import org.apache.nifi.processor.io.OutputStreamCallback; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.processors.aws.credentials.provider.AwsCredentialsProviderService; -import org.apache.nifi.processors.aws.kinesis.MemoryBoundRecordBuffer.Lease; -import org.apache.nifi.processors.aws.kinesis.ReaderRecordProcessor.ProcessingResult; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.apache.nifi.processors.aws.kinesis.converter.InjectMetadataRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.ValueRecordConverter; -import org.apache.nifi.processors.aws.kinesis.converter.WrapperRecordConverter; import org.apache.nifi.processors.aws.region.RegionUtil; import org.apache.nifi.proxy.ProxyConfiguration; -import org.apache.nifi.proxy.ProxyConfigurationService; import org.apache.nifi.proxy.ProxySpec; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.MalformedRecordException; +import org.apache.nifi.serialization.RecordReader; import org.apache.nifi.serialization.RecordReaderFactory; +import org.apache.nifi.serialization.RecordSetWriter; import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.SimpleRecordSchema; +import org.apache.nifi.serialization.record.MapRecord; +import org.apache.nifi.serialization.record.RecordField; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.RecordSchema; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.http.Protocol; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.apache5.Apache5HttpClient; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; -import software.amazon.awssdk.http.nio.netty.Http2Configuration; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; -import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbClientBuilder; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder; -import software.amazon.kinesis.common.ConfigsBuilder; -import software.amazon.kinesis.common.InitialPositionInStream; -import software.amazon.kinesis.common.InitialPositionInStreamExtended; -import software.amazon.kinesis.coordinator.Scheduler; -import software.amazon.kinesis.coordinator.WorkerStateChangeListener; -import software.amazon.kinesis.lifecycle.events.InitializationInput; -import software.amazon.kinesis.lifecycle.events.LeaseLostInput; -import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; -import software.amazon.kinesis.lifecycle.events.ShardEndedInput; -import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; -import software.amazon.kinesis.metrics.LogMetricsFactory; -import software.amazon.kinesis.metrics.MetricsFactory; -import software.amazon.kinesis.metrics.NullMetricsFactory; -import software.amazon.kinesis.processor.ShardRecordProcessor; -import software.amazon.kinesis.processor.ShardRecordProcessorFactory; -import software.amazon.kinesis.processor.SingleStreamTracker; -import software.amazon.kinesis.retrieval.KinesisClientRecord; -import software.amazon.kinesis.retrieval.RetrievalSpecificConfig; -import software.amazon.kinesis.retrieval.fanout.FanOutConfig; -import software.amazon.kinesis.retrieval.polling.PollingConfig; - +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.KinesisClientBuilder; +import software.amazon.awssdk.services.kinesis.model.DeregisterStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.math.BigInteger; +import java.net.Proxy; import java.net.URI; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; -import java.util.Date; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Future; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.NANOSECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.APPROXIMATE_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.PARTITION_KEY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.SHARD_ID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + import static org.apache.nifi.processors.aws.region.RegionUtil.CUSTOM_REGION; import static org.apache.nifi.processors.aws.region.RegionUtil.REGION; @InputRequirement(InputRequirement.Requirement.INPUT_FORBIDDEN) @Tags({"amazon", "aws", "kinesis", "consume", "stream", "record"}) @CapabilityDescription(""" - Consumes data from the specified AWS Kinesis stream and outputs a FlowFile for every processed Record (raw) - or a FlowFile for a batch of processed records if a Record Reader and Record Writer are configured. - The processor may take a few minutes on the first start and several seconds on subsequent starts - to initialize before starting to fetch data. - Uses DynamoDB for check pointing and coordination, and (optional) CloudWatch for metrics. - """) + Consumes records from an Amazon Kinesis Data Stream. Uses \ + DynamoDB-based checkpointing for reliable resumption after restarts. + + Note: when a shard is split or multiple shards are merged, this processor will consume from \ + child and parent shards concurrently. It does not wait for parent shards to be fully consumed \ + before reading child shards, so record ordering is not guaranteed across a split or merge \ + boundary.""") @WritesAttributes({ - @WritesAttribute(attribute = ConsumeKinesisAttributes.STREAM_NAME, - description = "The name of the Kinesis Stream from which all Kinesis Records in the FlowFile were read"), - @WritesAttribute(attribute = SHARD_ID, - description = "Shard ID from which all Kinesis Records in the FlowFile were read"), - @WritesAttribute(attribute = PARTITION_KEY, + @WritesAttribute(attribute = "aws.kinesis.stream.name", + description = "The name of the Kinesis Stream from which records were read"), + @WritesAttribute(attribute = "aws.kinesis.shard.id", + description = "Shard ID from which records were read"), + @WritesAttribute(attribute = "aws.kinesis.partition.key", description = "Partition key of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = FIRST_SEQUENCE_NUMBER, - description = "A Sequence Number of the first Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = FIRST_SUB_SEQUENCE_NUMBER, - description = "A SubSequence Number of the first Kinesis Record in the FlowFile. Generated by KPL when aggregating records into a single Kinesis Record"), - @WritesAttribute(attribute = LAST_SEQUENCE_NUMBER, - description = "A Sequence Number of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = LAST_SUB_SEQUENCE_NUMBER, - description = "A SubSequence Number of the last Kinesis Record in the FlowFile. Generated by KPL when aggregating records into a single Kinesis Record"), - @WritesAttribute(attribute = APPROXIMATE_ARRIVAL_TIMESTAMP, - description = "Approximate arrival timestamp of the last Kinesis Record in the FlowFile"), - @WritesAttribute(attribute = MIME_TYPE, + @WritesAttribute(attribute = "aws.kinesis.first.sequence.number", + description = "Sequence Number of the first Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.first.subsequence.number", + description = "Sub-Sequence Number of the first Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.last.sequence.number", + description = "Sequence Number of the last Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.last.subsequence.number", + description = "Sub-Sequence Number of the last Kinesis Record in the FlowFile"), + @WritesAttribute(attribute = "aws.kinesis.approximate.arrival.timestamp.ms", + description = "Approximate arrival timestamp associated with the Kinesis Record or records in the FlowFile"), + @WritesAttribute(attribute = "mime.type", description = "Sets the mime.type attribute to the MIME Type specified by the Record Writer (if configured)"), - @WritesAttribute(attribute = RECORD_COUNT, - description = "Number of records written to the FlowFiles by the Record Writer (if configured)"), - @WritesAttribute(attribute = RECORD_ERROR_MESSAGE, - description = "This attribute provides on failure the error message encountered by the Record Reader or Record Writer (if configured)") + @WritesAttribute(attribute = "record.count", + description = "Number of records written to the FlowFile"), + @WritesAttribute(attribute = "record.error.message", + description = "Error message encountered by the Record Reader or Record Writer (if configured)"), + @WritesAttribute(attribute = "kinesis.millis.behind", + description = "How far behind the stream tail we are, in milliseconds") }) @DefaultSettings(yieldDuration = "100 millis") @SystemResourceConsideration(resource = SystemResource.CPU, description = """ - The processor uses additional CPU resources when consuming data from Kinesis. - The consumption is started immediately after this Processor is scheduled. The consumption ends only when the Processor is stopped.""") + The processor uses additional CPU resources when consuming data from Kinesis.""") @SystemResourceConsideration(resource = SystemResource.NETWORK, description = """ - The processor will continually poll for new Records, - requesting up to a maximum number of Records/bytes per call. This can result in sustained network usage.""") + The processor will continually poll for new Records.""") @SystemResourceConsideration(resource = SystemResource.MEMORY, description = """ - ConsumeKinesis buffers Kinesis Records in memory until they can be processed. - The maximum size of the buffer is controlled by the 'Max Bytes to Buffer' property. - In addition, the processor may cache some amount of data for each shard when the processor's buffer is full.""") + Records are fetched from Kinesis in the background and buffered in memory until they \ + can be written to FlowFiles. Up to 200 fetch responses may be buffered per shard for \ + both Shared Throughput and Enhanced Fan-Out. Each Shared Throughput response can \ + contain up to the value of 'Max Records Per Request' records (default 100) and up \ + to 10 MB, so the theoretical maximum is 20,000 records or approximately 2 GB per \ + shard at default settings. Each Enhanced Fan-Out push event can hold up to roughly \ + 2 MB, for a theoretical maximum of approximately 400 MB per shard. In practice the \ + buffer is typically much smaller because fetch threads block when the queue is \ + full and most responses are well below the maximum size. + """) public class ConsumeKinesis extends AbstractProcessor { - private static final Duration HTTP_CLIENTS_CONNECTION_TIMEOUT = Duration.ofSeconds(30); - private static final Duration HTTP_CLIENTS_READ_TIMEOUT = Duration.ofMinutes(3); - - private static final int KINESIS_HTTP_CLIENT_WINDOW_SIZE_BYTES = 512 * 1024; // 512 KiB - private static final Duration KINESIS_HTTP_HEALTH_CHECK_PERIOD = Duration.ofMinutes(1); - - /** - * How long to wait for a Scheduler initialization to complete in the OnScheduled method. - * If the initialization takes longer than this, the processor will continue initialization checks in the onTrigger method. - */ - private static final Duration KINESIS_SCHEDULER_ON_SCHEDULED_INITIALIZATION_TIMEOUT = Duration.ofSeconds(30); - private static final Duration KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT = Duration.ofMinutes(3); + static final String ATTR_STREAM_NAME = "aws.kinesis.stream.name"; + static final String ATTR_SHARD_ID = "aws.kinesis.shard.id"; + static final String ATTR_FIRST_SEQUENCE = "aws.kinesis.first.sequence.number"; + static final String ATTR_LAST_SEQUENCE = "aws.kinesis.last.sequence.number"; + static final String ATTR_FIRST_SUBSEQUENCE = "aws.kinesis.first.subsequence.number"; + static final String ATTR_LAST_SUBSEQUENCE = "aws.kinesis.last.subsequence.number"; + static final String ATTR_PARTITION_KEY = "aws.kinesis.partition.key"; + static final String ATTR_ARRIVAL_TIMESTAMP = "aws.kinesis.approximate.arrival.timestamp.ms"; + static final String ATTR_MILLIS_BEHIND = "kinesis.millis.behind"; + static final String ATTR_RECORD_ERROR_MESSAGE = "record.error.message"; + + private static final long QUEUE_POLL_TIMEOUT_MILLIS = 100; + private static final Duration API_CALL_TIMEOUT = Duration.ofSeconds(30); + private static final Duration API_CALL_ATTEMPT_TIMEOUT = Duration.ofSeconds(10); + private static final byte[] NEWLINE_DELIMITER = new byte[] {'\n'}; + private static final String WRAPPER_VALUE_FIELD = "value"; static final PropertyDescriptor STREAM_NAME = new PropertyDescriptor.Builder() .name("Stream Name") @@ -188,17 +180,17 @@ public class ConsumeKinesis extends AbstractProcessor { static final PropertyDescriptor APPLICATION_NAME = new PropertyDescriptor.Builder() .name("Application Name") - .description("The name of the Kinesis application. This is used for DynamoDB table naming and worker coordination.") + .description(""" + The name of the Kinesis application. Used as the DynamoDB table name for checkpoint storage. \ + When Consumer Type is Enhanced Fan-Out, this value is also used as the registered consumer \ + name. This value should be unique for the stream.""") .required(true) .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) .build(); static final PropertyDescriptor AWS_CREDENTIALS_PROVIDER_SERVICE = new PropertyDescriptor.Builder() .name("AWS Credentials Provider Service") - .description(""" - The Controller Service that is used to obtain AWS credentials provider. - Ensure that the credentials provided have access to Kinesis, DynamoDB and (optional) CloudWatch. - """) + .description("The Controller Service used to obtain AWS credentials provider.") .required(true) .identifiesControllerService(AwsCredentialsProviderService.class) .build(); @@ -221,12 +213,7 @@ Ensure that the credentials provided have access to Kinesis, DynamoDB and (optio static final PropertyDescriptor RECORD_READER = new PropertyDescriptor.Builder() .name("Record Reader") - .description(""" - The Record Reader to use for parsing the data received from Kinesis. - - The Record Reader is responsible for providing schemas for the records. If the schemas change frequently, - it might hinder performance of the processor. - """) + .description("The Record Reader to use for parsing the data received from Kinesis.") .required(true) .dependsOn(PROCESSING_STRATEGY, ProcessingStrategy.RECORD) .identifiesControllerService(RecordReaderFactory.class) @@ -252,12 +239,9 @@ Ensure that the credentials provided have access to Kinesis, DynamoDB and (optio static final PropertyDescriptor MESSAGE_DEMARCATOR = new PropertyDescriptor.Builder() .name("Message Demarcator") .description(""" - Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kinesis messages - within a single FlowFile. If not specified, the content of the messages will be concatenated - without any delimiter. - To enter special character such as 'new line' use CTRL+Enter or Shift+Enter depending on the OS. - """) - .required(false) + Specifies the string (interpreted as UTF-8) used to separate multiple Kinesis messages \ + within a single FlowFile when Processing Strategy is DEMARCATOR.""") + .required(true) .addValidator(Validator.VALID) .dependsOn(PROCESSING_STRATEGY, ProcessingStrategy.DEMARCATOR) .build(); @@ -272,47 +256,50 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine static final PropertyDescriptor STREAM_POSITION_TIMESTAMP = new PropertyDescriptor.Builder() .name("Stream Position Timestamp") - .description("Timestamp position in stream from which to start reading Kinesis Records. The timestamp must be in ISO 8601 format.") + .description("Timestamp position in stream from which to start reading Kinesis Records. Must be in ISO 8601 format.") .required(true) .addValidator(StandardValidators.ISO8601_INSTANT_VALIDATOR) .dependsOn(INITIAL_STREAM_POSITION, InitialPosition.AT_TIMESTAMP) .build(); - static final PropertyDescriptor MAX_BYTES_TO_BUFFER = new PropertyDescriptor.Builder() - .name("Max Bytes to Buffer") - .description(""" - The maximum size of Kinesis Records that can be buffered in memory before being processed by NiFi. - If the buffer size exceeds the limit, the processor will stop consuming new records until free space is available. - - Using a larger value may increase the throughput, but will do so at the expense of using more memory. - """) + static final PropertyDescriptor MAX_RECORDS_PER_REQUEST = new PropertyDescriptor.Builder() + .name("Max Records Per Request") + .description("The maximum number of records to retrieve per GetRecords call. Maximum is 10,000.") .required(true) - .addValidator(StandardValidators.DATA_SIZE_VALIDATOR) - .defaultValue("100 MB") + .defaultValue("100") + .addValidator(StandardValidators.createLongValidator(1, 10000, true)) + .dependsOn(CONSUMER_TYPE, ConsumerType.SHARED_THROUGHPUT) .build(); - static final PropertyDescriptor CHECKPOINT_INTERVAL = new PropertyDescriptor.Builder() - .name("Checkpoint Interval") + static final PropertyDescriptor MAX_BATCH_DURATION = new PropertyDescriptor.Builder() + .name("Max Batch Duration") .description(""" - Interval between checkpointing consumed Kinesis records. To checkpoint records each time the Processor is run, set this value to 0 seconds. - - More frequent checkpoint may reduce performance and increase DynamoDB costs, - but less frequent checkpointing may result in duplicates when a Shard lease is lost or NiFi is restarted. - """) + The maximum amount of time to spend consuming records in a single invocation before \ + committing the session and checkpointing.""") .required(true) - .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) .defaultValue("5 sec") + .addValidator(StandardValidators.TIME_PERIOD_VALIDATOR) .build(); - static final PropertyDescriptor METRICS_PUBLISHING = new PropertyDescriptor.Builder() - .name("Metrics Publishing") - .description("Specifies where Kinesis usage metrics are published to.") + static final PropertyDescriptor MAX_BATCH_SIZE = new PropertyDescriptor.Builder() + .name("Max Batch Size") + .description(""" + The maximum amount of data to consume in a single invocation before committing the \ + session and checkpointing.""") .required(true) - .allowableValues(MetricsPublishing.class) - .defaultValue(MetricsPublishing.DISABLED) + .defaultValue("10 MB") + .addValidator(StandardValidators.DATA_SIZE_VALIDATOR) + .build(); + + static final PropertyDescriptor ENDPOINT_OVERRIDE = new PropertyDescriptor.Builder() + .name("Endpoint Override URL") + .description("An optional endpoint override URL for both the Kinesis and DynamoDB clients.") + .required(false) + .addValidator(StandardValidators.URL_VALIDATOR) .build(); - static final PropertyDescriptor PROXY_CONFIGURATION_SERVICE = ProxyConfiguration.createProxyConfigPropertyDescriptor(ProxySpec.HTTP, ProxySpec.HTTP_AUTH); + static final PropertyDescriptor PROXY_CONFIGURATION_SERVICE = + ProxyConfiguration.createProxyConfigPropertyDescriptor(ProxySpec.HTTP, ProxySpec.HTTP_AUTH); private static final List PROPERTY_DESCRIPTORS = List.of( STREAM_NAME, @@ -328,10 +315,11 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine MESSAGE_DEMARCATOR, INITIAL_STREAM_POSITION, STREAM_POSITION_TIMESTAMP, - MAX_BYTES_TO_BUFFER, - CHECKPOINT_INTERVAL, - PROXY_CONFIGURATION_SERVICE, - METRICS_PUBLISHING + MAX_RECORDS_PER_REQUEST, + MAX_BATCH_DURATION, + MAX_BATCH_SIZE, + ENDPOINT_OVERRIDE, + PROXY_CONFIGURATION_SERVICE ); static final Relationship REL_SUCCESS = new Relationship.Builder() @@ -347,38 +335,32 @@ Specifies the string (interpreted as UTF-8) to use for demarcating multiple Kine private static final Set RAW_FILE_RELATIONSHIPS = Set.of(REL_SUCCESS); private static final Set RECORD_FILE_RELATIONSHIPS = Set.of(REL_SUCCESS, REL_PARSE_FAILURE); - private volatile DynamoDbAsyncClient dynamoDbClient; - private volatile CloudWatchAsyncClient cloudWatchClient; - private volatile KinesisAsyncClient kinesisClient; - private volatile Scheduler kinesisScheduler; - + private volatile SdkHttpClient kinesisHttpClient; + private volatile SdkHttpClient dynamoHttpClient; + private volatile KinesisClient kinesisClient; + private volatile DynamoDbClient dynamoDbClient; + private volatile SdkAsyncHttpClient asyncHttpClient; + private volatile KinesisShardManager shardManager; + private volatile KinesisConsumerClient consumerClient; private volatile String streamName; - private volatile RecordBuffer.ForProcessor recordBuffer; - - private volatile @Nullable ReaderRecordProcessor readerRecordProcessor; - private volatile @Nullable byte[] demarcatorValue; + private volatile int maxRecordsPerRequest; + private volatile String initialStreamPosition; + private volatile long maxBatchNanos; + private volatile long maxBatchBytes; - private volatile Future initializationResultFuture; - private final AtomicBoolean initialized = new AtomicBoolean(); - - // An instance filed, so that it can be read in getRelationships. - private volatile ProcessingStrategy processingStrategy = ProcessingStrategy.from( - PROCESSING_STRATEGY.getDefaultValue()); + private volatile ProcessingStrategy processingStrategy = ProcessingStrategy.valueOf(PROCESSING_STRATEGY.getDefaultValue()); + private volatile String efoConsumerArn; + private final AtomicLong shardRoundRobinCounter = new AtomicLong(); @Override protected List getSupportedPropertyDescriptors() { return PROPERTY_DESCRIPTORS; } - @Override - public void migrateProperties(final PropertyConfiguration config) { - ProxyServiceMigration.renameProxyConfigurationServiceProperty(config); - } - @Override public Set getRelationships() { return switch (processingStrategy) { - case FLOW_FILE, DEMARCATOR -> RAW_FILE_RELATIONSHIPS; + case FLOW_FILE, LINE_DELIMITED, DEMARCATOR -> RAW_FILE_RELATIONSHIPS; case RECORD -> RECORD_FILE_RELATIONSHIPS; }; } @@ -386,506 +368,1166 @@ public Set getRelationships() { @Override public void onPropertyModified(final PropertyDescriptor descriptor, final String oldValue, final String newValue) { if (descriptor.equals(PROCESSING_STRATEGY)) { - processingStrategy = ProcessingStrategy.from(newValue); + processingStrategy = ProcessingStrategy.valueOf(newValue); } } - @OnScheduled - public void setup(final ProcessContext context) { - readerRecordProcessor = switch (processingStrategy) { - case FLOW_FILE, DEMARCATOR -> null; - case RECORD -> createReaderRecordProcessor(context); - }; - demarcatorValue = switch (processingStrategy) { - case FLOW_FILE, RECORD -> null; - case DEMARCATOR -> { - final String demarcatorValue = context.getProperty(MESSAGE_DEMARCATOR).getValue(); - yield demarcatorValue != null ? demarcatorValue.getBytes(UTF_8) : new byte[0]; - } - }; + @Override + public void migrateProperties(final PropertyConfiguration config) { + ProxyServiceMigration.renameProxyConfigurationServiceProperty(config); + config.renameProperty("Max Bytes to Buffer", "Max Batch Size"); + config.removeProperty("Checkpoint Interval"); + config.removeProperty("Metrics Publishing"); + } + @OnScheduled + public void onScheduled(final ProcessContext context) { final Region region = RegionUtil.getRegion(context); final AwsCredentialsProvider credentialsProvider = context.getProperty(AWS_CREDENTIALS_PROVIDER_SERVICE) .asControllerService(AwsCredentialsProviderService.class).getAwsCredentialsProvider(); + final String endpointOverride = context.getProperty(ENDPOINT_OVERRIDE).getValue(); - kinesisClient = KinesisAsyncClient.builder() - .region(region) - .credentialsProvider(credentialsProvider) - .endpointOverride(getKinesisEndpointOverride()) - .httpClient(createKinesisHttpClient(context)) + final ClientOverrideConfiguration clientConfig = ClientOverrideConfiguration.builder() + .apiCallTimeout(API_CALL_TIMEOUT) + .apiCallAttemptTimeout(API_CALL_ATTEMPT_TIMEOUT) .build(); - dynamoDbClient = DynamoDbAsyncClient.builder() + final KinesisClientBuilder kinesisBuilder = KinesisClient.builder() .region(region) .credentialsProvider(credentialsProvider) - .endpointOverride(getDynamoDbEndpointOverride()) - .httpClient(createHttpClientBuilder(context).build()) - .build(); + .overrideConfiguration(clientConfig); - cloudWatchClient = CloudWatchAsyncClient.builder() + final DynamoDbClientBuilder dynamoBuilder = DynamoDbClient.builder() .region(region) .credentialsProvider(credentialsProvider) - .endpointOverride(getCloudwatchEndpointOverride()) - .httpClient(createHttpClientBuilder(context).build()) - .build(); + .overrideConfiguration(clientConfig); + + if (endpointOverride != null && !endpointOverride.isEmpty()) { + final URI endpointUri = URI.create(endpointOverride); + kinesisBuilder.endpointOverride(endpointUri); + dynamoBuilder.endpointOverride(endpointUri); + } + + final ProxyConfiguration proxyConfig = ProxyConfiguration.getConfiguration(context); + kinesisHttpClient = buildApacheHttpClient(proxyConfig, PollingKinesisClient.MAX_CONCURRENT_FETCHES + 10); + dynamoHttpClient = buildApacheHttpClient(proxyConfig, 50); + kinesisBuilder.httpClient(kinesisHttpClient); + dynamoBuilder.httpClient(dynamoHttpClient); + + kinesisClient = kinesisBuilder.build(); + dynamoDbClient = dynamoBuilder.build(); + + final String checkpointTableName = context.getProperty(APPLICATION_NAME).getValue(); streamName = context.getProperty(STREAM_NAME).getValue(); - final InitialPositionInStreamExtended initialPositionExtended = getInitialPosition(context); - final SingleStreamTracker streamTracker = new SingleStreamTracker(streamName, initialPositionExtended); - - final long maxBytesToBuffer = context.getProperty(MAX_BYTES_TO_BUFFER).asDataSize(DataUnit.B).longValue(); - final Duration checkpointInterval = context.getProperty(CHECKPOINT_INTERVAL).asDuration(); - final MemoryBoundRecordBuffer memoryBoundRecordBuffer = new MemoryBoundRecordBuffer(getLogger(), maxBytesToBuffer, checkpointInterval); - recordBuffer = memoryBoundRecordBuffer; - final ShardRecordProcessorFactory recordProcessorFactory = () -> new ConsumeKinesisRecordProcessor(memoryBoundRecordBuffer); - - final String applicationName = context.getProperty(APPLICATION_NAME).getValue(); - final String workerId = generateWorkerId(); - final ConfigsBuilder configsBuilder = new ConfigsBuilder(streamTracker, applicationName, kinesisClient, dynamoDbClient, cloudWatchClient, workerId, recordProcessorFactory); - - final MetricsFactory metricsFactory = configureMetricsFactory(context); - final RetrievalSpecificConfig retrievalSpecificConfig = configureRetrievalSpecificConfig(context, kinesisClient, streamName, applicationName); - - final InitializationStateChangeListener initializationListener = new InitializationStateChangeListener(getLogger()); - initialized.set(false); - initializationResultFuture = initializationListener.result(); - - kinesisScheduler = new Scheduler( - configsBuilder.checkpointConfig(), - configsBuilder.coordinatorConfig().workerStateChangeListener(initializationListener), - configsBuilder.leaseManagementConfig(), - configsBuilder.lifecycleConfig(), - configsBuilder.metricsConfig().metricsFactory(metricsFactory), - configsBuilder.processorConfig(), - configsBuilder.retrievalConfig().retrievalSpecificConfig(retrievalSpecificConfig) - ); - - final String schedulerThreadName = "%s-Scheduler-%s".formatted(getClass().getSimpleName(), getIdentifier()); - final Thread schedulerThread = new Thread(kinesisScheduler, schedulerThreadName); - schedulerThread.setDaemon(true); - schedulerThread.start(); - // The thread is stopped when kinesisScheduler is shutdown in the onStopped method. + initialStreamPosition = context.getProperty(INITIAL_STREAM_POSITION).getValue(); + maxBatchNanos = context.getProperty(MAX_BATCH_DURATION).asTimePeriod(TimeUnit.NANOSECONDS); + maxBatchBytes = context.getProperty(MAX_BATCH_SIZE).asDataSize(DataUnit.B).longValue(); - try { - final InitializationResult result = initializationResultFuture.get( - KINESIS_SCHEDULER_ON_SCHEDULED_INITIALIZATION_TIMEOUT.getSeconds(), SECONDS); - checkInitializationResult(result); - } catch (final TimeoutException e) { - // During a first run the processor will take more time to initialize. We return from OnSchedule and continue waiting in the onTrigger method. - getLogger().warn("Kinesis Scheduler initialization may take up to 10 minutes on a first run, which is caused by AWS resources initialization"); - } catch (final InterruptedException | ExecutionException e) { - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - cleanUpState(); - throw new ProcessException("Initialization failed for stream [%s]".formatted(streamName), e); - } - } + final boolean efoMode = ConsumerType.ENHANCED_FAN_OUT.equals(context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class)); + maxRecordsPerRequest = efoMode ? 0 : context.getProperty(MAX_RECORDS_PER_REQUEST).asInteger(); - /** - * Creating Kinesis HTTP client, as per - * {@link software.amazon.kinesis.common.KinesisClientUtil#adjustKinesisClientBuilder(KinesisAsyncClientBuilder)}. - */ - private static SdkAsyncHttpClient createKinesisHttpClient(final ProcessContext context) { - return createHttpClientBuilder(context) - .protocol(Protocol.HTTP2) - // Since we're using HTTP/2, multiple concurrent requests will reuse the same HTTP connection. - // Therefore, the number of real connections is going to be relatively small. - .maxConcurrency(Integer.MAX_VALUE) - .http2Configuration(Http2Configuration.builder() - .initialWindowSize(KINESIS_HTTP_CLIENT_WINDOW_SIZE_BYTES) - .healthCheckPingPeriod(KINESIS_HTTP_HEALTH_CHECK_PERIOD) - .build()) - .build(); - } + shardManager = createShardManager(kinesisClient, dynamoDbClient, getLogger(), checkpointTableName, streamName); + shardManager.ensureCheckpointTableExists(); + consumerClient = createConsumerClient(kinesisClient, getLogger(), efoMode); - private static NettyNioAsyncHttpClient.Builder createHttpClientBuilder(final ProcessContext context) { - final NettyNioAsyncHttpClient.Builder builder = NettyNioAsyncHttpClient.builder() - .connectionTimeout(HTTP_CLIENTS_CONNECTION_TIMEOUT) - .readTimeout(HTTP_CLIENTS_READ_TIMEOUT); + final Instant timestampForPosition = resolveTimestampPosition(context); + if (timestampForPosition != null) { + consumerClient.setTimestampForInitialPosition(timestampForPosition); + } - final ProxyConfigurationService proxyConfigService = context.getProperty(PROXY_CONFIGURATION_SERVICE).asControllerService(ProxyConfigurationService.class); - if (proxyConfigService != null) { - final ProxyConfiguration proxyConfig = proxyConfigService.getConfiguration(); + if (efoMode) { + final NettyNioAsyncHttpClient.Builder nettyBuilder = NettyNioAsyncHttpClient.builder() + .protocol(Protocol.HTTP2) + .maxConcurrency(500) + .connectionAcquisitionTimeout(Duration.ofSeconds(60)); - final software.amazon.awssdk.http.nio.netty.ProxyConfiguration.Builder proxyConfigBuilder = software.amazon.awssdk.http.nio.netty.ProxyConfiguration.builder() + if (Proxy.Type.HTTP.equals(proxyConfig.getProxyType())) { + final software.amazon.awssdk.http.nio.netty.ProxyConfiguration.Builder nettyProxyBuilder = software.amazon.awssdk.http.nio.netty.ProxyConfiguration.builder() .host(proxyConfig.getProxyServerHost()) .port(proxyConfig.getProxyServerPort()); - if (proxyConfig.hasCredential()) { - proxyConfigBuilder.username(proxyConfig.getProxyUserName()); - proxyConfigBuilder.password(proxyConfig.getProxyUserPassword()); - } + if (proxyConfig.hasCredential()) { + nettyProxyBuilder.username(proxyConfig.getProxyUserName()); + nettyProxyBuilder.password(proxyConfig.getProxyUserPassword()); + } - builder.proxyConfiguration(proxyConfigBuilder.build()); - } + nettyBuilder.proxyConfiguration(nettyProxyBuilder.build()); + } - return builder; - } + asyncHttpClient = nettyBuilder.build(); - private ReaderRecordProcessor createReaderRecordProcessor(final ProcessContext context) { - final RecordReaderFactory recordReaderFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); - final RecordSetWriterFactory recordWriterFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final KinesisAsyncClientBuilder asyncBuilder = KinesisAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .httpClient(asyncHttpClient); - final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); - final KinesisRecordConverter converter = switch (outputStrategy) { - case USE_VALUE -> new ValueRecordConverter(); - case USE_WRAPPER -> new WrapperRecordConverter(); - case INJECT_METADATA -> new InjectMetadataRecordConverter(); - }; + if (endpointOverride != null && !endpointOverride.isEmpty()) { + asyncBuilder.endpointOverride(URI.create(endpointOverride)); + } - return new ReaderRecordProcessor(recordReaderFactory, converter, recordWriterFactory, getLogger()); + final String consumerName = context.getProperty(APPLICATION_NAME).getValue(); + consumerClient.initialize(asyncBuilder.build(), streamName, consumerName); + } } - private static InitialPositionInStreamExtended getInitialPosition(final ProcessContext context) { - final InitialPosition initialPosition = context.getProperty(INITIAL_STREAM_POSITION).asAllowableValue(InitialPosition.class); - return switch (initialPosition) { - case TRIM_HORIZON -> - InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.TRIM_HORIZON); - case LATEST -> InitialPositionInStreamExtended.newInitialPosition(InitialPositionInStream.LATEST); - case AT_TIMESTAMP -> { - final String timestampValue = context.getProperty(STREAM_POSITION_TIMESTAMP).getValue(); - final Instant timestamp = Instant.parse(timestampValue); - yield InitialPositionInStreamExtended.newInitialPositionAtTimestamp(Date.from(timestamp)); - } - }; + private static Instant resolveTimestampPosition(final ProcessContext context) { + final InitialPosition position = context.getProperty(INITIAL_STREAM_POSITION).asAllowableValue(InitialPosition.class); + if (position == InitialPosition.AT_TIMESTAMP) { + return Instant.parse(context.getProperty(STREAM_POSITION_TIMESTAMP).getValue()); + } + return null; } - private String generateWorkerId() { - final String processorId = getIdentifier(); - final NodeTypeProvider nodeTypeProvider = getNodeTypeProvider(); - - final String workerId; + /** + * Builds an {@link Apache5HttpClient} with the given connection pool size and optional proxy + * configuration. Each AWS service client (Kinesis, DynamoDB) should receive its own HTTP client + * so their connection pools are isolated and cannot starve each other under high shard counts. + */ + private static SdkHttpClient buildApacheHttpClient(final ProxyConfiguration proxyConfig, final int maxConnections) { + final Apache5HttpClient.Builder builder = Apache5HttpClient.builder() + .maxConnections(maxConnections); - if (nodeTypeProvider.isClustered()) { - // If a node id is not available for some reason, generating a random UUID helps to avoid collisions. - final String nodeId = nodeTypeProvider.getCurrentNode().orElse(UUID.randomUUID().toString()); - workerId = "%s@%s".formatted(processorId, nodeId); - } else { - workerId = processorId; - } + if (Proxy.Type.HTTP.equals(proxyConfig.getProxyType())) { + final URI proxyEndpoint = URI.create(String.format("http://%s:%s", proxyConfig.getProxyServerHost(), proxyConfig.getProxyServerPort())); + final software.amazon.awssdk.http.apache5.ProxyConfiguration.Builder proxyBuilder = + software.amazon.awssdk.http.apache5.ProxyConfiguration.builder().endpoint(proxyEndpoint); - return workerId; - } + if (proxyConfig.hasCredential()) { + proxyBuilder.username(proxyConfig.getProxyUserName()); + proxyBuilder.password(proxyConfig.getProxyUserPassword()); + } - private static @Nullable MetricsFactory configureMetricsFactory(final ProcessContext context) { - final MetricsPublishing metricsPublishing = context.getProperty(METRICS_PUBLISHING).asAllowableValue(MetricsPublishing.class); - return switch (metricsPublishing) { - case DISABLED -> new NullMetricsFactory(); - case LOGS -> new LogMetricsFactory(); - case CLOUDWATCH -> null; // If no metrics factory was provided, CloudWatch metrics factory is used by default. - }; - } + builder.proxyConfiguration(proxyBuilder.build()); + } - private static RetrievalSpecificConfig configureRetrievalSpecificConfig( - final ProcessContext context, - final KinesisAsyncClient kinesisClient, - final String streamName, - final String applicationName) { - final ConsumerType consumerType = context.getProperty(CONSUMER_TYPE).asAllowableValue(ConsumerType.class); - return switch (consumerType) { - case SHARED_THROUGHPUT -> new PollingConfig(kinesisClient).streamName(streamName); - case ENHANCED_FAN_OUT -> new FanOutConfig(kinesisClient).streamName(streamName).applicationName(applicationName); - }; + return builder.build(); } @OnStopped public void onStopped() { - cleanUpState(); + if (shardManager != null) { + shardManager.releaseAllLeases(); + shardManager.close(); + shardManager = null; + } - initialized.set(false); - initializationResultFuture = null; - } + if (consumerClient instanceof EnhancedFanOutClient efo) { + efoConsumerArn = efo.getConsumerArn(); + } + if (consumerClient != null) { + consumerClient.close(); + consumerClient = null; + } - private void cleanUpState() { - if (kinesisScheduler != null) { - shutdownScheduler(); - kinesisScheduler = null; + if (asyncHttpClient != null) { + asyncHttpClient.close(); + asyncHttpClient = null; } if (kinesisClient != null) { kinesisClient.close(); kinesisClient = null; } + if (dynamoDbClient != null) { dynamoDbClient.close(); dynamoDbClient = null; } - if (cloudWatchClient != null) { - cloudWatchClient.close(); - cloudWatchClient = null; - } - recordBuffer = null; - readerRecordProcessor = null; - demarcatorValue = null; + closeQuietly(kinesisHttpClient); + kinesisHttpClient = null; + closeQuietly(dynamoHttpClient); + dynamoHttpClient = null; } - private void shutdownScheduler() { - if (kinesisScheduler.shutdownComplete()) { + @OnRemoved + public void onRemoved(final ProcessContext context) { + final String arn = efoConsumerArn; + efoConsumerArn = null; + if (arn == null) { return; } - final long start = System.nanoTime(); - getLogger().debug("Shutting down Kinesis Scheduler"); + final Region region = RegionUtil.getRegion(context); + final AwsCredentialsProvider credentialsProvider = context.getProperty(AWS_CREDENTIALS_PROVIDER_SERVICE) + .asControllerService(AwsCredentialsProviderService.class).getAwsCredentialsProvider(); + final String endpointOverride = context.getProperty(ENDPOINT_OVERRIDE).getValue(); - boolean gracefulShutdownSucceeded; - try { - gracefulShutdownSucceeded = kinesisScheduler.startGracefulShutdown().get(KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); - if (!gracefulShutdownSucceeded) { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully. See the logs for more details"); - } - } catch (final RuntimeException | InterruptedException | ExecutionException | TimeoutException e) { - if (e instanceof TimeoutException) { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully after {} seconds", KINESIS_SCHEDULER_GRACEFUL_SHUTDOWN_TIMEOUT.getSeconds(), e); - } else { - getLogger().warn("Failed to shutdown Kinesis Scheduler gracefully", e); - } - gracefulShutdownSucceeded = false; - } + final KinesisClientBuilder builder = KinesisClient.builder() + .region(region) + .credentialsProvider(credentialsProvider); - if (!gracefulShutdownSucceeded) { - kinesisScheduler.shutdown(); + if (endpointOverride != null && !endpointOverride.isEmpty()) { + builder.endpointOverride(URI.create(endpointOverride)); } - final long finish = System.nanoTime(); - getLogger().debug("Kinesis Scheduler shutdown finished after {} seconds", NANOSECONDS.toSeconds(finish - start)); + try (final KinesisClient tempClient = builder.build()) { + tempClient.deregisterStreamConsumer(DeregisterStreamConsumerRequest.builder() + .consumerARN(arn) + .build()); + getLogger().info("Deregistered EFO consumer [{}]", arn); + } catch (final Exception e) { + getLogger().warn("Failed to deregister EFO consumer [{}]; manual cleanup may be required", arn, e); + } } @Override public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { - if (!initialized.get()) { - if (!initializationResultFuture.isDone()) { - getLogger().debug("Waiting for Kinesis Scheduler to finish initialization"); + final NodeTypeProvider nodeTypeProvider = getNodeTypeProvider(); + final int clusterMemberCount = nodeTypeProvider.isClustered() ? 0 : Math.max(1, nodeTypeProvider.getClusterMembers().size()); + shardManager.refreshLeasesIfNecessary(clusterMemberCount); + final List ownedShards = shardManager.getOwnedShards(); + + if (ownedShards.isEmpty()) { + context.yield(); + return; + } + + final Set ownedShardIds = new HashSet<>(); + for (final Shard shard : ownedShards) { + ownedShardIds.add(shard.shardId()); + } + + consumerClient.removeUnownedShards(ownedShardIds); + consumerClient.startFetches(ownedShards, streamName, maxRecordsPerRequest, initialStreamPosition, shardManager); + consumerClient.logDiagnostics(ownedShards.size(), shardManager.getCachedShardCount()); + + final Set claimedShards = new HashSet<>(); + List consumed = List.of(); + try { + consumed = consumeRecords(claimedShards); + final List accepted = discardRelinquishedResults(consumed, claimedShards); + + if (accepted.isEmpty()) { + consumerClient.releaseShards(claimedShards); + context.yield(); + return; + } + + final PartitionedBatch batch = partitionByShardAndCheckpoint(accepted); + + final WriteResult output; + try { + output = writeResults(session, context, batch.resultsByShard()); + } catch (final Exception e) { + handleWriteFailure(e, accepted, claimedShards, context); + return; + } + + if (output.produced().isEmpty() && output.parseFailures().isEmpty()) { + consumerClient.releaseShards(claimedShards); context.yield(); return; } - checkInitializationResult(initializationResultFuture.resultNow()); + session.transfer(output.produced(), REL_SUCCESS); + if (!output.parseFailures().isEmpty()) { + session.transfer(output.parseFailures(), REL_PARSE_FAILURE); + session.adjustCounter("Records Parse Failure", output.parseFailures().size(), false); + } + session.adjustCounter("Records Consumed", output.totalRecordCount(), false); + final long dedupEvents = consumerClient.drainDeduplicatedEventCount(); + if (dedupEvents > 0) { + session.adjustCounter("EFO Deduplicated Events", dedupEvents, false); + } + + consumed = List.of(); + session.commitAsync( + () -> { + try { + shardManager.writeCheckpoints(batch.checkpoints()); + } finally { + try { + consumerClient.acknowledgeResults(accepted); + } finally { + consumerClient.releaseShards(claimedShards); + } + } + }, + failure -> { + try { + getLogger().error("Session commit failed; resetting shard iterators for re-consumption", failure); + consumerClient.rollbackResults(accepted); + } finally { + consumerClient.releaseShards(claimedShards); + } + }); + } catch (final Exception e) { + if (!consumed.isEmpty()) { + consumerClient.rollbackResults(consumed); + } + consumerClient.releaseShards(claimedShards); + throw e; } + } - final Optional leaseAcquired = recordBuffer.acquireBufferLease(); + private List discardRelinquishedResults(final List consumedResults, final Set claimedShards) { + final List accepted = new ArrayList<>(); + final List discarded = new ArrayList<>(); + for (final ShardFetchResult result : consumedResults) { + if (shardManager.shouldProcessFetchedResult(result.shardId())) { + accepted.add(result); + } else { + discarded.add(result); + } + } + + if (!discarded.isEmpty()) { + getLogger().debug("Discarding {} fetched shard result(s) for relinquished shards", discarded.size()); + consumerClient.rollbackResults(discarded); + for (final ShardFetchResult result : discarded) { + claimedShards.remove(result.shardId()); + } + consumerClient.releaseShards(discarded.stream().map(ShardFetchResult::shardId).toList()); + } + + return accepted; + } + + private PartitionedBatch partitionByShardAndCheckpoint(final List accepted) { + final Map> resultsByShard = new LinkedHashMap<>(); + for (final ShardFetchResult result : accepted) { + resultsByShard.computeIfAbsent(result.shardId(), k -> new ArrayList<>()).add(result); + } + for (final List shardResults : resultsByShard.values()) { + shardResults.sort(Comparator.comparing(ShardFetchResult::firstSequenceNumber)); + } + + final Map checkpoints = new HashMap<>(); + for (final List shardResults : resultsByShard.values()) { + final ShardFetchResult last = shardResults.getLast(); + checkpoints.put(last.shardId(), last.lastSequenceNumber()); + } - leaseAcquired.ifPresentOrElse( - lease -> processRecordsFromBuffer(session, lease), - context::yield - ); + return new PartitionedBatch(resultsByShard, checkpoints); } - private void checkInitializationResult(final InitializationResult initializationResult) { - switch (initializationResult) { - case InitializationResult.Success ignored -> { - final boolean wasInitialized = initialized.getAndSet(true); - if (!wasInitialized) { - getLogger().info( - "Started Kinesis Scheduler for stream [{}] with application name [{}] and workerId [{}]", - streamName, kinesisScheduler.applicationName(), kinesisScheduler.leaseManagementConfig().workerIdentifier()); + private List consumeRecords(final Set claimedShards) { + final List results = new ArrayList<>(); + final long startNanos = System.nanoTime(); + long estimatedBytes = 0; + + while (System.nanoTime() < startNanos + maxBatchNanos && estimatedBytes < maxBatchBytes) { + final List readyShards = consumerClient.getShardIdsWithResults(); + if (readyShards.isEmpty()) { + if (!consumerClient.hasPendingFetches()) { + break; + } + + try { + consumerClient.awaitResults(QUEUE_POLL_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + break; } + continue; } - case InitializationResult.Failure failure -> { - cleanUpState(); - final ProcessException ex = failure.error() - .map(err -> new ProcessException("Initialization failed for stream [%s]".formatted(streamName), err)) - // This branch is active only when a scheduler was shutdown, but no initialization error was provided. - // This behavior isn't typical and wasn't observed. - .orElseGet(() -> new ProcessException("Initialization failed for stream [%s]".formatted(streamName))); + boolean foundAny = false; + final int shardCount = readyShards.size(); + final int startOffset = (int) (shardRoundRobinCounter.getAndIncrement() % shardCount); + for (int i = 0; i < shardCount && estimatedBytes < maxBatchBytes; i++) { + final String shardId = readyShards.get((startOffset + i) % shardCount); + if (!claimedShards.contains(shardId) && !consumerClient.claimShard(shardId)) { + continue; + } + claimedShards.add(shardId); - throw ex; + final ShardFetchResult result = consumerClient.pollShardResult(shardId); + if (result != null) { + results.add(result); + estimatedBytes += estimateResultBytes(result); + foundAny = true; + } + } + + if (!foundAny) { + break; } } + + return results; } - private void processRecordsFromBuffer(final ProcessSession session, final Lease lease) { + private void handleWriteFailure(final Exception cause, final List accepted, + final Set claimedShards, final ProcessContext context) { + getLogger().error("Failed to write consumed Kinesis records", cause); + consumerClient.rollbackResults(accepted); + consumerClient.releaseShards(claimedShards); + context.yield(); + } + + private WriteResult writeResults(final ProcessSession session, final ProcessContext context, + final Map> resultsByShard) { + final List produced = new ArrayList<>(); + final List parseFailures = new ArrayList<>(); + long totalRecordCount = 0; + long totalBytesConsumed = 0; + long maxMillisBehind = -1; + try { - final List records = recordBuffer.consumeRecords(lease); + if (processingStrategy == ProcessingStrategy.FLOW_FILE) { + final BatchAccumulator batch = new BatchAccumulator(); + for (final List shardResults : resultsByShard.values()) { + for (final ShardFetchResult result : shardResults) { + batch.updateMillisBehind(result.millisBehindLatest()); + for (final UserRecord record : result.records()) { + batch.addBytes(record.data().length); + } + } + writeFlowFilePerRecord(session, shardResults, streamName, batch, produced); + } + totalRecordCount = batch.getRecordCount(); + totalBytesConsumed = batch.getBytesConsumed(); + maxMillisBehind = batch.getMaxMillisBehind(); + } else { + for (final Map.Entry> entry : resultsByShard.entrySet()) { + final BatchAccumulator batch = new BatchAccumulator(); + batch.setLastShardId(entry.getKey()); + for (final ShardFetchResult result : entry.getValue()) { + batch.updateMillisBehind(result.millisBehindLatest()); + batch.updateSequenceRange(result); + for (final UserRecord record : result.records()) { + batch.addBytes(record.data().length); + batch.updateRecordRange(record); + } + } - if (records.isEmpty()) { - recordBuffer.returnBufferLease(lease); - return; - } + if (processingStrategy == ProcessingStrategy.LINE_DELIMITED || processingStrategy == ProcessingStrategy.DEMARCATOR) { + final byte[] delimiter; + if (processingStrategy == ProcessingStrategy.LINE_DELIMITED) { + delimiter = NEWLINE_DELIMITER; + } else { + final String demarcatorValue = context.getProperty(MESSAGE_DEMARCATOR).getValue(); + delimiter = demarcatorValue.getBytes(StandardCharsets.UTF_8); + } + writeDelimited(session, entry.getValue(), streamName, batch, delimiter, produced); + } else { + writeRecordOriented(session, context, entry.getValue(), streamName, batch, produced, parseFailures); + } - final String shardId = lease.shardId(); - switch (processingStrategy) { - case FLOW_FILE -> processRecordsAsRaw(session, shardId, records); - case RECORD -> processRecordsWithReader(session, shardId, records); - case DEMARCATOR -> processRecordsAsDemarcated(session, shardId, records); + totalRecordCount += batch.getRecordCount(); + totalBytesConsumed += batch.getBytesConsumed(); + maxMillisBehind = Math.max(maxMillisBehind, batch.getMaxMillisBehind()); + } } + } catch (final Exception e) { + session.remove(produced); + session.remove(parseFailures); + throw e; + } - session.adjustCounter("Records Processed", records.size(), false); + return new WriteResult(produced, parseFailures, totalRecordCount, totalBytesConsumed, maxMillisBehind); + } - session.commitAsync( - () -> commitRecords(lease), - __ -> rollbackRecords(lease) - ); - } catch (final RuntimeException e) { - rollbackRecords(lease); - throw e; + private void writeFlowFilePerRecord(final ProcessSession session, final List results, + final String streamName, final BatchAccumulator batch, final List output) { + for (final ShardFetchResult result : results) { + for (final UserRecord record : result.records()) { + final byte[] recordBytes = record.data(); + FlowFile flowFile = session.create(); + try { + flowFile = session.write(flowFile, out -> out.write(recordBytes)); + + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put(ATTR_SHARD_ID, result.shardId()); + attributes.put(ATTR_FIRST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_LAST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_PARTITION_KEY, record.partitionKey()); + if (record.approximateArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(record.approximateArrivalTimestamp().toEpochMilli())); + } + attributes.put("record.count", "1"); + if (result.millisBehindLatest() >= 0) { + attributes.put(ATTR_MILLIS_BEHIND, String.valueOf(result.millisBehindLatest())); + } + + flowFile = session.putAllAttributes(flowFile, attributes); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, result.shardId())); + output.add(flowFile); + batch.incrementRecordCount(); + } catch (final Exception e) { + session.remove(flowFile); + throw e; + } + } } } - private void commitRecords(final Lease lease) { + private void writeDelimited(final ProcessSession session, final List results, + final String streamName, final BatchAccumulator batch, final byte[] delimiter, + final List output) { + FlowFile flowFile = session.create(); try { - recordBuffer.commitConsumedRecords(lease); - } finally { - recordBuffer.returnBufferLease(lease); + flowFile = session.write(flowFile, new OutputStreamCallback() { + @Override + public void process(final OutputStream out) throws IOException { + boolean first = true; + for (final ShardFetchResult result : results) { + for (final UserRecord record : result.records()) { + if (!first) { + out.write(delimiter); + } + out.write(record.data()); + first = false; + batch.incrementRecordCount(); + } + } + } + }); + + flowFile = session.putAllAttributes(flowFile, createFlowFileAttributes(streamName, batch)); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; } } - private void rollbackRecords(final Lease lease) { + /** + * Writes Kinesis records as NiFi records using the configured Record Reader and Record Writer. + * + *

This method may appear unnecessarily complex, but it is intended to address specific requirements:

+ *
    + *
  • Keep records ordered in the same order they are received from Kinesis
  • + *
  • Create as few FlowFiles as necessary, keeping many records together in larger FlowFiles for performance reasons.
  • + *
+ * + *

Alternative options have been considered, as well:

+ *
    + *
  • Read each Record one at a time with a separate RecordReader. If its schema is different than the previous + * record, create a new FlowFile. However, when the stream is filled with JSON and many fields are nullable, this + * can look like a different schema for each Record when inference is used, thus creating many tiny FlowFiles.
  • + *
  • Read each Record one at a time with a separate RecordReader. Map the RecordSchema to the existing RecordWriter + * for that schema, if one exists, and write to that writer; if none exists, create a new one. This results in better + * grouping in many cases, but it results in the output being reordered, as we may write records 1, 2, 3 to writers + * A, B, A.
  • + *
  • Create a single InputStream and RecordReader for the entire batch. Create a single Writer for the entire batch. + * This way, we infer a single schema for the entire batch that is appropriate for all records. This bundles all records + * in the batch into a single FlowFile, which is ideal. However, this approach fails when we are not inferring the schema + * and the records do not all have the same schema. In that case, we can fail when attempting to read the records or when + * we attempt to write the records due to schema incompatibility.
  • + *
+ * + *

+ * Additionally, the existing RecordSchema API does not tell us whether or not a schema was inferred, + * so we cannot easily make a decision based on that knowledge. Therefore, we have taken an approach that + * attempts to process data using our preferred method, falling back as necessary to other options. + *

+ * + *

+ * The primary path ({@link #writeRecordBatch}) combines all records into a single InputStream + * via {@link KinesisRecordInputStream} and creates one RecordReader. This is optimal for formats + * like JSON where the schema is inferred from the data: a single InputStream lets the reader see + * all records and produce a unified schema for the writer. + *

+ * + *

+ * However, this approach fails when records carry incompatible embedded schemas (e.g. Avro + * containers with different field sets). The single reader sees only the first schema and cannot + * parse subsequent records that differ from it. When this happens, the method falls back to + * {@link #writeRecordBatchPerRecord}, which processes each record individually and splits output + * across multiple FlowFiles when schemas change. + *

+ */ + private void writeRecordOriented(final ProcessSession session, final ProcessContext context, + final List results, final String streamName, + final BatchAccumulator batch, final List output, + final List parseFailureOutput) { + + final List allRecords = new ArrayList<>(); + for (final ShardFetchResult result : results) { + allRecords.addAll(result.records()); + } + try { - recordBuffer.rollbackConsumedRecords(lease); - } finally { - recordBuffer.returnBufferLease(lease); + final RecordReaderFactory readerFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); + final RecordSetWriterFactory writerFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); + writeRecordBatch(session, readerFactory, writerFactory, outputStrategy, + allRecords, streamName, batch, output); + } catch (final Exception e) { + getLogger().debug("Combined-stream record processing failed; falling back to per-record processing", e); + batch.resetRecordCount(); + final RecordBatchResult result = writeRecordBatchPerRecord(session, context, allRecords, streamName, batch); + output.addAll(result.output()); + parseFailureOutput.addAll(result.parseFailures()); } } - private void processRecordsAsRaw(final ProcessSession session, final String shardId, final List records) { - for (final KinesisClientRecord record : records) { - FlowFile flowFile = session.create(); - flowFile = session.putAllAttributes(flowFile, ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, record, record)); + private void writeRecordBatch(final ProcessSession session, final RecordReaderFactory readerFactory, + final RecordSetWriterFactory writerFactory, final OutputStrategy outputStrategy, + final List records, + final String streamName, final BatchAccumulator batch, final List output) { - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - channel.write(record.data()); + FlowFile flowFile = session.create(); + final Map attributes = new HashMap<>(); + + try { + flowFile = session.write(flowFile, new OutputStreamCallback() { + @Override + public void process(final OutputStream out) throws IOException { + try (final InputStream kinesisInput = new KinesisRecordInputStream(records); + final RecordReader reader = readerFactory.createRecordReader(Map.of(), kinesisInput, -1, getLogger())) { + + RecordSchema writeSchema = reader.getSchema(); + if (outputStrategy == OutputStrategy.INJECT_METADATA) { + final List fields = new ArrayList<>(writeSchema.getFields()); + fields.add(KinesisRecordMetadata.FIELD_METADATA); + writeSchema = new SimpleRecordSchema(fields); + } else if (outputStrategy == OutputStrategy.USE_WRAPPER) { + writeSchema = new SimpleRecordSchema(List.of( + KinesisRecordMetadata.FIELD_METADATA, + new RecordField(WRAPPER_VALUE_FIELD, RecordFieldType.RECORD.getRecordDataType(writeSchema)))); + } + + try (final RecordSetWriter writer = writerFactory.createWriter(getLogger(), writeSchema, out, Map.of())) { + writer.beginRecordSet(); + + int recordIndex = 0; + org.apache.nifi.serialization.record.Record nifiRecord; + while ((nifiRecord = reader.nextRecord()) != null) { + final UserRecord record = records.get(recordIndex++); + nifiRecord = decorateRecord(nifiRecord, record, record.shardId(), streamName, outputStrategy, writeSchema); + + writer.write(nifiRecord); + batch.incrementRecordCount(); + } + + final org.apache.nifi.serialization.WriteResult writeResult = writer.finishRecordSet(); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), writer.getMimeType()); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + } + } catch (final MalformedRecordException | SchemaNotFoundException e) { + throw new IOException(e); + } } }); - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - session.transfer(flowFile, REL_SUCCESS); + attributes.putAll(createFlowFileAttributes(streamName, batch)); + flowFile = session.putAllAttributes(flowFile, attributes); + session.getProvenanceReporter().receive(flowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; } } - private void processRecordsWithReader(final ProcessSession session, final String shardId, final List records) { - final ReaderRecordProcessor recordProcessor = readerRecordProcessor; - if (recordProcessor == null) { - throw new IllegalStateException("RecordProcessor has not been initialized"); - } + /** + * Fallback path that processes each Kinesis record individually, splitting output across multiple + * FlowFiles when the record schema changes between consecutive records. + * + *

This is invoked when the combined-stream approach ({@link #writeRecordBatch}) fails, which + * typically happens when the batch contains records with incompatible embedded schemas (e.g. Avro + * containers whose field sets differ). Rather than grouping or buffering records up front, this + * method makes a single pass: for each record it creates a RecordReader, compares the schema to + * the current writer's schema, and either continues writing to the same FlowFile or finalizes the + * current FlowFile and starts a new one. This preserves record ordering without demultiplexing.

+ * + *

Records that cannot be parsed (empty data, malformed content, missing schema) are collected + * and routed to the parse-failure relationship at the end.

+ * + * @param session the current process session + * @param context the current process context (used to resolve Record Reader, Record Writer, and Output Strategy) + * @param records the Kinesis records to process, in order + * @param streamName the Kinesis stream name, used for FlowFile attributes + * @param batch accumulator for batch-level attributes and record counting + * @return a {@link RecordBatchResult} containing the successfully written FlowFiles and any parse-failure FlowFiles + */ + private RecordBatchResult writeRecordBatchPerRecord(final ProcessSession session, final ProcessContext context, + final List records, + final String streamName, final BatchAccumulator batch) { - final ProcessingResult result = recordProcessor.processRecords(session, streamName, shardId, records); + final RecordReaderFactory readerFactory = context.getProperty(RECORD_READER).asControllerService(RecordReaderFactory.class); + final RecordSetWriterFactory writerFactory = context.getProperty(RECORD_WRITER).asControllerService(RecordSetWriterFactory.class); + final OutputStrategy outputStrategy = context.getProperty(OUTPUT_STRATEGY).asAllowableValue(OutputStrategy.class); - session.transfer(result.successFlowFiles(), REL_SUCCESS); - session.transfer(result.parseFailureFlowFiles(), REL_PARSE_FAILURE); - } + final List output = new ArrayList<>(); + final List parseFailureOutput = new ArrayList<>(); + final List unparseable = new ArrayList<>(); + FlowFile currentFlowFile = null; + OutputStream currentOut = null; + RecordSetWriter currentWriter = null; + RecordSchema currentReadSchema = null; + RecordSchema currentWriteSchema = null; - private void processRecordsAsDemarcated(final ProcessSession session, final String shardId, final List records) { - final byte[] demarcator = demarcatorValue; - if (demarcator == null) { - throw new IllegalStateException("Demarcator has not been initialized"); - } + try { + for (final UserRecord record : records) { - FlowFile flowFile = session.create(); + if (record.data().length == 0) { + unparseable.add(new ParseFailureRecord(record, "Record content is empty")); + continue; + } + + RecordSchema readSchema = null; + final List parsedRecords = new ArrayList<>(); + RecordReader reader = null; + try { + reader = readerFactory.createRecordReader(Map.of(), new ByteArrayInputStream(record.data()), record.data().length, getLogger()); + readSchema = reader.getSchema(); + org.apache.nifi.serialization.record.Record nifiRecord; + while ((nifiRecord = reader.nextRecord()) != null) { + parsedRecords.add(nifiRecord); + } + } catch (final MalformedRecordException | SchemaNotFoundException | IOException e) { + getLogger().debug("Kinesis record seq {} classified as unparseable: {}", record.sequenceNumber(), e.getMessage()); + unparseable.add(new ParseFailureRecord(record, e.toString())); + continue; + } finally { + closeQuietly(reader); + } + + if (parsedRecords.isEmpty()) { + unparseable.add(new ParseFailureRecord(record, "Record content produced no parsed records")); + continue; + } - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, records.getFirst(), records.getLast()); - attributes.put(RECORD_COUNT, String.valueOf(records.size())); - flowFile = session.putAllAttributes(flowFile, attributes); + if (currentWriter == null || !readSchema.equals(currentReadSchema)) { + if (currentWriter != null) { + final org.apache.nifi.serialization.WriteResult writeResult = currentWriter.finishRecordSet(); - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - boolean writtenData = false; - for (final KinesisClientRecord record : records) { - if (writtenData) { - out.write(demarcator); + currentWriter.close(); + currentOut.close(); + + final Map attributes = createFlowFileAttributes(streamName, batch); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), currentWriter.getMimeType()); + currentFlowFile = session.putAllAttributes(currentFlowFile, attributes); + + session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(currentFlowFile); + currentFlowFile = null; } - channel.write(record.data()); - writtenData = true; + + currentReadSchema = readSchema; + currentWriteSchema = buildWriteSchema(readSchema, outputStrategy); + currentFlowFile = session.create(); + currentOut = session.write(currentFlowFile); + currentWriter = writerFactory.createWriter(getLogger(), currentWriteSchema, currentOut, Map.of()); + currentWriter.beginRecordSet(); + batch.resetRanges(); + } + + batch.updateRecordRange(record); + + for (final org.apache.nifi.serialization.record.Record parsed : parsedRecords) { + final org.apache.nifi.serialization.record.Record decorated = + decorateRecord(parsed, record, record.shardId(), streamName, outputStrategy, currentWriteSchema); + currentWriter.write(decorated); + batch.incrementRecordCount(); } } - }); - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); + if (currentWriter != null) { + final org.apache.nifi.serialization.WriteResult writeResult = currentWriter.finishRecordSet(); + currentWriter.close(); + currentOut.close(); + + final Map attributes = createFlowFileAttributes(streamName, batch); + attributes.put("record.count", String.valueOf(writeResult.getRecordCount())); + attributes.putAll(writeResult.getAttributes()); + attributes.put(CoreAttributes.MIME_TYPE.key(), currentWriter.getMimeType()); - session.transfer(flowFile, REL_SUCCESS); + currentFlowFile = session.putAllAttributes(currentFlowFile, attributes); + session.getProvenanceReporter().receive(currentFlowFile, buildTransitUri(streamName, batch.getLastShardId())); + output.add(currentFlowFile); + currentFlowFile = null; + } + } catch (final Exception e) { + closeQuietly(currentWriter); + closeQuietly(currentOut); + if (currentFlowFile != null) { + session.remove(currentFlowFile); + } + if (e instanceof RuntimeException re) { + throw re; + } + throw new ProcessException(e); + } + + if (!unparseable.isEmpty()) { + getLogger().warn("Encountered {} unparseable record(s) in shard {}; routing to parse failure", + unparseable.size(), batch.getLastShardId()); + writeParseFailures(session, unparseable, streamName, batch, parseFailureOutput); + } + + return new RecordBatchResult(output, parseFailureOutput); } /** - * An adapter between Kinesis Consumer Library and {@link RecordBuffer}. + * Adjusts a read schema to the write schema required by the configured OutputStrategy. For + * {@code INJECT_METADATA} the metadata field is appended; for {@code USE_WRAPPER} a two-field + * wrapper schema is created; for {@code USE_VALUE} the read schema is returned unchanged. */ - private static class ConsumeKinesisRecordProcessor implements ShardRecordProcessor { + private static RecordSchema buildWriteSchema(final RecordSchema readSchema, final OutputStrategy outputStrategy) { + return switch (outputStrategy) { + case INJECT_METADATA -> { + final List fields = new ArrayList<>(readSchema.getFields()); + fields.add(KinesisRecordMetadata.FIELD_METADATA); + yield new SimpleRecordSchema(fields); + } + case USE_WRAPPER -> { + yield new SimpleRecordSchema(List.of( + KinesisRecordMetadata.FIELD_METADATA, + new RecordField(WRAPPER_VALUE_FIELD, RecordFieldType.RECORD.getRecordDataType(readSchema)))); + } + case USE_VALUE -> readSchema; + }; + } - private final RecordBuffer.ForKinesisClientLibrary recordBuffer; - private volatile @Nullable ShardBufferId bufferId; + /** + * Attaches Kinesis metadata to a NiFi record according to the configured OutputStrategy. + */ + private static org.apache.nifi.serialization.record.Record decorateRecord( + final org.apache.nifi.serialization.record.Record nifiRecord, + final UserRecord kinesisRecord, final String shardId, + final String streamName, final OutputStrategy outputStrategy, + final RecordSchema writeSchema) { + return switch (outputStrategy) { + case INJECT_METADATA -> { + final Map values = new HashMap<>(nifiRecord.toMap()); + values.put(KinesisRecordMetadata.METADATA, + KinesisRecordMetadata.composeMetadataObject(kinesisRecord, streamName, shardId)); + yield new MapRecord(writeSchema, values); + } + case USE_WRAPPER -> { + final Map wrapperValues = new HashMap<>(2, 1.0f); + wrapperValues.put(KinesisRecordMetadata.METADATA, + KinesisRecordMetadata.composeMetadataObject(kinesisRecord, streamName, shardId)); + wrapperValues.put(WRAPPER_VALUE_FIELD, nifiRecord); + yield new MapRecord(writeSchema, wrapperValues); + } + case USE_VALUE -> nifiRecord; + }; + } + + private void writeParseFailures(final ProcessSession session, final List unparseable, + final String streamName, final BatchAccumulator batch, final List parseFailureOutput) { + + for (final ParseFailureRecord parseFailureRecord : unparseable) { + final UserRecord record = parseFailureRecord.record(); + FlowFile flowFile = session.create(); + try { + final byte[] rawBytes = record.data(); + flowFile = session.write(flowFile, out -> out.write(rawBytes)); + + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put(ATTR_FIRST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_LAST_SEQUENCE, record.sequenceNumber()); + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(record.subSequenceNumber())); + attributes.put(ATTR_PARTITION_KEY, record.partitionKey()); + if (record.approximateArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(record.approximateArrivalTimestamp().toEpochMilli())); + } + attributes.put("record.count", "1"); + attributes.put(ATTR_RECORD_ERROR_MESSAGE, parseFailureRecord.reason()); + if (batch.getLastShardId() != null) { + attributes.put(ATTR_SHARD_ID, batch.getLastShardId()); + } + flowFile = session.putAllAttributes(flowFile, attributes); + parseFailureOutput.add(flowFile); + } catch (final Exception e) { + session.remove(flowFile); + throw e; + } + } + } - ConsumeKinesisRecordProcessor(final MemoryBoundRecordBuffer recordBuffer) { - this.recordBuffer = recordBuffer; + private static long estimateResultBytes(final ShardFetchResult result) { + long bytes = 0; + for (final UserRecord record : result.records()) { + bytes += record.data().length; } + return bytes; + } - @Override - public void initialize(final InitializationInput initializationInput) { - bufferId = recordBuffer.createBuffer(initializationInput.shardId()); + private static Map createFlowFileAttributes(final String streamName, final BatchAccumulator batch) { + final Map attributes = new HashMap<>(); + attributes.put(ATTR_STREAM_NAME, streamName); + attributes.put("record.count", String.valueOf(batch.getRecordCount())); + + if (batch.getMaxMillisBehind() >= 0) { + attributes.put(ATTR_MILLIS_BEHIND, String.valueOf(batch.getMaxMillisBehind())); + } + if (batch.getLastShardId() != null) { + attributes.put(ATTR_SHARD_ID, batch.getLastShardId()); + } + if (batch.getMinSequenceNumber() != null) { + attributes.put(ATTR_FIRST_SEQUENCE, batch.getMinSequenceNumber()); + } + if (batch.getMaxSequenceNumber() != null) { + attributes.put(ATTR_LAST_SEQUENCE, batch.getMaxSequenceNumber()); + } + if (batch.getMinSubSequenceNumber() != Long.MAX_VALUE) { + attributes.put(ATTR_FIRST_SUBSEQUENCE, String.valueOf(batch.getMinSubSequenceNumber())); + } + if (batch.getMaxSubSequenceNumber() != Long.MIN_VALUE) { + attributes.put(ATTR_LAST_SUBSEQUENCE, String.valueOf(batch.getMaxSubSequenceNumber())); + } + if (batch.getLastPartitionKey() != null) { + attributes.put(ATTR_PARTITION_KEY, batch.getLastPartitionKey()); + } + if (batch.getLatestArrivalTimestamp() != null) { + attributes.put(ATTR_ARRIVAL_TIMESTAMP, String.valueOf(batch.getLatestArrivalTimestamp().toEpochMilli())); + } + + return attributes; + } + + private void closeQuietly(final AutoCloseable closeable) { + if (closeable != null) { + try { + closeable.close(); + } catch (final Exception e) { + getLogger().warn("Failed to close resource", e); + } + } + } + + private static String buildTransitUri(final String streamName, final String shardId) { + return "kinesis://" + streamName + "/" + shardId; + } + + // Exposed for testing to allow injection of mock Shard Manager + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return new KinesisShardManager(kinesisClient, dynamoDbClient, logger, checkpointTableName, streamName); + } + + // Exposed for testing to allow injection of a mock client + protected KinesisConsumerClient createConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger, final boolean efoMode) { + if (efoMode) { + return new EnhancedFanOutClient(kinesisClient, logger); + } + return new PollingKinesisClient(kinesisClient, logger); + } + + private record RecordBatchResult(List output, List parseFailures) { + } + + private record ParseFailureRecord(UserRecord record, String reason) { + } + + private static final class KinesisRecordInputStream extends InputStream { + private final List chunks; + private int chunkIndex; + private int positionInChunk; + private int markChunkIndex = -1; + private int markPositionInChunk; + + KinesisRecordInputStream(final List records) { + this.chunks = new ArrayList<>(records.size()); + for (final UserRecord record : records) { + final byte[] data = record.data(); + if (data.length > 0) { + chunks.add(data); + } + } } @Override - public void processRecords(final ProcessRecordsInput processRecordsInput) { - if (bufferId == null) { - throw new IllegalStateException("Buffer ID not found: Record Processor not initialized"); + public int read() { + while (chunkIndex < chunks.size()) { + final byte[] current = chunks.get(chunkIndex); + if (positionInChunk < current.length) { + return current[positionInChunk++] & 0xFF; + } + chunkIndex++; + positionInChunk = 0; } - recordBuffer.addRecords(bufferId, processRecordsInput.records(), processRecordsInput.checkpointer()); + return -1; } @Override - public void leaseLost(final LeaseLostInput leaseLostInput) { - if (bufferId != null) { - recordBuffer.consumerLeaseLost(bufferId); + public int read(final byte[] buffer, final int offset, final int length) { + if (chunkIndex >= chunks.size()) { + return -1; + } + if (length == 0) { + return 0; } + + int totalRead = 0; + while (totalRead < length && chunkIndex < chunks.size()) { + final byte[] current = chunks.get(chunkIndex); + final int remaining = current.length - positionInChunk; + if (remaining <= 0) { + chunkIndex++; + positionInChunk = 0; + continue; + } + + final int toRead = Math.min(length - totalRead, remaining); + System.arraycopy(current, positionInChunk, buffer, offset + totalRead, toRead); + positionInChunk += toRead; + totalRead += toRead; + } + + return totalRead == 0 ? -1 : totalRead; } @Override - public void shardEnded(final ShardEndedInput shardEndedInput) { - if (bufferId != null) { - recordBuffer.checkpointEndedShard(bufferId, shardEndedInput.checkpointer()); + public int available() { + if (chunkIndex >= chunks.size()) { + return 0; } + return chunks.get(chunkIndex).length - positionInChunk; } @Override - public void shutdownRequested(final ShutdownRequestedInput shutdownRequestedInput) { - if (bufferId != null) { - recordBuffer.shutdownShardConsumption(bufferId, shutdownRequestedInput.checkpointer()); + public boolean markSupported() { + return true; + } + + @Override + public void mark(final int readLimit) { + markChunkIndex = chunkIndex; + markPositionInChunk = positionInChunk; + } + + @Override + public void reset() throws IOException { + if (markChunkIndex < 0) { + throw new IOException("Stream not marked"); } + chunkIndex = markChunkIndex; + positionInChunk = markPositionInChunk; } } - private static final class InitializationStateChangeListener implements WorkerStateChangeListener { + private record PartitionedBatch(Map> resultsByShard, Map checkpoints) { + } - private final ComponentLog logger; + private record WriteResult(List produced, List parseFailures, + long totalRecordCount, long totalBytesConsumed, long maxMillisBehind) { + } - private final CompletableFuture resultFuture = new CompletableFuture<>(); + private static final class BatchAccumulator { + private long bytesConsumed; + private long recordCount; + private long maxMillisBehind = -1; + private BigInteger minSequenceNumber; + private BigInteger maxSequenceNumber; + private long minSubSequenceNumber = Long.MAX_VALUE; + private long maxSubSequenceNumber = Long.MIN_VALUE; + private String lastPartitionKey; + private Instant latestArrivalTimestamp; + private String lastShardId; + + long getBytesConsumed() { + return bytesConsumed; + } - private volatile @Nullable Throwable initializationFailure; + long getRecordCount() { + return recordCount; + } - InitializationStateChangeListener(final ComponentLog logger) { - this.logger = logger; + long getMaxMillisBehind() { + return maxMillisBehind; } - @Override - public void onWorkerStateChange(final WorkerState newState) { - logger.info("Worker state changed to [{}]", newState); + String getMinSequenceNumber() { + return minSequenceNumber == null ? null : minSequenceNumber.toString(); + } - if (newState == WorkerState.STARTED) { - resultFuture.complete(new InitializationResult.Success()); - } else if (newState == WorkerState.SHUT_DOWN) { - resultFuture.complete(new InitializationResult.Failure(Optional.ofNullable(initializationFailure))); - } + String getMaxSequenceNumber() { + return maxSequenceNumber == null ? null : maxSequenceNumber.toString(); } - @Override - public void onAllInitializationAttemptsFailed(final Throwable e) { - // This method is called before the SHUT_DOWN_STARTED phase. - // Memorizing the error until the Scheduler is SHUT_DOWN. - initializationFailure = e; + long getMinSubSequenceNumber() { + return minSubSequenceNumber; } - Future result() { - return resultFuture; + long getMaxSubSequenceNumber() { + return maxSubSequenceNumber; + } + + String getLastPartitionKey() { + return lastPartitionKey; + } + + Instant getLatestArrivalTimestamp() { + return latestArrivalTimestamp; + } + + String getLastShardId() { + return lastShardId; + } + + void setLastShardId(final String shardId) { + lastShardId = shardId; + } + + void addBytes(final long bytes) { + bytesConsumed += bytes; } - } - private sealed interface InitializationResult { - record Success() implements InitializationResult { + void incrementRecordCount() { + recordCount++; } - record Failure(Optional error) implements InitializationResult { + void resetRecordCount() { + recordCount = 0; + } + + void updateMillisBehind(final long millisBehindLatest) { + maxMillisBehind = Math.max(maxMillisBehind, millisBehindLatest); + } + + void updateSequenceRange(final ShardFetchResult result) { + final BigInteger firstSeq = result.firstSequenceNumber(); + final BigInteger lastSeq = result.lastSequenceNumber(); + if (minSequenceNumber == null || firstSeq.compareTo(minSequenceNumber) < 0) { + minSequenceNumber = firstSeq; + } + if (maxSequenceNumber == null || lastSeq.compareTo(maxSequenceNumber) > 0) { + maxSequenceNumber = lastSeq; + } + } + + void updateRecordRange(final UserRecord record) { + updateSequenceFromRecord(record); + final long subSeq = record.subSequenceNumber(); + if (subSeq < minSubSequenceNumber) { + minSubSequenceNumber = subSeq; + } + if (subSeq > maxSubSequenceNumber) { + maxSubSequenceNumber = subSeq; + } + lastPartitionKey = record.partitionKey(); + final Instant arrival = record.approximateArrivalTimestamp(); + if (arrival != null && (latestArrivalTimestamp == null || arrival.isAfter(latestArrivalTimestamp))) { + latestArrivalTimestamp = arrival; + } + } + + void updateSequenceFromRecord(final UserRecord record) { + final BigInteger seqNum = new BigInteger(record.sequenceNumber()); + if (minSequenceNumber == null || seqNum.compareTo(minSequenceNumber) < 0) { + minSequenceNumber = seqNum; + } + if (maxSequenceNumber == null || seqNum.compareTo(maxSequenceNumber) > 0) { + maxSequenceNumber = seqNum; + } + } + + void resetRanges() { + minSequenceNumber = null; + maxSequenceNumber = null; + minSubSequenceNumber = Long.MAX_VALUE; + maxSubSequenceNumber = Long.MIN_VALUE; + lastPartitionKey = null; + latestArrivalTimestamp = null; } } @@ -919,6 +1561,7 @@ public String getDescription() { enum ProcessingStrategy implements DescribedValue { FLOW_FILE("Write one FlowFile for each consumed Kinesis Record"), + LINE_DELIMITED("Write one FlowFile containing multiple consumed Kinesis Records separated by line delimiters"), RECORD("Write one FlowFile containing multiple consumed Kinesis Records processed with Record Reader and Record Writer"), DEMARCATOR("Write one FlowFile containing multiple consumed Kinesis Records separated by a configurable demarcator"); @@ -943,15 +1586,11 @@ public String getDescription() { return description; } - private static ProcessingStrategy from(final String name) { - // As long as getValue() returns name(), using valueOf is fine. - return ProcessingStrategy.valueOf(name); - } } enum InitialPosition implements DescribedValue { - TRIM_HORIZON("Trim Horizon", "Start reading at the last untrimmed record in the shard in the system, which is the oldest data record in the shard."), - LATEST("Latest", "Start reading just after the most recent record in the shard, so that you always read the most recent data in the shard."), + TRIM_HORIZON("Trim Horizon", "Start reading at the last untrimmed record in the shard."), + LATEST("Latest", "Start reading just after the most recent record in the shard."), AT_TIMESTAMP("At Timestamp", "Start reading at the record with the specified timestamp."); private final String displayName; @@ -978,39 +1617,10 @@ public String getDescription() { } } - enum MetricsPublishing implements DescribedValue { - DISABLED("Disabled", "No metrics are published"), - LOGS("Logs", "Metrics are published to application logs"), - CLOUDWATCH("CloudWatch", "Metrics are published to Amazon CloudWatch"); - - private final String displayName; - private final String description; - - MetricsPublishing(final String displayName, final String description) { - this.displayName = displayName; - this.description = description; - } - - @Override - public String getValue() { - return name(); - } - - @Override - public String getDisplayName() { - return displayName; - } - - @Override - public String getDescription() { - return description; - } - } - enum OutputStrategy implements DescribedValue { USE_VALUE("Use Content as Value", "Write only the Kinesis Record value to the FlowFile record."), USE_WRAPPER("Use Wrapper", "Write the Kinesis Record value and metadata into the FlowFile record."), - INJECT_METADATA("Inject Metadata", "Write the Kinesis Record value to the FlowFile record and add a sub-record to it with metadata."); + INJECT_METADATA("Inject Metadata", "Write the Kinesis Record value to the FlowFile record and add a sub-record with metadata."); private final String displayName; private final String description; @@ -1035,19 +1645,4 @@ public String getDescription() { return description; } } - - // Visible for tests only. - @Nullable URI getKinesisEndpointOverride() { - return null; - } - - // Visible for tests only. - @Nullable URI getDynamoDbEndpointOverride() { - return null; - } - - // Visible for tests only. - @Nullable URI getCloudwatchEndpointOverride() { - return null; - } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java deleted file mode 100644 index acfc91b74e29..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisAttributes.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.HashMap; -import java.util.Map; - -final class ConsumeKinesisAttributes { - - private static final String PREFIX = "aws.kinesis."; - - // AWS Kinesis attributes. - static final String STREAM_NAME = PREFIX + "stream.name"; - static final String SHARD_ID = PREFIX + "shard.id"; - static final String FIRST_SEQUENCE_NUMBER = PREFIX + "first.sequence.number"; - static final String FIRST_SUB_SEQUENCE_NUMBER = PREFIX + "first.subsequence.number"; - static final String LAST_SEQUENCE_NUMBER = PREFIX + "last.sequence.number"; - static final String LAST_SUB_SEQUENCE_NUMBER = PREFIX + "last.subsequence.number"; - - static final String PARTITION_KEY = PREFIX + "partition.key"; - static final String APPROXIMATE_ARRIVAL_TIMESTAMP = PREFIX + "approximate.arrival.timestamp.ms"; - - // Record attributes. - static final String MIME_TYPE = "mime.type"; - static final String RECORD_COUNT = "record.count"; - static final String RECORD_ERROR_MESSAGE = "record.error.message"; - - /** - * Creates a map of FlowFile attributes from the provided Kinesis records. - * - * @param streamName the name of the Kinesis stream the FileFile records came from. - * @param shardId the shard ID the FlowFile records came from. - * @param firstRecord the first Kinesis record in the FlowFile. - * @param lastRecord the last Kinesis record in the FlowFile. - * @return a mutable map with kinesis attributes. - */ - static Map fromKinesisRecords( - final String streamName, - final String shardId, - final KinesisClientRecord firstRecord, - final KinesisClientRecord lastRecord) { - final Map attributes = new HashMap<>(8); - - attributes.put(STREAM_NAME, streamName); - attributes.put(SHARD_ID, shardId); - - attributes.put(FIRST_SEQUENCE_NUMBER, firstRecord.sequenceNumber()); - attributes.put(FIRST_SUB_SEQUENCE_NUMBER, String.valueOf(firstRecord.subSequenceNumber())); - - attributes.put(LAST_SEQUENCE_NUMBER, lastRecord.sequenceNumber()); - attributes.put(LAST_SUB_SEQUENCE_NUMBER, String.valueOf(lastRecord.subSequenceNumber())); - - attributes.put(PARTITION_KEY, lastRecord.partitionKey()); - - if (lastRecord.approximateArrivalTimestamp() != null) { - attributes.put(APPROXIMATE_ARRIVAL_TIMESTAMP, String.valueOf(lastRecord.approximateArrivalTimestamp().toEpochMilli())); - } - - return attributes; - } - - private ConsumeKinesisAttributes() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java new file mode 100644 index 000000000000..70047be3ea8b --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/EnhancedFanOutClient.java @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ConsumerStatus; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.LimitExceededException; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.RegisterStreamConsumerRequest; +import software.amazon.awssdk.services.kinesis.model.RegisterStreamConsumerResponse; +import software.amazon.awssdk.services.kinesis.model.ResourceInUseException; +import software.amazon.awssdk.services.kinesis.model.ResourceNotFoundException; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEvent; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardEventStream; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; + +import java.io.IOException; +import java.math.BigInteger; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +/** + * Enhanced Fan-Out Kinesis consumer that uses SubscribeToShard with dedicated throughput + * per shard via HTTP/2. Uses Reactive Streams demand-driven backpressure to control the + * rate of event delivery. + */ +final class EnhancedFanOutClient extends KinesisConsumerClient { + + private static final long SUBSCRIBE_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(5); + private static final long CONSUMER_REGISTRATION_POLL_MILLIS = 1_000; + private static final int CONSUMER_REGISTRATION_MAX_ATTEMPTS = 60; + private static final int MAX_QUEUED_RESULTS = 200; + + private final Map shardConsumers = new ConcurrentHashMap<>(); + private final Queue pausedConsumers = new ConcurrentLinkedQueue<>(); + private volatile KinesisAsyncClient kinesisAsyncClient; + private volatile String consumerArn; + private volatile String streamName; + + EnhancedFanOutClient(final KinesisClient kinesisClient, final ComponentLog logger) { + super(kinesisClient, logger); + } + + @Override + void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { + this.kinesisAsyncClient = asyncClient; + this.streamName = streamName; + registerEfoConsumer(streamName, consumerName); + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, final String initialStreamPosition, final KinesisShardManager shardManager) { + final long now = System.nanoTime(); + + for (final Shard shard : shards) { + final String shardId = shard.shardId(); + final ShardConsumer existing = shardConsumers.get(shardId); + + if (existing == null) { + final String checkpoint = shardManager.readCheckpoint(shardId); + final BigInteger lastSeq = checkpoint == null ? null : new BigInteger(checkpoint); + final StartingPosition startingPosition = buildStartingPosition(lastSeq, initialStreamPosition); + logger.info("Creating Enhanced Fan-Out subscription for stream [{}] shard [{}] type [{}] seq [{}]", streamName, shardId, startingPosition.type(), lastSeq); + final ShardConsumer shardConsumer = new ShardConsumer(shardId, result -> enqueueIfActiveConsumer(shardId, result), pausedConsumers, logger); + final ShardConsumer prior = shardConsumers.putIfAbsent(shardId, shardConsumer); + if (prior == null) { + try { + shardConsumer.subscribe(kinesisAsyncClient, consumerArn, startingPosition); + } catch (final Exception e) { + shardConsumers.remove(shardId, shardConsumer); + throw e; + } + } + } else if (existing.isSubscriptionExpired()) { + final long lastAttempt = existing.getLastSubscribeAttemptNanos(); + if (lastAttempt > 0 && now < lastAttempt + SUBSCRIBE_BACKOFF_NANOS) { + continue; + } + + final String checkpoint = shardManager.readCheckpoint(shardId); + final BigInteger checkpointSeq = checkpoint == null ? null : new BigInteger(checkpoint); + final BigInteger lastQueued = existing.getLastQueuedSequenceNumber(); + final BigInteger resumeSeq = maxSequenceNumber(lastQueued, checkpointSeq); + final StartingPosition startingPosition = buildStartingPosition(resumeSeq, initialStreamPosition); + logger.debug("Renewing expired Enhanced Fan-Out subscription for stream [{}] shard [{}] type [{}] seq [{}]", streamName, shardId, startingPosition.type(), resumeSeq); + existing.subscribe(kinesisAsyncClient, consumerArn, startingPosition); + } + } + + resumePausedConsumers(); + } + + @Override + boolean hasPendingFetches() { + return !shardConsumers.isEmpty(); + } + + @Override + long drainDeduplicatedEventCount() { + long total = 0; + for (final ShardConsumer sc : shardConsumers.values()) { + total += sc.drainDeduplicatedEventCount(); + } + return total; + } + + @Override + void acknowledgeResults(final List results) { + resumePausedConsumers(); + } + + private void resumePausedConsumers() { + int available = MAX_QUEUED_RESULTS - totalQueuedResults(); + int resumed = 0; + ShardConsumer consumer; + while (resumed < available && (consumer = pausedConsumers.poll()) != null) { + if (consumer.requestNextIfReady()) { + resumed++; + } + } + } + + private void enqueueIfActiveConsumer(final String shardId, final ShardFetchResult result) { + synchronized (getShardLock(shardId)) { + if (shardConsumers.containsKey(shardId)) { + enqueueResult(result); + } + } + } + + private ShardConsumer drainAndRemoveConsumer(final String shardId) { + synchronized (getShardLock(shardId)) { + drainShardQueue(shardId); + return shardConsumers.remove(shardId); + } + } + + @Override + void rollbackResults(final List results) { + for (final ShardFetchResult result : results) { + final ShardConsumer shardConsumer = drainAndRemoveConsumer(result.shardId()); + if (shardConsumer != null) { + shardConsumer.cancel(); + } + } + } + + @Override + void removeUnownedShards(final Set ownedShards) { + shardConsumers.entrySet().removeIf(entry -> { + if (!ownedShards.contains(entry.getKey())) { + entry.getValue().cancel(); + return true; + } + return false; + }); + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + if (!shouldLogDiagnostics()) { + return; + } + + final long now = System.nanoTime(); + int activeSubscriptions = 0; + int expiredSubscriptions = 0; + int backedOff = 0; + for (final ShardConsumer shardConsumer : shardConsumers.values()) { + if (shardConsumer.isSubscriptionExpired()) { + expiredSubscriptions++; + final long lastAttempt = shardConsumer.getLastSubscribeAttemptNanos(); + if (lastAttempt > 0 && now < lastAttempt + SUBSCRIBE_BACKOFF_NANOS) { + backedOff++; + } + } else { + activeSubscriptions++; + } + } + + final int queueDepth = totalQueuedResults(); + logger.debug("Kinesis Enhanced Fan-Out diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}/{}, shardConsumers={}, active={}, expired={}, backedOff={}", + cachedShardCount, ownedCount, queueDepth, MAX_QUEUED_RESULTS, shardConsumers.size(), activeSubscriptions, expiredSubscriptions, backedOff); + } + + @Override + void close() { + for (final ShardConsumer sc : shardConsumers.values()) { + sc.cancel(); + } + shardConsumers.clear(); + + if (kinesisAsyncClient != null) { + kinesisAsyncClient.close(); + kinesisAsyncClient = null; + } + + super.close(); + } + + void initializeForTest(final KinesisAsyncClient asyncClient, final String consumerArn) { + this.kinesisAsyncClient = asyncClient; + this.consumerArn = consumerArn; + } + + ShardConsumer getShardConsumer(final String shardId) { + return shardConsumers.get(shardId); + } + + String getConsumerArn() { + return consumerArn; + } + + private void registerEfoConsumer(final String streamName, final String consumerName) { + final DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder().streamName(streamName).build(); + final DescribeStreamResponse describeResponse = kinesisClient.describeStream(describeStreamRequest); + final String arn = describeResponse.streamDescription().streamARN(); + + try { + final DescribeStreamConsumerRequest describeConsumerReq = DescribeStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerStatus(); + if (status == ConsumerStatus.ACTIVE) { + consumerArn = kinesisClient.describeStreamConsumer(describeConsumerReq).consumerDescription().consumerARN(); + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] already registered and ACTIVE", consumerName, streamName); + return; + } + } catch (final ResourceNotFoundException ignored) { + } + + try { + final RegisterStreamConsumerRequest registerRequest = RegisterStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + final RegisterStreamConsumerResponse registerResponse = kinesisClient.registerStreamConsumer(registerRequest); + consumerArn = registerResponse.consumer().consumerARN(); + logger.info("Registered Enhanced Fan-Out consumer [{}] for stream [{}], waiting for ACTIVE status", consumerName, streamName); + } catch (final ResourceInUseException e) { + final DescribeStreamConsumerRequest fallbackRequest = DescribeStreamConsumerRequest.builder() + .streamARN(arn) + .consumerName(consumerName) + .build(); + consumerArn = kinesisClient.describeStreamConsumer(fallbackRequest).consumerDescription().consumerARN(); + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] already being registered", consumerName, streamName); + } + + waitForConsumerActive(arn, consumerName); + } + + private void waitForConsumerActive(final String streamArn, final String consumerName) { + final DescribeStreamConsumerRequest describeConsumerRequest = DescribeStreamConsumerRequest.builder() + .streamARN(streamArn) + .consumerName(consumerName) + .build(); + + for (int i = 0; i < CONSUMER_REGISTRATION_MAX_ATTEMPTS; i++) { + final ConsumerStatus status = kinesisClient.describeStreamConsumer(describeConsumerRequest).consumerDescription().consumerStatus(); + if (status == ConsumerStatus.ACTIVE) { + logger.info("Enhanced Fan-Out consumer [{}] for stream [{}] is now ACTIVE", consumerName, streamName); + return; + } + + try { + Thread.sleep(CONSUMER_REGISTRATION_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for Enhanced Fan-Out consumer [%s] registration for stream [%s]".formatted(consumerName, streamName), e); + } + } + + throw new ProcessException("Enhanced Fan-Out consumer [%s] for stream [%s] did not become ACTIVE within %d seconds".formatted(consumerName, streamName, CONSUMER_REGISTRATION_MAX_ATTEMPTS)); + } + + private static BigInteger maxSequenceNumber(final BigInteger a, final BigInteger b) { + if (a == null) { + return b; + } + if (b == null) { + return a; + } + return a.compareTo(b) >= 0 ? a : b; + } + + private StartingPosition buildStartingPosition(final BigInteger sequenceNumber, final String initialStreamPosition) { + if (sequenceNumber != null) { + return StartingPosition.builder() + .type(ShardIteratorType.AFTER_SEQUENCE_NUMBER) + .sequenceNumber(sequenceNumber.toString()) + .build(); + } + final ShardIteratorType iteratorType = ShardIteratorType.fromValue(initialStreamPosition); + final StartingPosition.Builder builder = StartingPosition.builder().type(iteratorType); + if (iteratorType == ShardIteratorType.AT_TIMESTAMP && getTimestampForInitialPosition() != null) { + builder.timestamp(getTimestampForInitialPosition()); + } + return builder.build(); + } + + static final class ShardConsumer { + private final String shardId; + private final Consumer resultSink; + private final Queue pausedConsumers; + private final ComponentLog consumerLogger; + private final AtomicBoolean subscribing = new AtomicBoolean(false); + private final AtomicBoolean paused = new AtomicBoolean(false); + private final AtomicInteger subscriptionGeneration = new AtomicInteger(); + private final AtomicLong deduplicatedEvents = new AtomicLong(); + private volatile Subscription subscription; + private volatile CompletableFuture subscriptionFuture; + private volatile long lastSubscribeAttemptNanos; + private volatile BigInteger lastQueuedSequenceNumber; + ShardConsumer(final String shardId, final Consumer resultSink, + final Queue pausedConsumers, final ComponentLog consumerLogger) { + this.shardId = shardId; + this.resultSink = resultSink; + this.pausedConsumers = pausedConsumers; + this.consumerLogger = consumerLogger; + } + + void subscribe(final KinesisAsyncClient asyncClient, final String consumerArn, final StartingPosition startingPosition) { + if (!subscribing.compareAndSet(false, true)) { + return; + } + + final int generation = subscriptionGeneration.incrementAndGet(); + + try { + final SubscribeToShardRequest request = SubscribeToShardRequest.builder() + .consumerARN(consumerArn) + .shardId(shardId) + .startingPosition(startingPosition) + .build(); + + final SubscribeToShardResponseHandler handler = SubscribeToShardResponseHandler.builder() + .subscriber(() -> new DemandDrivenSubscriber(generation)) + .onError(t -> { + logSubscriptionError(t); + endSubscriptionIfCurrent(generation); + }) + .build(); + + lastSubscribeAttemptNanos = System.nanoTime(); + subscriptionFuture = asyncClient.subscribeToShard(request, handler); + } catch (final Exception e) { + subscribing.set(false); + throw e; + } + } + + void requestNext() { + final Subscription sub = subscription; + if (sub != null) { + sub.request(1); + } + } + + boolean requestNextIfReady() { + if (paused.compareAndSet(true, false)) { + final Subscription sub = subscription; + if (sub != null) { + sub.request(1); + return true; + } + } + return false; + } + + boolean isSubscriptionExpired() { + final CompletableFuture future = subscriptionFuture; + return future == null || future.isDone(); + } + + long getLastSubscribeAttemptNanos() { + return lastSubscribeAttemptNanos; + } + + BigInteger getLastQueuedSequenceNumber() { + return lastQueuedSequenceNumber; + } + + long drainDeduplicatedEventCount() { + return deduplicatedEvents.getAndSet(0); + } + + void cancel() { + final CompletableFuture future = subscriptionFuture; + if (future != null) { + future.cancel(true); + } + subscription = null; + } + + void resetForRenewal() { + subscribing.set(false); + lastSubscribeAttemptNanos = 0; + } + + void setLastQueuedSequenceNumber(final BigInteger seq) { + lastQueuedSequenceNumber = seq; + } + + void setSubscription(final Subscription sub) { + subscription = sub; + } + + Subscription getSubscription() { + return subscription; + } + + int getSubscriptionGeneration() { + return subscriptionGeneration.get(); + } + + boolean isSubscribing() { + return subscribing.get(); + } + + void pause() { + if (paused.compareAndSet(false, true)) { + pausedConsumers.add(this); + } + } + + private void logSubscriptionError(final Throwable t) { + if (isCancellation(t)) { + consumerLogger.debug("Enhanced Fan-Out subscription cancelled for shard [{}]", shardId); + } else if (isRetryableSubscriptionError(t)) { + consumerLogger.warn("Enhanced Fan-Out subscription temporarily rejected for shard [{}]; will retry after backoff", shardId, t); + } else if (isRetryableStreamDisconnect(t)) { + consumerLogger.warn("Enhanced Fan-Out subscription disconnected for shard [{}]; will retry", shardId, t); + } else { + consumerLogger.error("Enhanced Fan-Out subscription error for shard [{}]", shardId, t); + } + } + + private static boolean isCancellation(final Throwable t) { + if (t instanceof CancellationException) { + return true; + } + if (t instanceof IOException && "Request cancelled".equals(t.getMessage())) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isCancellation(cause); + } + + private static boolean isRetryableSubscriptionError(final Throwable t) { + if (t instanceof LimitExceededException || t instanceof ResourceInUseException) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isRetryableSubscriptionError(cause); + } + + private static boolean isRetryableStreamDisconnect(final Throwable t) { + if (t instanceof IOException || t instanceof SdkClientException) { + return true; + } + if (t instanceof AwsServiceException ase && ase.statusCode() >= 500) { + return true; + } + final String className = t.getClass().getName(); + if (className.startsWith("io.netty.")) { + return true; + } + final Throwable cause = t.getCause(); + return cause != null && cause != t && isRetryableStreamDisconnect(cause); + } + + void endSubscriptionIfCurrent(final int generation) { + if (subscriptionGeneration.get() == generation) { + subscription = null; + subscribing.set(false); + } + } + + private List deduplicateRecords(final List records) { + final BigInteger threshold = lastQueuedSequenceNumber; + if (threshold == null) { + return records; + } + int firstNewIndex = records.size(); + for (int i = 0; i < records.size(); i++) { + if (new BigInteger(records.get(i).sequenceNumber()).compareTo(threshold) > 0) { + firstNewIndex = i; + break; + } + } + + if (firstNewIndex == 0) { + return records; + } + + final int kept = records.size() - firstNewIndex; + deduplicatedEvents.incrementAndGet(); + if (kept == 0) { + consumerLogger.debug("Skipped re-delivered Enhanced Fan-Out event for shard [{}] ({} records already seen)", shardId, records.size()); + } else { + consumerLogger.debug("Filtered {} duplicate record(s) from Enhanced Fan-Out event for shard [{}] (kept {})", firstNewIndex, shardId, kept); + } + + return records.subList(firstNewIndex, records.size()); + } + + private class DemandDrivenSubscriber implements Subscriber { + private final int generation; + + DemandDrivenSubscriber(final int generation) { + this.generation = generation; + } + + @Override + public void onSubscribe(final Subscription sub) { + subscription = sub; + paused.set(false); + consumerLogger.info("Enhanced Fan-Out subscription established for shard [{}] (HTTP/2 stream active)", shardId); + sub.request(1); + } + + @Override + public void onNext(final SubscribeToShardEventStream eventStream) { + if (!(eventStream instanceof SubscribeToShardEvent event)) { + requestNext(); + return; + } + + if (event.records().isEmpty()) { + requestNext(); + return; + } + + final long millisBehind = event.millisBehindLatest() != null ? event.millisBehindLatest() : -1; + final List records = deduplicateRecords(event.records()); + if (records.isEmpty()) { + requestNext(); + return; + } + + final ShardFetchResult result = createFetchResult(shardId, records, millisBehind); + lastQueuedSequenceNumber = result.lastSequenceNumber(); + resultSink.accept(result); + if (paused.compareAndSet(false, true)) { + pausedConsumers.add(ShardConsumer.this); + } + } + + @Override + public void onError(final Throwable t) { + logSubscriptionError(t); + endSubscriptionIfCurrent(generation); + } + + @Override + public void onComplete() { + consumerLogger.debug("Enhanced Fan-Out subscription completed normally for shard [{}]", shardId); + endSubscriptionIfCurrent(generation); + } + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java new file mode 100644 index 000000000000..ff1e8a0e4fa6 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClient.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base for Kinesis consumer clients. Provides per-shard result queues and + * shard-claim infrastructure used by both polling (shared-throughput) and EFO + * (enhanced-fan-out) implementations. + * + *

Results are stored in per-shard FIFO queues rather than a single shared queue. + * This guarantees that results for the same shard are always consumed in enqueue order, + * preventing out-of-order delivery when concurrent tasks cannot claim the same shard. + */ +abstract class KinesisConsumerClient { + + private static final long DIAGNOSTIC_INTERVAL_NANOS = TimeUnit.SECONDS.toNanos(30); + + protected final KinesisClient kinesisClient; + protected final ComponentLog logger; + private final Map> shardQueues = new ConcurrentHashMap<>(); + private final Map shardLocks = new ConcurrentHashMap<>(); + private final Semaphore resultNotification = new Semaphore(0); + protected final Set shardsInFlight = ConcurrentHashMap.newKeySet(); + + private volatile long lastDiagnosticLogNanos; + private volatile Instant timestampForInitialPosition; + + KinesisConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger) { + this.kinesisClient = kinesisClient; + this.logger = logger; + } + + void setTimestampForInitialPosition(final Instant timestamp) { + this.timestampForInitialPosition = timestamp; + } + + Instant getTimestampForInitialPosition() { + return timestampForInitialPosition; + } + + void initialize(final KinesisAsyncClient asyncClient, final String streamName, final String consumerName) { + } + + abstract void startFetches(List shards, String streamName, int batchSize, + String initialStreamPosition, KinesisShardManager shardManager); + + abstract boolean hasPendingFetches(); + + abstract void acknowledgeResults(List results); + + abstract void rollbackResults(List results); + + abstract void removeUnownedShards(Set ownedShards); + + abstract void logDiagnostics(int ownedCount, int cachedShardCount); + + Object getShardLock(final String shardId) { + return shardLocks.computeIfAbsent(shardId, k -> new Object()); + } + + void close() { + shardQueues.clear(); + shardLocks.clear(); + resultNotification.drainPermits(); + shardsInFlight.clear(); + } + + void enqueueResult(final ShardFetchResult result) { + shardQueues.computeIfAbsent(result.shardId(), k -> new ConcurrentLinkedQueue<>()).add(result); + resultNotification.release(); + } + + ShardFetchResult pollShardResult(final String shardId) { + final Queue queue = shardQueues.get(shardId); + final ShardFetchResult result = queue == null ? null : queue.poll(); + if (result != null) { + onResultPolled(); + } + return result; + } + + protected void onResultPolled() { + } + + int drainShardQueue(final String shardId) { + final Queue queue = shardQueues.get(shardId); + if (queue == null) { + return 0; + } + int drained = 0; + while (queue.poll() != null) { + drained++; + } + return drained; + } + + ShardFetchResult pollAnyResult(final long timeout, final TimeUnit unit) throws InterruptedException { + final long deadlineNanos = System.nanoTime() + unit.toNanos(timeout); + while (System.nanoTime() < deadlineNanos) { + for (final Queue queue : shardQueues.values()) { + final ShardFetchResult result = queue.poll(); + if (result != null) { + onResultPolled(); + return result; + } + } + final long remainingMs = TimeUnit.NANOSECONDS.toMillis(deadlineNanos - System.nanoTime()); + if (remainingMs <= 0) { + break; + } + resultNotification.drainPermits(); + resultNotification.tryAcquire(Math.min(remainingMs, 100), TimeUnit.MILLISECONDS); + } + return null; + } + + List getShardIdsWithResults() { + final List ids = new ArrayList<>(); + for (final Map.Entry> entry : shardQueues.entrySet()) { + if (!entry.getValue().isEmpty()) { + ids.add(entry.getKey()); + } + } + return ids; + } + + boolean awaitResults(final long timeout, final TimeUnit unit) throws InterruptedException { + resultNotification.drainPermits(); + return resultNotification.tryAcquire(timeout, unit); + } + + int totalQueuedResults() { + int total = 0; + for (final Queue queue : shardQueues.values()) { + total += queue.size(); + } + return total; + } + + boolean hasQueuedResults() { + for (final Queue queue : shardQueues.values()) { + if (!queue.isEmpty()) { + return true; + } + } + return false; + } + + boolean claimShard(final String shardId) { + return shardsInFlight.add(shardId); + } + + void releaseShards(final Collection shardIds) { + shardsInFlight.removeAll(shardIds); + } + + protected static ShardFetchResult createFetchResult(final String shardId, final List records, final long millisBehindLatest) { + return new ShardFetchResult(shardId, ProducerLibraryDeaggregator.deaggregate(shardId, records), millisBehindLatest); + } + + long drainDeduplicatedEventCount() { + return 0; + } + + protected boolean shouldLogDiagnostics() { + final long now = System.nanoTime(); + if (now < DIAGNOSTIC_INTERVAL_NANOS + lastDiagnosticLogNanos) { + return false; + } + lastDiagnosticLogNanos = now; + return true; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java similarity index 70% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java index 9fdc480727f1..bdbeb697ec5e 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordMetadata.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordMetadata.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.nifi.processors.aws.kinesis.converter; +package org.apache.nifi.processors.aws.kinesis; import org.apache.nifi.serialization.SimpleRecordSchema; import org.apache.nifi.serialization.record.MapRecord; @@ -22,16 +22,15 @@ import org.apache.nifi.serialization.record.RecordField; import org.apache.nifi.serialization.record.RecordFieldType; import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; import java.util.HashMap; import java.util.List; import java.util.Map; -final class KinesisRecordMetadata { +public final class KinesisRecordMetadata { - static final String METADATA = "kinesisMetadata"; - static final String APPROX_ARRIVAL_TIMESTAMP = "approximateArrival"; + public static final String METADATA = "kinesisMetadata"; + public static final String APPROX_ARRIVAL_TIMESTAMP = "approximateArrival"; private static final String STREAM = "stream"; private static final String SHARD_ID = "shardId"; @@ -49,28 +48,28 @@ final class KinesisRecordMetadata { private static final RecordField FIELD_APPROX_ARRIVAL_TIMESTAMP = new RecordField(APPROX_ARRIVAL_TIMESTAMP, RecordFieldType.TIMESTAMP.getDataType()); private static final RecordSchema SCHEMA_METADATA = new SimpleRecordSchema(List.of( - FIELD_STREAM, - FIELD_SHARD_ID, - FIELD_SEQUENCE_NUMBER, - FIELD_SUB_SEQUENCE_NUMBER, - FIELD_SHARDED_SEQUENCE_NUMBER, - FIELD_PARTITION_KEY, - FIELD_APPROX_ARRIVAL_TIMESTAMP)); + FIELD_STREAM, + FIELD_SHARD_ID, + FIELD_SEQUENCE_NUMBER, + FIELD_SUB_SEQUENCE_NUMBER, + FIELD_SHARDED_SEQUENCE_NUMBER, + FIELD_PARTITION_KEY, + FIELD_APPROX_ARRIVAL_TIMESTAMP)); - static final RecordField FIELD_METADATA = new RecordField(METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)); + public static final RecordField FIELD_METADATA = new RecordField(METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)); - static Record composeMetadataObject(final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { + public static Record composeMetadataObject(final UserRecord record, final String streamName, final String shardId) { final Map metadata = new HashMap<>(7, 1.0f); metadata.put(STREAM, streamName); metadata.put(SHARD_ID, shardId); - metadata.put(SEQUENCE_NUMBER, kinesisRecord.sequenceNumber()); - metadata.put(SUB_SEQUENCE_NUMBER, kinesisRecord.subSequenceNumber()); - metadata.put(SHARDED_SEQUENCE_NUMBER, "%s%020d".formatted(kinesisRecord.sequenceNumber(), kinesisRecord.subSequenceNumber())); - metadata.put(PARTITION_KEY, kinesisRecord.partitionKey()); + metadata.put(SEQUENCE_NUMBER, record.sequenceNumber()); + metadata.put(SUB_SEQUENCE_NUMBER, record.subSequenceNumber()); + metadata.put(SHARDED_SEQUENCE_NUMBER, "%s%020d".formatted(record.sequenceNumber(), record.subSequenceNumber())); + metadata.put(PARTITION_KEY, record.partitionKey()); - if (kinesisRecord.approximateArrivalTimestamp() != null) { - metadata.put(APPROX_ARRIVAL_TIMESTAMP, kinesisRecord.approximateArrivalTimestamp().toEpochMilli()); + if (record.approximateArrivalTimestamp() != null) { + metadata.put(APPROX_ARRIVAL_TIMESTAMP, record.approximateArrivalTimestamp().toEpochMilli()); } return new MapRecord(SCHEMA_METADATA, metadata); diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java new file mode 100644 index 000000000000..53fbfd41bd3c --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManager.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; +import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; +import software.amazon.awssdk.services.kinesis.model.Shard; + +import java.math.BigInteger; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Coordinates shard ownership and checkpoints across clustered processor instances using + * a single DynamoDB table. + * + *

The table stores two record types under the same stream hash key: + *

    + *
  • Shard rows: key {@code (streamName, shardId)} with lease/checkpoint fields
  • + *
  • Node heartbeat rows: key {@code (streamName, "__node__#")}
  • + *
+ * + *

Lease lifecycle used during refresh: + *

    + *
  1. Discover active shard leases and identify currently available shards
  2. + *
  3. Compute fair-share target from active node heartbeats
  4. + *
  5. If this node is over target, mark excess shards for graceful relinquish
  6. + *
  7. Continue renewing leases for owned shards and draining relinquishing shards
  8. + *
  9. After drain deadline, explicitly release relinquishing shards
  10. + *
  11. Acquire available shards until fair-share target is reached
  12. + *
+ * + *

Graceful relinquish is designed to reduce duplicate replay at rebalance boundaries by: + * (a) stopping new fetches immediately (shard removed from {@code ownedShards}), + * (b) briefly retaining lease ownership to allow in-flight work to finish, + * then (c) explicitly releasing the lease for fast handoff. + * + *

Shard split/merge: this implementation does not enforce parent-before-child + * ordering. When a shard is split or shards are merged, child shards become eligible for + * consumption immediately alongside any still-active parent shards. Callers that require strict + * ordering across split/merge boundaries would need to defer child shard assignment until the + * parent shard's {@code SHARD_END} has been reached and checkpointed. + */ +final class KinesisShardManager { + + private static final long DEFAULT_SHARD_CACHE_MILLIS = 60_000; + private static final long DEFAULT_LEASE_DURATION_MILLIS = 30_000; + private static final long DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS = 10_000; + private static final long DEFAULT_NODE_HEARTBEAT_EXPIRATION_MILLIS = + DEFAULT_LEASE_DURATION_MILLIS + DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS; + + private final KinesisClient kinesisClient; + private final DynamoDbClient dynamoDbClient; + private final ComponentLog logger; + private final String nodeId; + private final String checkpointTableName; + private final String streamName; + private final long shardCacheMillis; + private final long leaseDurationMillis; + private final long leaseRefreshIntervalMillis; + private final long nodeHeartbeatExpirationMillis; + private final long relinquishDrainMillis; + + private volatile ShardCache shardCache = new ShardCache(List.of(), Instant.EPOCH); + private final Set ownedShards = ConcurrentHashMap.newKeySet(); + private final Map pendingRelinquishDeadlines = new ConcurrentHashMap<>(); + private final AtomicBoolean leaseRefreshInProgress = new AtomicBoolean(false); + private final Map highestWrittenCheckpoints = new ConcurrentHashMap<>(); + private volatile Instant lastLeaseRefresh = Instant.EPOCH; + private volatile String activeCheckpointTableName; + + KinesisShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, final ComponentLog logger, + final String checkpointTableName, final String streamName) { + this(kinesisClient, dynamoDbClient, logger, checkpointTableName, streamName, + DEFAULT_SHARD_CACHE_MILLIS, DEFAULT_LEASE_DURATION_MILLIS, + DEFAULT_LEASE_REFRESH_INTERVAL_MILLIS, DEFAULT_NODE_HEARTBEAT_EXPIRATION_MILLIS); + } + + KinesisShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, final ComponentLog logger, + final String checkpointTableName, final String streamName, final long shardCacheMillis, + final long leaseDurationMillis, final long leaseRefreshIntervalMillis, final long nodeHeartbeatExpirationMillis) { + this.kinesisClient = kinesisClient; + this.dynamoDbClient = dynamoDbClient; + this.logger = logger; + this.nodeId = UUID.randomUUID().toString(); + this.checkpointTableName = checkpointTableName; + this.streamName = streamName; + this.shardCacheMillis = shardCacheMillis; + this.leaseDurationMillis = leaseDurationMillis; + this.leaseRefreshIntervalMillis = leaseRefreshIntervalMillis; + this.nodeHeartbeatExpirationMillis = nodeHeartbeatExpirationMillis; + this.relinquishDrainMillis = Math.max(2_000L, leaseRefreshIntervalMillis); + this.activeCheckpointTableName = checkpointTableName; + } + + void ensureCheckpointTableExists() { + final CheckpointTableUtils.TableSchema currentSchema = + CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName); + logger.debug("Checkpoint table [{}] detected as {}", checkpointTableName, currentSchema); + + final LegacyCheckpointMigrator migrator = + new LegacyCheckpointMigrator(dynamoDbClient, checkpointTableName, streamName, nodeId, logger); + + switch (currentSchema) { + case NOT_FOUND -> { + final String orphanedMigration = migrator.findMigrationTable(); + if (orphanedMigration == null) { + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + } else { + logger.info("Found orphaned migration table [{}]; renaming to [{}]", + orphanedMigration, checkpointTableName); + migrator.renameMigrationTable(orphanedMigration); + } + } + case NEW -> { + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + migrator.cleanupLingeringMigration(); + } + case LEGACY -> { + migrator.migrateAndRename(); + } + default -> throw new ProcessException( + "Unsupported DynamoDB schema for checkpoint table [%s]".formatted(checkpointTableName)); + } + + activeCheckpointTableName = checkpointTableName; + logger.info("Using checkpoint table [{}] for stream [{}]", activeCheckpointTableName, streamName); + } + + List getShards() { + final ShardCache current = shardCache; + if (!current.shards().isEmpty() && Instant.now().toEpochMilli() < current.refreshTime().toEpochMilli() + shardCacheMillis) { + return current.shards(); + } + return refreshShards(); + } + + int getCachedShardCount() { + return shardCache.shards().size(); + } + + Set getOwnedShardIds() { + return Collections.unmodifiableSet(ownedShards); + } + + List getOwnedShards() { + final List result = new ArrayList<>(); + for (final Shard shard : getShards()) { + if (ownedShards.contains(shard.shardId())) { + result.add(shard); + } + } + return result; + } + + boolean shouldProcessFetchedResult(final String shardId) { + return ownedShards.contains(shardId) || pendingRelinquishDeadlines.containsKey(shardId); + } + + void refreshLeasesIfNecessary(final int clusterMemberCount) { + if (Instant.now().toEpochMilli() < leaseRefreshIntervalMillis + lastLeaseRefresh.toEpochMilli()) { + return; + } + + if (!leaseRefreshInProgress.compareAndSet(false, true)) { + return; + } + + try { + final List allShards = getShards(); + final Set currentShardIds = new HashSet<>(); + for (final Shard shard : allShards) { + currentShardIds.add(shard.shardId()); + } + + final long now = Instant.now().toEpochMilli(); + updateNodeHeartbeat(now); + final Map> ownerToShards = new HashMap<>(); + final List availableShardIds = new ArrayList<>(); + + final Map> leaseItemsByShardId = queryAllLeaseItems(); + for (final String shardId : currentShardIds) { + final Map item = leaseItemsByShardId.get(shardId); + if (item != null && item.containsKey("leaseOwner")) { + final String owner = item.get("leaseOwner").s(); + final AttributeValue expiryAttr = item.get("leaseExpiry"); + final long expiry = expiryAttr == null ? 0 : Long.parseLong(expiryAttr.n()); + if (expiry >= now) { + ownerToShards.computeIfAbsent(owner, k -> new ArrayList<>()).add(shardId); + } else { + availableShardIds.add(shardId); + } + } else { + availableShardIds.add(shardId); + } + } + + final Set stillOwned = new HashSet<>(ownerToShards.getOrDefault(nodeId, List.of())); + ownedShards.retainAll(stillOwned); + pendingRelinquishDeadlines.keySet().removeIf(shardId -> !currentShardIds.contains(shardId) || !stillOwned.contains(shardId)); + + final int heartbeatNodes = countActiveNodes(now); + final int totalOwners = Math.max(heartbeatNodes, Math.max(1, clusterMemberCount)); + final int targetCount = (allShards.size() + totalOwners - 1) / totalOwners; + + final List currentlyOwned = new ArrayList<>(ownerToShards.getOrDefault(nodeId, List.of())); + final int excessCount = Math.max(0, currentlyOwned.size() - targetCount); + if (excessCount > 0) { + for (int index = 0; index < excessCount && index < currentlyOwned.size(); index++) { + final String shardToRelinquish = currentlyOwned.get(index); + if (!pendingRelinquishDeadlines.containsKey(shardToRelinquish)) { + pendingRelinquishDeadlines.put(shardToRelinquish, now + relinquishDrainMillis); + ownedShards.remove(shardToRelinquish); + logger.info("Starting graceful relinquish for shard {}", shardToRelinquish); + } + } + } + + for (final String shardId : ownedShards) { + tryAcquireLease(shardId); + } + + for (final Map.Entry pendingRelinquish : new ArrayList<>(pendingRelinquishDeadlines.entrySet())) { + final String shardId = pendingRelinquish.getKey(); + final long relinquishDeadline = pendingRelinquish.getValue(); + if (!stillOwned.contains(shardId)) { + pendingRelinquishDeadlines.remove(shardId); + continue; + } + + if (now < relinquishDeadline) { + tryAcquireLease(shardId); + } else { + try { + releaseLease(shardId); + pendingRelinquishDeadlines.remove(shardId); + logger.info("Completed graceful relinquish for shard {}", shardId); + } catch (final Exception e) { + logger.warn("Failed to complete graceful relinquish for shard {}", shardId, e); + } + } + } + + for (final String shardId : availableShardIds) { + if (ownedShards.size() >= targetCount) { + break; + } + if (pendingRelinquishDeadlines.containsKey(shardId)) { + continue; + } + if (tryAcquireLease(shardId)) { + ownedShards.add(shardId); + } + } + + ownedShards.removeIf(id -> !currentShardIds.contains(id)); + lastLeaseRefresh = Instant.now(); + } finally { + leaseRefreshInProgress.set(false); + } + } + + String readCheckpoint(final String shardId) { + final GetItemRequest getItemRequest = GetItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .consistentRead(true) + .build(); + final GetItemResponse response = dynamoDbClient.getItem(getItemRequest); + + if (response.hasItem() && response.item().containsKey("sequenceNumber")) { + final String value = response.item().get("sequenceNumber").s(); + if (isValidSequenceNumber(value)) { + logger.debug("Read checkpoint for shard {}: {}", shardId, value); + return value; + } + logger.warn("Ignoring non-numeric checkpoint [{}] for shard {} in table [{}]", value, shardId, activeCheckpointTableName); + } else { + logger.debug("No checkpoint found for shard {} in table [{}]", shardId, activeCheckpointTableName); + } + return null; + } + + void writeCheckpoints(final Map checkpoints) { + for (final Map.Entry entry : checkpoints.entrySet()) { + writeCheckpoint(entry.getKey(), entry.getValue()); + } + } + + void releaseAllLeases() { + final Set shardsToRelease = new HashSet<>(ownedShards); + shardsToRelease.addAll(pendingRelinquishDeadlines.keySet()); + for (final String shardId : shardsToRelease) { + try { + releaseLease(shardId); + } catch (final Exception e) { + logger.warn("Failed to release lease for shard {}", shardId, e); + } + } + + try { + final UpdateItemRequest heartbeatReleaseRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(nodeHeartbeatKey()) + .updateExpression("REMOVE nodeHeartbeat, lastUpdateTimestamp") + .build(); + dynamoDbClient.updateItem(heartbeatReleaseRequest); + } catch (final Exception e) { + logger.debug("Failed to clear node heartbeat record for node {}", nodeId, e); + } + } + + void close() { + ownedShards.clear(); + pendingRelinquishDeadlines.clear(); + highestWrittenCheckpoints.clear(); + leaseRefreshInProgress.set(false); + lastLeaseRefresh = Instant.EPOCH; + shardCache = new ShardCache(List.of(), Instant.EPOCH); + } + + private synchronized List refreshShards() { + final ShardCache current = shardCache; + if (!current.shards().isEmpty() && Instant.now().toEpochMilli() < current.refreshTime().toEpochMilli() + shardCacheMillis) { + return current.shards(); + } + + final List allShards = new ArrayList<>(); + String nextToken = null; + do { + final ListShardsRequest request = nextToken != null + ? ListShardsRequest.builder().nextToken(nextToken).build() + : ListShardsRequest.builder().streamName(streamName).build(); + + final ListShardsResponse response = kinesisClient.listShards(request); + allShards.addAll(response.shards()); + nextToken = response.nextToken(); + } while (nextToken != null); + + logger.debug("ListShards returned {} shards for stream {}", allShards.size(), streamName); + shardCache = new ShardCache(allShards, Instant.now()); + return allShards; + } + + private static boolean isValidSequenceNumber(final String value) { + if (value == null || value.isEmpty()) { + return false; + } + for (int idx = 0; idx < value.length(); idx++) { + if (!Character.isDigit(value.charAt(idx))) { + return false; + } + } + return true; + } + + private boolean tryAcquireLease(final String shardId) { + final long now = Instant.now().toEpochMilli(); + final long expiry = now + leaseDurationMillis; + final AttributeValue ownerVal = AttributeValue.builder().s(nodeId).build(); + final AttributeValue expiryVal = AttributeValue.builder().n(String.valueOf(expiry)).build(); + final AttributeValue nowVal = AttributeValue.builder().n(String.valueOf(now)).build(); + + try { + final UpdateItemRequest leaseRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("SET leaseOwner = :owner, leaseExpiry = :exp, lastUpdateTimestamp = :ts") + .conditionExpression("attribute_not_exists(leaseOwner) OR leaseOwner = :owner OR leaseExpiry < :now") + .expressionAttributeValues(Map.of( + ":owner", ownerVal, + ":exp", expiryVal, + ":now", nowVal, + ":ts", nowVal)) + .build(); + dynamoDbClient.updateItem(leaseRequest); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } catch (final Exception e) { + logger.warn("Failed to acquire lease for shard {}", shardId, e); + return false; + } + } + + private void updateNodeHeartbeat(final long now) { + final UpdateItemRequest heartbeatRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(nodeHeartbeatKey()) + .updateExpression("SET nodeHeartbeat = :heartbeat, lastUpdateTimestamp = :ts") + .expressionAttributeValues(Map.of( + ":heartbeat", AttributeValue.builder().n(String.valueOf(now)).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(heartbeatRequest); + } + + private Map> queryAllLeaseItems() { + final Map> itemsByShardId = new HashMap<>(); + final QueryRequest.Builder queryBuilder = QueryRequest.builder() + .tableName(activeCheckpointTableName) + .consistentRead(true) + .keyConditionExpression("streamName = :streamName") + .expressionAttributeValues(Map.of( + ":streamName", AttributeValue.builder().s(streamName).build())); + + Map exclusiveStartKey = null; + do { + final QueryRequest queryRequest = exclusiveStartKey == null + ? queryBuilder.build() + : queryBuilder.exclusiveStartKey(exclusiveStartKey).build(); + final QueryResponse queryResponse = dynamoDbClient.query(queryRequest); + for (final Map item : queryResponse.items()) { + final AttributeValue shardIdAttr = item.get(CheckpointTableUtils.ATTR_SHARD_ID); + if (shardIdAttr == null) { + continue; + } + final String shardId = shardIdAttr.s(); + if (shardId.startsWith(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX) + || CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID.equals(shardId)) { + continue; + } + itemsByShardId.put(shardId, item); + } + + exclusiveStartKey = queryResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + return itemsByShardId; + } + + private int countActiveNodes(final long now) { + final QueryRequest.Builder queryRequestBuilder = QueryRequest.builder() + .tableName(activeCheckpointTableName) + .consistentRead(true) + .keyConditionExpression("streamName = :streamName AND begins_with(shardId, :nodePrefix)") + .expressionAttributeValues(Map.of( + ":streamName", AttributeValue.builder().s(streamName).build(), + ":nodePrefix", AttributeValue.builder().s(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX).build())); + + Map exclusiveStartKey = null; + int activeNodes = 0; + do { + final QueryRequest queryRequest = exclusiveStartKey == null + ? queryRequestBuilder.build() + : queryRequestBuilder.exclusiveStartKey(exclusiveStartKey).build(); + final QueryResponse queryResponse = dynamoDbClient.query(queryRequest); + for (final Map item : queryResponse.items()) { + final AttributeValue heartbeatValue = item.get("nodeHeartbeat"); + if (heartbeatValue == null) { + continue; + } + + final long heartbeatMillis = Long.parseLong(heartbeatValue.n()); + if (now <= heartbeatMillis + nodeHeartbeatExpirationMillis) { + activeNodes++; + } + } + + exclusiveStartKey = queryResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + return Math.max(1, activeNodes); + } + + private void writeCheckpoint(final String shardId, final BigInteger checkpoint) { + final BigInteger written = highestWrittenCheckpoints.compute(shardId, + (key, existing) -> persistIfHigher(shardId, checkpoint, existing)); + + if (written != null && checkpoint.compareTo(written) < 0) { + logger.debug("Skipped checkpoint regression for shard {} (highest: {}, attempted: {})", shardId, written, checkpoint); + } + } + + /** + * Writes the checkpoint to DynamoDB if it is higher than the existing value. Returns the + * new highest checkpoint on success, or {@code existing} if the checkpoint was lower or + * the write failed. + */ + private BigInteger persistIfHigher(final String shardId, final BigInteger checkpoint, final BigInteger existing) { + if (existing != null && checkpoint.max(existing).equals(existing)) { + return existing; + } + + try { + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest checkpointRequest = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts, leaseExpiry = :exp") + .conditionExpression("leaseOwner = :owner") + .expressionAttributeValues(Map.of( + ":seq", AttributeValue.builder().s(checkpoint.toString()).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build(), + ":exp", AttributeValue.builder().n(String.valueOf(now + leaseDurationMillis)).build(), + ":owner", AttributeValue.builder().s(nodeId).build())) + .build(); + dynamoDbClient.updateItem(checkpointRequest); + logger.debug("Checkpointed shard {} at sequence {}", shardId, checkpoint); + + return checkpoint; + } catch (final ConditionalCheckFailedException e) { + logger.warn("Lost lease on shard {} during checkpoint; another node may have taken it", shardId); + } catch (final Exception e) { + logger.error("Failed to write checkpoint for shard {}", shardId, e); + } + + return existing; + } + + private Map checkpointKey(final String shardId) { + return Map.of( + CheckpointTableUtils.ATTR_STREAM_NAME, AttributeValue.builder().s(streamName).build(), + CheckpointTableUtils.ATTR_SHARD_ID, AttributeValue.builder().s(shardId).build()); + } + + private void releaseLease(final String shardId) { + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(activeCheckpointTableName) + .key(checkpointKey(shardId)) + .updateExpression("REMOVE leaseOwner, leaseExpiry") + .conditionExpression("leaseOwner = :owner") + .expressionAttributeValues(Map.of(":owner", AttributeValue.builder().s(nodeId).build())) + .build(); + dynamoDbClient.updateItem(request); + } + + private Map nodeHeartbeatKey() { + return checkpointKey(CheckpointTableUtils.NODE_HEARTBEAT_PREFIX + nodeId); + } + + private record ShardCache(List shards, Instant refreshTime) { } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java new file mode 100644 index 000000000000..d994094d2c47 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/LegacyCheckpointMigrator.java @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.apache.nifi.processor.exception.ProcessException; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; + +import java.time.Instant; +import java.util.Map; + +/** + * Handles one-time migration of checkpoint data from legacy KCL-format DynamoDB tables + * to the new composite-key schema used by {@link KinesisShardManager}, including the + * table rename lifecycle. Uses distributed locks in the target table to coordinate + * migration and rename operations across clustered nodes. + */ +final class LegacyCheckpointMigrator { + + private static final String MIGRATION_TABLE_SUFFIX = "_migration"; + private static final String LEGACY_LEASE_KEY_ATTRIBUTE = "leaseKey"; + private static final String LEGACY_CHECKPOINT_ATTRIBUTE = "checkpoint"; + private static final String MIGRATION_STATUS_ATTRIBUTE = "migrationStatus"; + private static final String MIGRATION_STATUS_IN_PROGRESS = "IN_PROGRESS"; + private static final String MIGRATION_STATUS_COMPLETE = "COMPLETE"; + private static final long MIGRATION_LOCK_STALE_MILLIS = 600_000; + private static final long MIGRATION_WAIT_MILLIS = 2_000; + private static final int MIGRATION_WAIT_MAX_ATTEMPTS = 180; + private static final long RENAME_LOCK_STALE_MILLIS = 120_000; + private static final long RENAME_POLL_MILLIS = 1_000; + private static final int RENAME_POLL_MAX_ATTEMPTS = 60; + + private final DynamoDbClient dynamoDbClient; + private final String checkpointTableName; + private final String streamName; + private final String nodeId; + private final ComponentLog logger; + + LegacyCheckpointMigrator(final DynamoDbClient dynamoDbClient, final String checkpointTableName, + final String streamName, final String nodeId, final ComponentLog logger) { + this.dynamoDbClient = dynamoDbClient; + this.checkpointTableName = checkpointTableName; + this.streamName = streamName; + this.nodeId = nodeId; + this.logger = logger; + } + + String findMigrationTable() { + final String migrationTableName = checkpointTableName + MIGRATION_TABLE_SUFFIX; + if (CheckpointTableUtils.getTableSchema(dynamoDbClient, migrationTableName) + == CheckpointTableUtils.TableSchema.NEW) { + return migrationTableName; + } + return null; + } + + void cleanupLingeringMigration() { + final String lingeringMigration = findMigrationTable(); + if (lingeringMigration == null) { + return; + } + logger.info("Deleting orphaned migration table [{}]; legacy checkpoint table [{}] retains original data", lingeringMigration, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, lingeringMigration); + } + + void migrateAndRename() { + final String existingMigration = findMigrationTable(); + final String migrationTableName; + + if (existingMigration == null) { + migrationTableName = checkpointTableName + MIGRATION_TABLE_SUFFIX; + logger.info("Legacy checkpoint table detected; migrating via [{}]", migrationTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, migrationTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, migrationTableName); + migrateCheckpoints(checkpointTableName, migrationTableName); + } else { + migrationTableName = existingMigration; + logger.info("Found existing migration table [{}]; completing rename to [{}]", + migrationTableName, checkpointTableName); + } + + renameMigrationTable(migrationTableName); + } + + void renameMigrationTable(final String migrationTableName) { + if (acquireRenameLock(migrationTableName)) { + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableDeleted(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, migrationTableName, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, migrationTableName); + } else { + waitForTableRenamed(migrationTableName); + } + } + + private boolean acquireRenameLock(final String migrationTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final Map key = migrationMarkerKey(); + final Map values = Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build()); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(migrationTableName) + .key(key) + .updateExpression("SET renameOwner = :owner, renameStartedAt = :now") + .conditionExpression("attribute_not_exists(renameOwner)") + .expressionAttributeValues(values) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Acquired rename lock for migration table [{}]", migrationTableName); + return true; + } catch (final ConditionalCheckFailedException e) { + logger.debug("Rename lock already held for migration table [{}]", migrationTableName); + return false; + } catch (final ResourceNotFoundException e) { + logger.debug("Migration table [{}] already deleted; rename must be complete", migrationTableName); + return false; + } + } + + private boolean forceAcquireStaleRenameLock(final String migrationTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final long staleThreshold = now - RENAME_LOCK_STALE_MILLIS; + final Map key = migrationMarkerKey(); + final Map values = Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build(), + ":staleThreshold", AttributeValue.builder().n(String.valueOf(staleThreshold)).build()); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(migrationTableName) + .key(key) + .updateExpression("SET renameOwner = :owner, renameStartedAt = :now") + .conditionExpression("attribute_exists(renameOwner) AND renameStartedAt < :staleThreshold") + .expressionAttributeValues(values) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Force-acquired stale rename lock for migration table [{}]", migrationTableName); + return true; + } catch (final ConditionalCheckFailedException | ResourceNotFoundException e) { + return false; + } + } + + private Map migrationMarkerKey() { + return Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build()); + } + + private void waitForTableRenamed(final String migrationTableName) { + for (int i = 0; i < RENAME_POLL_MAX_ATTEMPTS; i++) { + if (CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName) + == CheckpointTableUtils.TableSchema.NEW) { + logger.info("Migration table rename complete; table [{}] is now available", checkpointTableName); + return; + } + + try { + Thread.sleep(RENAME_POLL_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for migration table rename to complete", e); + } + } + + logger.warn("Timed out waiting for migration table rename; attempting stale lock takeover"); + if (forceAcquireStaleRenameLock(migrationTableName)) { + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableDeleted(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.createNewSchemaTable(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.waitForTableActive(dynamoDbClient, logger, checkpointTableName); + CheckpointTableUtils.copyCheckpointItems(dynamoDbClient, logger, migrationTableName, checkpointTableName); + CheckpointTableUtils.deleteTable(dynamoDbClient, logger, migrationTableName); + } else if (CheckpointTableUtils.getTableSchema(dynamoDbClient, checkpointTableName) + == CheckpointTableUtils.TableSchema.NEW) { + logger.info("Migration table rename completed during takeover attempt"); + } else { + throw new ProcessException( + "Unable to complete migration table rename for [%s]".formatted(checkpointTableName)); + } + } + + private void migrateCheckpoints(final String sourceTableName, final String targetTableName) { + if (isMigrationComplete(targetTableName)) { + logger.debug("Legacy checkpoint migration already complete for stream [{}]", streamName); + return; + } + + if (acquireMigrationLock(targetTableName)) { + performMigration(sourceTableName, targetTableName); + } else { + waitForMigrationComplete(targetTableName); + if (!isMigrationComplete(targetTableName)) { + logger.warn("Migration wait timed out with stale lock; attempting takeover for stream [{}]", streamName); + if (forceAcquireStaleMigrationLock(targetTableName)) { + performMigration(sourceTableName, targetTableName); + } else { + throw new ProcessException( + "Unable to acquire migration lock for stream [%s]".formatted(streamName)); + } + } + } + } + + private void performMigration(final String sourceTableName, final String targetTableName) { + try { + migrateLegacyCheckpoints(sourceTableName, targetTableName); + markMigrationComplete(targetTableName); + } catch (final Exception e) { + clearMigrationLock(targetTableName); + throw new ProcessException("Failed to migrate legacy checkpoints to [%s]".formatted(targetTableName), e); + } + } + + private boolean acquireMigrationLock(final String targetTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final PutItemRequest request = PutItemRequest.builder() + .tableName(targetTableName) + .item(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build(), + MIGRATION_STATUS_ATTRIBUTE, AttributeValue.builder().s(MIGRATION_STATUS_IN_PROGRESS).build(), + "migrationOwner", AttributeValue.builder().s(nodeId).build(), + "migrationStartedAt", AttributeValue.builder().n(String.valueOf(now)).build(), + "lastUpdateTimestamp", AttributeValue.builder().n(String.valueOf(now)).build())) + .conditionExpression("attribute_not_exists(streamName)") + .build(); + dynamoDbClient.putItem(request); + logger.info("Acquired checkpoint migration lock for stream [{}]", streamName); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } + } + + private boolean forceAcquireStaleMigrationLock(final String targetTableName) { + try { + final long now = Instant.now().toEpochMilli(); + final long staleThreshold = now - MIGRATION_LOCK_STALE_MILLIS; + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("SET migrationOwner = :owner, migrationStartedAt = :now, lastUpdateTimestamp = :now") + .conditionExpression("migrationStatus = :inProgress AND migrationStartedAt < :staleThreshold") + .expressionAttributeValues(Map.of( + ":owner", AttributeValue.builder().s(nodeId).build(), + ":now", AttributeValue.builder().n(String.valueOf(now)).build(), + ":inProgress", AttributeValue.builder().s(MIGRATION_STATUS_IN_PROGRESS).build(), + ":staleThreshold", AttributeValue.builder().n(String.valueOf(staleThreshold)).build())) + .build(); + dynamoDbClient.updateItem(request); + logger.info("Force-acquired stale migration lock for stream [{}]", streamName); + return true; + } catch (final ConditionalCheckFailedException e) { + return false; + } + } + + private void migrateLegacyCheckpoints(final String sourceTableName, final String targetTableName) { + logger.info("Starting legacy checkpoint migration from [{}] to [{}] for stream [{}]", + sourceTableName, targetTableName, streamName); + + Map exclusiveStartKey = null; + int scanned = 0; + int migrated = 0; + int skippedMissingAttr = 0; + int skippedNonNumeric = 0; + do { + final ScanRequest scanRequest = ScanRequest.builder() + .tableName(sourceTableName) + .exclusiveStartKey(exclusiveStartKey) + .build(); + final ScanResponse scanResponse = dynamoDbClient.scan(scanRequest); + + for (final Map item : scanResponse.items()) { + scanned++; + final AttributeValue leaseKeyAttr = item.get(LEGACY_LEASE_KEY_ATTRIBUTE); + final AttributeValue checkpointAttr = item.get(LEGACY_CHECKPOINT_ATTRIBUTE); + if (leaseKeyAttr == null || checkpointAttr == null) { + skippedMissingAttr++; + logger.warn("Skipping legacy item missing leaseKey or checkpoint: keys={}", item.keySet()); + continue; + } + + final String shardId = extractShardId(leaseKeyAttr.s()); + if (shardId == null || shardId.isEmpty()) { + skippedMissingAttr++; + continue; + } + + final String checkpoint = checkpointAttr.s(); + if (!isValidSequenceNumber(checkpoint)) { + skippedNonNumeric++; + logger.warn("Skipping non-numeric legacy checkpoint [{}] for shard {}", checkpoint, shardId); + continue; + } + + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(shardId).build())) + .updateExpression("SET sequenceNumber = :seq, lastUpdateTimestamp = :ts") + .expressionAttributeValues(Map.of( + ":seq", AttributeValue.builder().s(checkpoint).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(request); + migrated++; + } + + exclusiveStartKey = scanResponse.lastEvaluatedKey(); + } while (exclusiveStartKey != null && !exclusiveStartKey.isEmpty()); + + logger.info("Legacy checkpoint migration complete for stream [{}]: scanned={}, migrated={}, skippedNonNumeric={}, skippedMissingAttr={}", + streamName, scanned, migrated, skippedNonNumeric, skippedMissingAttr); + } + + private static String extractShardId(final String legacyLeaseKey) { + if (legacyLeaseKey == null || legacyLeaseKey.isEmpty()) { + return null; + } + final int separatorIndex = legacyLeaseKey.lastIndexOf(':'); + if (separatorIndex >= 0 && separatorIndex + 1 < legacyLeaseKey.length()) { + return legacyLeaseKey.substring(separatorIndex + 1); + } + return legacyLeaseKey; + } + + private static boolean isValidSequenceNumber(final String value) { + if (value == null || value.isEmpty()) { + return false; + } + for (int idx = 0; idx < value.length(); idx++) { + if (!Character.isDigit(value.charAt(idx))) { + return false; + } + } + return true; + } + + private void markMigrationComplete(final String targetTableName) { + final long now = Instant.now().toEpochMilli(); + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("SET migrationStatus = :status, migrationCompletedAt = :doneAt, lastUpdateTimestamp = :ts REMOVE migrationOwner") + .expressionAttributeValues(Map.of( + ":status", AttributeValue.builder().s(MIGRATION_STATUS_COMPLETE).build(), + ":doneAt", AttributeValue.builder().n(String.valueOf(now)).build(), + ":ts", AttributeValue.builder().n(String.valueOf(now)).build())) + .build(); + dynamoDbClient.updateItem(request); + } + + private void clearMigrationLock(final String targetTableName) { + final UpdateItemRequest request = UpdateItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .updateExpression("REMOVE migrationStatus, migrationOwner, migrationStartedAt") + .build(); + dynamoDbClient.updateItem(request); + } + + private boolean isMigrationComplete(final String targetTableName) { + final GetItemResponse response = dynamoDbClient.getItem(GetItemRequest.builder() + .tableName(targetTableName) + .key(Map.of( + "streamName", AttributeValue.builder().s(streamName).build(), + "shardId", AttributeValue.builder().s(CheckpointTableUtils.MIGRATION_MARKER_SHARD_ID).build())) + .build()); + if (!response.hasItem()) { + return false; + } + final AttributeValue status = response.item().get(MIGRATION_STATUS_ATTRIBUTE); + return status != null && MIGRATION_STATUS_COMPLETE.equals(status.s()); + } + + private void waitForMigrationComplete(final String targetTableName) { + for (int attempt = 0; attempt < MIGRATION_WAIT_MAX_ATTEMPTS; attempt++) { + if (isMigrationComplete(targetTableName)) { + logger.info("Observed completed checkpoint migration for stream [{}]", streamName); + return; + } + + try { + Thread.sleep(MIGRATION_WAIT_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ProcessException("Interrupted while waiting for checkpoint migration to complete", e); + } + } + + logger.warn("Timed out waiting for checkpoint migration to complete for stream [{}]; will check for stale lock", + streamName); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java deleted file mode 100644 index c4d1522a350d..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBuffer.java +++ /dev/null @@ -1,725 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import jakarta.annotation.Nullable; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferLease; -import software.amazon.kinesis.exceptions.InvalidStateException; -import software.amazon.kinesis.exceptions.KinesisClientLibDependencyException; -import software.amazon.kinesis.exceptions.ShutdownException; -import software.amazon.kinesis.exceptions.ThrottlingException; -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Queue; -import java.util.Random; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - -import static java.util.Collections.emptyList; - -/** - * A record buffer which limits the maximum memory usage across all shard buffers. - * If the memory limit is reached, adding new records will block until enough memory is freed. - */ -final class MemoryBoundRecordBuffer implements RecordBuffer.ForKinesisClientLibrary, RecordBuffer.ForProcessor { - - private final ComponentLog logger; - - private final long checkpointIntervalMillis; - private final BlockingMemoryTracker memoryTracker; - - private final AtomicLong bufferIdCounter = new AtomicLong(0); - - /** - * All shard buffers stored by their ids. - *

- * When a buffer is invalidated, it is removed from this map, but its id may still be present in the buffersToLease queue. - * Since the buffer can be invalidated concurrently, it's possible for some buffer operations to be called - * after the buffer was removed from this map. In that case the operations should take no effect. - */ - private final ConcurrentMap shardBuffers = new ConcurrentHashMap<>(); - - /** - * A queue with ids shard buffers available for leasing. - *

- * Note: when a buffer is invalidated its id is NOT removed from the queue immediately. - */ - private final Queue buffersToLease = new ConcurrentLinkedQueue<>(); - - MemoryBoundRecordBuffer(final ComponentLog logger, final long maxMemoryBytes, final Duration checkpointInterval) { - this.logger = logger; - this.memoryTracker = new BlockingMemoryTracker(logger, maxMemoryBytes); - this.checkpointIntervalMillis = checkpointInterval.toMillis(); - } - - @Override - public ShardBufferId createBuffer(final String shardId) { - final ShardBufferId id = new ShardBufferId(shardId, bufferIdCounter.getAndIncrement()); - - logger.debug("Creating new buffer for shard {} with id {}", shardId, id); - - shardBuffers.put(id, new ShardBuffer(id, logger, checkpointIntervalMillis)); - buffersToLease.add(id); - return id; - } - - @Override - public void addRecords(final ShardBufferId bufferId, final List records, final RecordProcessorCheckpointer checkpointer) { - if (records.isEmpty()) { - return; - } - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot add records with sequence and subsequence numbers: {}.{} - {}.{}", - bufferId, - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber()); - return; - } - - final RecordBatch recordBatch = new RecordBatch(records, checkpointer, calculateMemoryUsage(records)); - memoryTracker.reserveMemory(recordBatch, bufferId); - final boolean addedRecords = buffer.offer(recordBatch); - - if (addedRecords) { - logger.debug("Successfully added records with sequence and subsequence numbers: {}.{} - {}.{} to buffer with id {}", - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber(), - bufferId); - } else { - logger.debug("Buffer with id {} was invalidated. Cannot add records with sequence and subsequence numbers: {}.{} - {}.{}", - bufferId, - records.getFirst().sequenceNumber(), - records.getFirst().subSequenceNumber(), - records.getLast().sequenceNumber(), - records.getLast().subSequenceNumber()); - // If the buffer was invalidated, we should free memory reserved for these records. - memoryTracker.freeMemory(List.of(recordBatch), bufferId); - } - } - - @Override - public void checkpointEndedShard(final ShardBufferId bufferId, final RecordProcessorCheckpointer checkpointer) { - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot checkpoint the ended shard", bufferId); - return; - } - - logger.debug("Finishing consumption for buffer {}. Checkpointing the ended shard", bufferId); - buffer.checkpointEndedShard(checkpointer); - - logger.debug("Removing buffer with id {} after successful ended shard checkpoint", bufferId); - shardBuffers.remove(bufferId); - } - - @Override - public void shutdownShardConsumption(final ShardBufferId bufferId, final RecordProcessorCheckpointer checkpointer) { - final ShardBuffer buffer = shardBuffers.remove(bufferId); - - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot shutdown shard consumption", bufferId); - } else { - logger.debug("Shutting down the buffer {}. Checkpointing last consumed record", bufferId); - final Collection invalidatedBatches = buffer.shutdownBuffer(checkpointer); - memoryTracker.freeMemory(invalidatedBatches, bufferId); - } - } - - @Override - public void consumerLeaseLost(final ShardBufferId bufferId) { - final ShardBuffer buffer = shardBuffers.remove(bufferId); - - if (buffer == null) { - logger.debug("Buffer with id {} not found. Ignoring lease lost event", bufferId); - } else { - logger.debug("Lease lost for buffer {}: Invalidating", bufferId); - final Collection invalidatedBatches = buffer.invalidate(); - memoryTracker.freeMemory(invalidatedBatches, bufferId); - } - } - - @Override - public Optional acquireBufferLease() { - final Set seenBuffers = new HashSet<>(); - - while (true) { - final ShardBufferId bufferId = buffersToLease.poll(); - if (bufferId == null) { - // The queue is empty or all buffers were seen already. Nothing to consume. - return Optional.empty(); - } - - if (seenBuffers.contains(bufferId)) { - // If the same buffer is seen again, there is a high chance we iterated through most of the buffers and didn't find any that isn't empty. - // To avoid burning CPU we return empty here, even if some buffer received records in the meantime. It will be picked up in the next iteration. - buffersToLease.add(bufferId); - return Optional.empty(); - } - - final ShardBuffer buffer = shardBuffers.get(bufferId); - - if (buffer == null) { - // By the time the bufferId is polled, it might have been invalidated. No need to return it to the queue. - logger.debug("Buffer with id {} was removed while polling for lease. Continuing to poll", bufferId); - } else if (buffer.isEmpty()) { - seenBuffers.add(bufferId); - buffersToLease.add(bufferId); - logger.debug("Buffer with id {} is empty. Continuing to poll", bufferId); - } else { - logger.debug("Acquired lease for buffer {}", bufferId); - return Optional.of(new Lease(bufferId)); - } - } - } - - @Override - public List consumeRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to consume records from a buffer that was already returned to the pool. Ignoring"); - return emptyList(); - } - - final ShardBufferId bufferId = lease.bufferId(); - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot consume records", bufferId); - return emptyList(); - } - - return buffer.consumeRecords(); - } - - @Override - public void commitConsumedRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to commit records from a buffer that was already returned to the pool. Ignoring"); - return; - } - - final ShardBufferId bufferId = lease.bufferId(); - - final ShardBuffer buffer = shardBuffers.get(bufferId); - if (buffer == null) { - logger.debug("Buffer with id {} not found. Cannot commit consumed records", bufferId); - return; - } - - logger.debug("Committing consumed records for buffer {}", bufferId); - final List consumedBatches = buffer.commitConsumedRecords(); - memoryTracker.freeMemory(consumedBatches, bufferId); - } - - @Override - public void rollbackConsumedRecords(final Lease lease) { - if (lease.isReturnedToPool()) { - logger.warn("Attempting to rollback records from a buffer that was already returned to the pool. Ignoring"); - return; - } - - final ShardBufferId bufferId = lease.bufferId(); - final ShardBuffer buffer = shardBuffers.get(bufferId); - - if (buffer != null) { - logger.debug("Rolling back consumed records for buffer {}", bufferId); - buffer.rollbackConsumedRecords(); - } - } - - @Override - public void returnBufferLease(final Lease lease) { - if (lease.returnToPool()) { - final ShardBufferId bufferId = lease.bufferId(); - buffersToLease.add(bufferId); - logger.debug("The buffer {} is available for lease again", bufferId); - } else { - logger.warn("Attempting to return a buffer that was already returned to the pool. Ignoring"); - } - } - - static final class Lease implements ShardBufferLease { - - private final ShardBufferId bufferId; - private final AtomicBoolean returnedToPool = new AtomicBoolean(false); - - private Lease(final ShardBufferId bufferId) { - this.bufferId = bufferId; - } - - @Override - public String shardId() { - return bufferId.shardId(); - } - - private ShardBufferId bufferId() { - return bufferId; - } - - private boolean isReturnedToPool() { - return returnedToPool.get(); - } - - /** - * Marks the lease as returned to the pool. - * @return true if the lease was not returned before, false otherwise. - */ - private boolean returnToPool() { - final boolean wasReturned = returnedToPool.getAndSet(true); - return !wasReturned; - } - } - - /** - * A memory tracker which blocks a thread when the memory usage exceeds the allowed maximum. - *

- * In order to make progress, the memory consumption may exceed the limit, but any new records will not be accepted. - * This is done to support the case when a single record batch is larger than the allowed memory limit. - */ - private static class BlockingMemoryTracker { - - private static final long AWAIT_MILLIS = 100; - - private final ComponentLog logger; - - private final long maxMemoryBytes; - - private final AtomicLong consumedMemoryBytes = new AtomicLong(0); - /** - * Whenever memory is freed a latch opens. Then replaced with a new one. - */ - private final AtomicReference memoryAvailableLatch = new AtomicReference<>(new CountDownLatch(1)); - - BlockingMemoryTracker(final ComponentLog logger, final long maxMemoryBytes) { - this.logger = logger; - this.maxMemoryBytes = maxMemoryBytes; - } - - void reserveMemory(final RecordBatch recordBatch, final ShardBufferId bufferId) { - final long consumedBytes = recordBatch.batchSizeBytes(); - - if (consumedBytes == 0) { - logger.debug("The batch for buffer {} is empty. No need to reserve memory", bufferId); - return; - } - - while (true) { - final long currentlyConsumedBytes = consumedMemoryBytes.get(); - - if (currentlyConsumedBytes >= maxMemoryBytes) { - // Not enough memory available, need to wait. - try { - memoryAvailableLatch.get().await(AWAIT_MILLIS, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for available memory in RecordBuffer", e); - } - } else { - final long newConsumedBytes = currentlyConsumedBytes + consumedBytes; - if (consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, newConsumedBytes)) { - logger.debug("Reserved {} bytes for {} records for buffer {}. Total consumed memory: {} bytes", - consumedBytes, recordBatch.size(), bufferId, newConsumedBytes); - break; - } - // If we're here, the compare and set operation failed, as another thread has modified the gauge in meantime. - // Retrying the operation. - } - } - } - - void freeMemory(final Collection consumedBatches, final ShardBufferId bufferId) { - if (consumedBatches.isEmpty()) { - logger.debug("No batches were consumed from buffer {}. No need to free memory", bufferId); - return; - } - - long freedBytes = 0; - for (final RecordBatch batch : consumedBatches) { - freedBytes += batch.batchSizeBytes(); - } - - while (true) { - final long currentlyConsumedBytes = consumedMemoryBytes.get(); - if (currentlyConsumedBytes < freedBytes) { - throw new IllegalStateException("Attempting to free more memory than currently used"); - } - - final long newTotal = currentlyConsumedBytes - freedBytes; - if (consumedMemoryBytes.compareAndSet(currentlyConsumedBytes, newTotal)) { - logger.debug("Freed {} bytes for {} batches from buffer {}. Total consumed memory: {} bytes", - freedBytes, consumedBatches.size(), bufferId, newTotal); - - final CountDownLatch oldLatch = memoryAvailableLatch.getAndSet(new CountDownLatch(1)); - oldLatch.countDown(); // Release any waiting threads for free memory. - break; - } - // If we're here, the compare and set operation failed, as another thread has modified the gauge in meantime. - // Retrying the operation. - } - } - } - - private record RecordBatch(List records, - RecordProcessorCheckpointer checkpointer, - long batchSizeBytes) { - int size() { - return records.size(); - } - } - - private long calculateMemoryUsage(final Collection records) { - long totalBytes = 0; - for (final KinesisClientRecord record : records) { - final ByteBuffer data = record.data(); - if (data != null) { - totalBytes += data.capacity(); - } - } - return totalBytes; - } - - /** - * ShardBuffer stores all record batches for a single shard in two queues: - * - IN_PROGRESS: record batches that have been consumed but not yet checkpointed. - * - PENDING: record batches that have been added but not yet consumed. - *

- * When consuming records all PENDING batches are moved to IN_PROGRESS. - * After a successful checkpoint all IN_PROGRESS batches are cleared. - * After a rollback all IN_PROGRESS batches are kept, allowing to retry consumption. - *

- * Each batch preserves the original grouping of records as provided by Kinesis - * along with their associated checkpointer, ensuring atomicity. - */ - private static class ShardBuffer { - - private static final long AWAIT_MILLIS = 100; - - // Retry configuration. - private static final int MAX_RETRY_ATTEMPTS = 5; - private static final long BASE_RETRY_DELAY_MILLIS = 100; - private static final long MAX_RETRY_DELAY_MILLIS = 10_000; - private static final Random RANDOM = new Random(); - - private final ShardBufferId bufferId; - private final ComponentLog logger; - - private final long checkpointIntervalMillis; - private volatile long nextCheckpointTimeMillis; - - /** - * A last record checkpointer and sequence number that was ignored due to the checkpoint interval. - * If null, the last checkpoint was successful or no checkpoint was attempted yet. - */ - private volatile @Nullable LastIgnoredCheckpoint lastIgnoredCheckpoint; - - /** - * Queues for managing record batches with their checkpointers in different states. - */ - private final Queue inProgressBatches = new ConcurrentLinkedQueue<>(); - private final Queue pendingBatches = new ConcurrentLinkedQueue<>(); - /** - * Counter for tracking the number of batches in the buffer. Can be larger than the number of batches in the queues. - */ - private final AtomicInteger batchesCount = new AtomicInteger(0); - - /** - * A countdown latch that is used to signal when the buffer becomes empty. Used when ShardBuffer should be closed. - */ - private volatile @Nullable CountDownLatch emptyBufferLatch = null; - private final AtomicBoolean invalidated = new AtomicBoolean(false); - - ShardBuffer(final ShardBufferId bufferId, final ComponentLog logger, final long checkpointIntervalMillis) { - this.bufferId = bufferId; - this.logger = logger; - this.checkpointIntervalMillis = checkpointIntervalMillis; - this.nextCheckpointTimeMillis = System.currentTimeMillis() + checkpointIntervalMillis; - } - - /** - * @param recordBatch record batch with records to add. - * @return true if the records were added successfully, false if a buffer was invalidated. - */ - boolean offer(final RecordBatch recordBatch) { - if (invalidated.get()) { - return false; - } - - // Batches count must be always equal to or larger than the number of batches in the queues. - // Thus, the ordering of the operations. - batchesCount.incrementAndGet(); - pendingBatches.offer(recordBatch); - - return true; - } - - List consumeRecords() { - if (invalidated.get()) { - return emptyList(); - } - - RecordBatch pendingBatch; - while ((pendingBatch = pendingBatches.poll()) != null) { - inProgressBatches.offer(pendingBatch); - } - - final List recordsToConsume = new ArrayList<>(); - for (final RecordBatch batch : inProgressBatches) { - recordsToConsume.addAll(batch.records()); - } - - return recordsToConsume; - } - - List commitConsumedRecords() { - if (invalidated.get()) { - return emptyList(); - } - - final List checkpointedBatches = new ArrayList<>(); - RecordBatch batch; - while ((batch = inProgressBatches.poll()) != null) { - checkpointedBatches.add(batch); - } - - if (checkpointedBatches.isEmpty()) { - // The buffer could be invalidated in the meantime, or no records were consumed. - return emptyList(); - } - - // Batches count must always be equal to or larger than the number of batches in the queues. - // To achieve so, the count is decreased only after the queue has been emptied. - batchesCount.addAndGet(-checkpointedBatches.size()); - - final RecordProcessorCheckpointer lastBatchCheckpointer = checkpointedBatches.getLast().checkpointer(); - final KinesisClientRecord lastRecord = checkpointedBatches.getLast().records().getLast(); - - if (System.currentTimeMillis() >= nextCheckpointTimeMillis) { - checkpointSequenceNumber(lastBatchCheckpointer, lastRecord.sequenceNumber(), lastRecord.subSequenceNumber()); - nextCheckpointTimeMillis = System.currentTimeMillis() + checkpointIntervalMillis; - lastIgnoredCheckpoint = null; - } else { - // Saving the checkpointer for later, in case shutdown happens before the next checkpoint. - lastIgnoredCheckpoint = new LastIgnoredCheckpoint(lastBatchCheckpointer, lastRecord.sequenceNumber(), lastRecord.subSequenceNumber()); - } - - final CountDownLatch localEmptyBufferLatch = this.emptyBufferLatch; - if (localEmptyBufferLatch != null && isEmpty()) { - // If the latch is not null, it means we are waiting for the buffer to become empty. - localEmptyBufferLatch.countDown(); - } - - return checkpointedBatches; - } - - void rollbackConsumedRecords() { - if (invalidated.get()) { - return; - } - - for (final RecordBatch recordBatch : inProgressBatches) { - for (final KinesisClientRecord record : recordBatch.records()) { - record.data().rewind(); - } - } - } - - void checkpointEndedShard(final RecordProcessorCheckpointer checkpointer) { - while (true) { - if (invalidated.get()) { - return; - } - - // Setting the latch first, so that if the buffer is being emptied concurrently in commitConsumedRecords - // the latch is guaranteed to be visible to the commitConsumedRecords, and, therefore, opened. - // This will eliminate unnecessary downtimes when waiting for a latch to be opened. - final CountDownLatch localEmptyBufferLatch = new CountDownLatch(1); - this.emptyBufferLatch = localEmptyBufferLatch; - - if (batchesCount.get() == 0) { - // Buffer is empty, perform final checkpoint. - checkpointLastReceivedRecord(checkpointer); - return; - } - - // Wait for the records to be consumed first. - try { - localEmptyBufferLatch.await(AWAIT_MILLIS, TimeUnit.MILLISECONDS); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for records to be consumed", e); - } - } - } - - Collection shutdownBuffer(final RecordProcessorCheckpointer checkpointer) { - if (invalidated.getAndSet(true)) { - return emptyList(); - } - - if (batchesCount.get() == 0) { - checkpointLastReceivedRecord(checkpointer); - return emptyList(); - } - - // If there are still records in the buffer, checkpointing with the latest provided checkpointer is not safe. - // But, if the records were committed without checkpointing in the past, we can checkpoint them now. - final LastIgnoredCheckpoint ignoredCheckpoint = this.lastIgnoredCheckpoint; - if (ignoredCheckpoint != null) { - checkpointSequenceNumber( - ignoredCheckpoint.checkpointer(), - ignoredCheckpoint.sequenceNumber(), - ignoredCheckpoint.subSequenceNumber() - ); - } - - return drainInvalidatedBatches(); - } - - Collection invalidate() { - if (invalidated.getAndSet(true)) { - return emptyList(); - } - - return drainInvalidatedBatches(); - } - - private Collection drainInvalidatedBatches() { - if (!invalidated.get()) { - throw new IllegalStateException("Unable to drain invalidated batches for valid shard buffer: " + bufferId); - } - - final List batches = new ArrayList<>(); - RecordBatch batch; - // If both consumeRecords and drainInvalidatedBatches are called concurrently, invalidation must always consume all batches. - // Since consumeRecords moves batches from pending to in_progress, during invalidation pending batches should be drained first. - while ((batch = pendingBatches.poll()) != null) { - batches.add(batch); - } - while ((batch = inProgressBatches.poll()) != null) { - batches.add(batch); - } - - // No need to adjust batchesCount after invalidation. - - return batches; - } - - /** - * Checks if the buffer has any records. Can produce false negatives. - * - * @return whether there are any records in the buffer. - */ - boolean isEmpty() { - return invalidated.get() || batchesCount.get() == 0; - } - - private void checkpointLastReceivedRecord(final RecordProcessorCheckpointer checkpointer) { - logger.debug("Performing checkpoint for buffer with id {}. Checkpointing the last received record", bufferId); - - checkpointSafely(checkpointer::checkpoint); - } - - private void checkpointSequenceNumber(final RecordProcessorCheckpointer checkpointer, final String sequenceNumber, final long subSequenceNumber) { - logger.debug("Performing checkpoint for buffer with id {}. Sequence number: [{}], sub sequence number: [{}]", - bufferId, sequenceNumber, subSequenceNumber); - - checkpointSafely(() -> checkpointer.checkpoint(sequenceNumber, subSequenceNumber)); - } - - /** - * Performs checkpointing using exponential backoff and jitter, if needed. - * - * @param checkpointAction the action which performs the checkpointing. - */ - private void checkpointSafely(final CheckpointAction checkpointAction) { - for (int attempt = 1; attempt <= MAX_RETRY_ATTEMPTS; attempt++) { - try { - checkpointAction.doCheckpoint(); - if (attempt > 1) { - logger.debug("Checkpoint succeeded on attempt {}", attempt); - } - return; - } catch (final ThrottlingException | InvalidStateException | KinesisClientLibDependencyException e) { - if (attempt == MAX_RETRY_ATTEMPTS) { - logger.error("Failed to checkpoint after {} attempts, giving up", MAX_RETRY_ATTEMPTS, e); - return; - } - - final long delayMillis = calculateRetryDelay(attempt); - - logger.debug("Checkpoint failed on attempt {} with {}, retrying in {} ms", - attempt, e.getMessage(), delayMillis); - - try { - Thread.sleep(delayMillis); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - logger.warn("Thread interrupted while waiting to retry checkpoint. Exiting retry loop", ie); - return; - } - } catch (final ShutdownException e) { - logger.warn("Failed to checkpoint records due to shutdown. Ignoring checkpoint", e); - return; - } catch (final RuntimeException e) { - logger.warn("Failed to checkpoint records due to an error. Ignoring checkpoint", e); - return; - } - } - } - - private long calculateRetryDelay(final int attempt) { - final long desiredBaseDelayMillis = BASE_RETRY_DELAY_MILLIS * (1L << (attempt - 1)); - final long baseDelayMillis = Math.min(desiredBaseDelayMillis, MAX_RETRY_DELAY_MILLIS); - final long jitterMillis = RANDOM.nextLong(baseDelayMillis / 4); // Up to 25% jitter. - return baseDelayMillis + jitterMillis; - } - - private interface CheckpointAction { - - /** - * Throws the same set of exceptions as {@link RecordProcessorCheckpointer#checkpoint()} and {@link RecordProcessorCheckpointer#checkpoint(String, long)}. - */ - void doCheckpoint() throws KinesisClientLibDependencyException, InvalidStateException, ThrottlingException, ShutdownException, IllegalArgumentException; - } - - private record LastIgnoredCheckpoint(RecordProcessorCheckpointer checkpointer, String sequenceNumber, long subSequenceNumber) { - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java new file mode 100644 index 000000000000..e21dd3c7ccb2 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClient.java @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; +import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; +import software.amazon.awssdk.services.kinesis.model.LimitExceededException; +import software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; + +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Shared-throughput Kinesis consumer that runs a continuous background fetch loop per shard. + * Each owned shard gets its own virtual thread that repeatedly calls GetRecords and enqueues + * results for the processor, mirroring the producer-consumer architecture of the KCL Scheduler. + * This keeps data flowing between onTrigger invocations rather than fetching on-demand. + * + *

Concurrency is bounded by a semaphore with {@value #MAX_CONCURRENT_FETCHES} permits so + * that at most that many GetRecords HTTP calls are in flight at any moment, preventing + * connection-pool exhaustion. A second fair semaphore with {@value #MAX_QUEUED_RESULTS} + * permits ensures that fetch threads block when the result queue is full, with FIFO ordering + * guaranteeing that all shard threads get equal opportunity to enqueue results. + */ +final class PollingKinesisClient extends KinesisConsumerClient { + + private static final long DEFAULT_EMPTY_SHARD_BACKOFF_NANOS = TimeUnit.MILLISECONDS.toNanos(500); + private static final long DEFAULT_ERROR_BACKOFF_NANOS = TimeUnit.SECONDS.toNanos(2); + static final int MAX_QUEUED_RESULTS = 200; + static final int MAX_CONCURRENT_FETCHES = 25; + + private final ExecutorService fetchExecutor = Executors.newVirtualThreadPerTaskExecutor(); + private final Map pollingShardStates = new ConcurrentHashMap<>(); + private final Semaphore fetchPermits = new Semaphore(MAX_CONCURRENT_FETCHES, true); + private final Semaphore queuePermits = new Semaphore(MAX_QUEUED_RESULTS, true); + private final long emptyShardBackoffNanos; + private final long errorBackoffNanos; + PollingKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger) { + this(kinesisClient, logger, DEFAULT_EMPTY_SHARD_BACKOFF_NANOS, DEFAULT_ERROR_BACKOFF_NANOS); + } + + PollingKinesisClient(final KinesisClient kinesisClient, final ComponentLog logger, + final long emptyShardBackoffNanos, final long errorBackoffNanos) { + super(kinesisClient, logger); + this.emptyShardBackoffNanos = emptyShardBackoffNanos; + this.errorBackoffNanos = errorBackoffNanos; + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, + final String initialStreamPosition, final KinesisShardManager shardManager) { + if (fetchExecutor.isShutdown()) { + return; + } + + for (final Shard shard : shards) { + final String shardId = shard.shardId(); + final PollingShardState existing = pollingShardStates.get(shardId); + if (existing == null) { + final PollingShardState state = new PollingShardState(); + if (pollingShardStates.putIfAbsent(shardId, state) == null && state.tryStartLoop()) { + launchFetchLoop(state, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } + } else if (!existing.isExhausted() && !existing.isStopped() && !existing.isLoopRunning() + && existing.tryStartLoop()) { + logger.warn("Restarting dead fetch loop for stream [{}] shard [{}]", streamName, shardId); + launchFetchLoop(existing, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } + } + } + + @Override + boolean hasPendingFetches() { + if (hasQueuedResults()) { + return true; + } + for (final PollingShardState state : pollingShardStates.values()) { + if (!state.isExhausted() && !state.isStopped()) { + return true; + } + } + return false; + } + + @Override + void acknowledgeResults(final List results) { + } + + @Override + void rollbackResults(final List results) { + for (final ShardFetchResult result : results) { + final PollingShardState state = pollingShardStates.get(result.shardId()); + if (state != null) { + resetAndDrainShard(result.shardId(), state); + } + } + } + + @Override + void removeUnownedShards(final Set ownedShards) { + pollingShardStates.entrySet().removeIf(entry -> { + if (!ownedShards.contains(entry.getKey())) { + entry.getValue().stop(); + return true; + } + return false; + }); + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + if (!shouldLogDiagnostics()) { + return; + } + + int active = 0; + int exhausted = 0; + int stopped = 0; + int dead = 0; + for (final PollingShardState state : pollingShardStates.values()) { + if (state.isExhausted()) { + exhausted++; + } else if (state.isStopped()) { + stopped++; + } else if (!state.isLoopRunning()) { + dead++; + } else { + active++; + } + } + + logger.debug("Kinesis polling diagnostics: discoveredShards={}, ownedShards={}, queueDepth={}, " + + "fetchLoops={}, active={}, exhausted={}, stopped={}, dead={}, concurrentFetches={}", + cachedShardCount, ownedCount, totalQueuedResults(), pollingShardStates.size(), + active, exhausted, stopped, dead, MAX_CONCURRENT_FETCHES - fetchPermits.availablePermits()); + } + + @Override + protected void onResultPolled() { + queuePermits.release(); + } + + @Override + void close() { + for (final PollingShardState state : pollingShardStates.values()) { + state.stop(); + } + pollingShardStates.clear(); + fetchExecutor.shutdownNow(); + super.close(); + } + + private void launchFetchLoop(final PollingShardState state, final String shardId, + final String streamName, final int batchSize, final String initialStreamPosition, + final KinesisShardManager shardManager) { + final ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + try { + fetchExecutor.submit(() -> { + Thread.currentThread().setContextClassLoader(contextClassLoader); + try { + runFetchLoop(state, shardId, streamName, batchSize, initialStreamPosition, shardManager); + } catch (final Throwable t) { + if (!state.isStopped()) { + logger.error("Fetch loop for shard {} terminated unexpectedly", shardId, t); + } + } finally { + state.markLoopStopped(); + } + }); + } catch (final RejectedExecutionException e) { + state.markLoopStopped(); + logger.debug("Executor shut down; cannot start fetch loop for stream [{}] shard [{}]", streamName, shardId); + } + } + + private void runFetchLoop(final PollingShardState state, final String shardId, + final String streamName, final int batchSize, final String initialStreamPosition, + final KinesisShardManager shardManager) { + + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + + while (!Thread.currentThread().isInterrupted() && !state.isStopped()) { + try { + if (state.isExhausted()) { + return; + } + + if (state.isResetRequested()) { + state.clearReset(); + final int drained = drainShardQueue(shardId); + if (drained > 0) { + queuePermits.release(drained); + } + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + } + + if (state.getIterator() == null) { + state.setIterator(getShardIterator(state, streamName, shardId, initialStreamPosition, shardManager)); + if (state.getIterator() == null) { + sleepNanos(errorBackoffNanos); + continue; + } + } + + try { + queuePermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + boolean queuePermitConsumed = false; + try { + try { + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + + final GetRecordsResponse response; + try { + response = fetchRecords(shardId, state, batchSize); + } finally { + fetchPermits.release(); + } + if (response == null) { + continue; + } + + final List records = response.records(); + if (!records.isEmpty()) { + final long millisBehind = response.millisBehindLatest() != null ? response.millisBehindLatest() : -1; + queuePermitConsumed = enqueueIfActive(shardId, state, createFetchResult(shardId, records, millisBehind)); + } + + state.setIterator(response.nextShardIterator()); + if (state.getIterator() == null) { + state.markExhausted(); + return; + } + + if (records.isEmpty()) { + sleepNanos(emptyShardBackoffNanos); + } + } finally { + if (!queuePermitConsumed) { + queuePermits.release(); + } + } + } catch (final Exception e) { + if (!state.isStopped()) { + logger.warn("Unexpected error in fetch loop for shard [{}]; will retry", shardId, e); + state.setIterator(null); + sleepNanos(errorBackoffNanos); + } + } + } + } + + private GetRecordsResponse fetchRecords(final String shardId, final PollingShardState state, final int batchSize) { + final GetRecordsRequest request = GetRecordsRequest.builder() + .shardIterator(state.getIterator()) + .limit(batchSize) + .build(); + + try { + return kinesisClient.getRecords(request); + } catch (final ProvisionedThroughputExceededException | LimitExceededException e) { + logger.debug("GetRecords throttled for shard {}; will retry after backoff", shardId); + sleepNanos(errorBackoffNanos); + return null; + } catch (final ExpiredIteratorException e) { + logger.info("Shard iterator expired for shard {}; will re-acquire", shardId); + state.requestReset(); + sleepNanos(errorBackoffNanos); + return null; + } catch (final SdkClientException e) { + if (!state.isStopped()) { + logger.warn("GetRecords failed for shard {}; will retry with existing iterator", shardId, e); + sleepNanos(errorBackoffNanos); + } + return null; + } catch (final Exception e) { + if (!state.isStopped()) { + logger.error("GetRecords failed for shard {}", shardId, e); + state.requestReset(); + sleepNanos(errorBackoffNanos); + } + return null; + } + } + + private boolean enqueueIfActive(final String shardId, final PollingShardState state, final ShardFetchResult result) { + synchronized (getShardLock(shardId)) { + if (state.isResetRequested()) { + return false; + } + enqueueResult(result); + return true; + } + } + + private void resetAndDrainShard(final String shardId, final PollingShardState state) { + synchronized (getShardLock(shardId)) { + state.requestReset(); + final int drained = drainShardQueue(shardId); + if (drained > 0) { + queuePermits.release(drained); + } + } + } + + private static void sleepNanos(final long nanos) { + try { + TimeUnit.NANOSECONDS.sleep(nanos); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private String getShardIterator(final PollingShardState state, final String streamName, + final String shardId, final String initialStreamPosition, final KinesisShardManager shardManager) { + try { + fetchPermits.acquire(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + + try { + final String lastSequenceNumber; + try { + lastSequenceNumber = shardManager.readCheckpoint(shardId); + } catch (final Exception e) { + if (!state.isStopped()) { + logger.warn("Failed to read checkpoint for shard {}; will retry", shardId, e); + } + return null; + } + + final ShardIteratorType iteratorType; + final String startingSequenceNumber; + final Instant timestamp; + if (lastSequenceNumber != null) { + iteratorType = ShardIteratorType.AFTER_SEQUENCE_NUMBER; + startingSequenceNumber = lastSequenceNumber; + timestamp = null; + } else { + iteratorType = ShardIteratorType.fromValue(initialStreamPosition); + startingSequenceNumber = null; + timestamp = (iteratorType == ShardIteratorType.AT_TIMESTAMP) ? getTimestampForInitialPosition() : null; + } + + logger.debug("Getting shard iterator for shard {} with type={}, startingSeq={}, timestamp={}", + shardId, iteratorType, startingSequenceNumber, timestamp); + + final GetShardIteratorRequest.Builder iteratorRequestBuilder = GetShardIteratorRequest.builder() + .streamName(streamName) + .shardId(shardId) + .shardIteratorType(iteratorType); + + if (startingSequenceNumber != null) { + iteratorRequestBuilder.startingSequenceNumber(startingSequenceNumber); + } + if (timestamp != null) { + iteratorRequestBuilder.timestamp(timestamp); + } + + return kinesisClient.getShardIterator(iteratorRequestBuilder.build()).shardIterator(); + } catch (final Exception e) { + if (!state.isStopped()) { + logger.error("Failed to get shard iterator for shard {} in stream {}", shardId, streamName, e); + } + return null; + } finally { + fetchPermits.release(); + } + } + + static final class PollingShardState { + private volatile String currentIterator; + private volatile boolean shardExhausted; + private volatile boolean stopped; + private volatile boolean resetRequested; + private final AtomicBoolean loopRunning = new AtomicBoolean(); + + String getIterator() { + return currentIterator; + } + + void setIterator(final String iterator) { + currentIterator = iterator; + } + + boolean isExhausted() { + return shardExhausted; + } + + void markExhausted() { + shardExhausted = true; + } + + boolean isStopped() { + return stopped; + } + + void stop() { + stopped = true; + } + + boolean isResetRequested() { + return resetRequested; + } + + void requestReset() { + resetRequested = true; + } + + void clearReset() { + resetRequested = false; + } + + boolean tryStartLoop() { + return loopRunning.compareAndSet(false, true); + } + + void markLoopStopped() { + loopRunning.set(false); + } + + boolean isLoopRunning() { + return loopRunning.get(); + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java new file mode 100644 index 000000000000..db4b580a4798 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregator.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.WireFormat; +import software.amazon.awssdk.services.kinesis.model.Record; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Deaggregates KPL (Kinesis Producer Library) aggregated records into individual user records. + * + *

KPL aggregation packs multiple user records into a single Kinesis record using a protobuf + * envelope with a 4-byte magic header and a 16-byte MD5 trailer. Non-aggregated records + * pass through unchanged as a single {@link UserRecord} with {@code subSequenceNumber=0}. + * + *

If a record has the magic header but fails MD5 verification or protobuf parsing, it falls + * back to passthrough to avoid data loss. + * + *

We could make direct use of the KPL's protobuf definition and generated classes, + * but doing so requires bringing in 20+ transitive dependencies. Since the protobuf format is + * simple and well-documented, we implement a minimal custom parser using the protobuf wire format + * instead. Additionally, we have integration tests to verify compatibility. + * + * @see KPL Aggregation Format + */ +final class ProducerLibraryDeaggregator { + + static final byte[] KPL_MAGIC = {(byte) 0xF3, (byte) 0x89, (byte) 0x9A, (byte) 0xC2}; + private static final int MD5_DIGEST_LENGTH = 16; + private static final int MIN_AGGREGATED_LENGTH = KPL_MAGIC.length + MD5_DIGEST_LENGTH + 1; + + private static final int FIELD_PARTITION_KEY_TABLE = 1; + private static final int FIELD_EXPLICIT_HASH_KEY_TABLE = 2; + private static final int FIELD_RECORDS = 3; + + private static final int RECORD_FIELD_PARTITION_KEY_INDEX = 1; + private static final int RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX = 2; + private static final int RECORD_FIELD_DATA = 3; + + private ProducerLibraryDeaggregator() { + } + + /** + * Deaggregates a list of Kinesis records, expanding any KPL-aggregated records into + * their constituent sub-records. + * + * @param shardId the shard these records were fetched from + * @param records raw Kinesis records from the API + * @return list of deaggregated records preserving original order + */ + static List deaggregate(final String shardId, final List records) { + final List result = new ArrayList<>(); + for (final Record record : records) { + deaggregateRecord(shardId, record, result); + } + return result; + } + + private static void deaggregateRecord(final String shardId, final Record record, final List out) { + final byte[] data = record.data().asByteArrayUnsafe(); + + if (!isAggregated(data)) { + out.add(passthrough(shardId, record, data)); + return; + } + + final int protobufOffset = KPL_MAGIC.length; + final int protobufLength = data.length - KPL_MAGIC.length - MD5_DIGEST_LENGTH; + + if (!verifyMd5(data, protobufOffset, protobufLength)) { + out.add(passthrough(shardId, record, data)); + return; + } + + try { + parseAggregatedRecord(shardId, record, data, protobufOffset, protobufLength, out); + } catch (final Exception e) { + out.add(passthrough(shardId, record, data)); + } + } + + static boolean isAggregated(final byte[] data) { + if (data.length < MIN_AGGREGATED_LENGTH) { + return false; + } + return data[0] == KPL_MAGIC[0] + && data[1] == KPL_MAGIC[1] + && data[2] == KPL_MAGIC[2] + && data[3] == KPL_MAGIC[3]; + } + + private static boolean verifyMd5(final byte[] data, final int protobufOffset, final int protobufLength) { + final MessageDigest md5 = getMd5Digest(); + md5.update(data, protobufOffset, protobufLength); + final byte[] computed = md5.digest(); + final int md5Offset = protobufOffset + protobufLength; + return Arrays.equals(computed, 0, MD5_DIGEST_LENGTH, data, md5Offset, md5Offset + MD5_DIGEST_LENGTH); + } + + private static MessageDigest getMd5Digest() { + try { + return MessageDigest.getInstance("MD5"); + } catch (final NoSuchAlgorithmException e) { + throw new IllegalStateException("MD5 algorithm not available", e); + } + } + + private static void parseAggregatedRecord(final String shardId, final Record kinesisRecord, + final byte[] data, final int protobufOffset, final int protobufLength, final List out) throws Exception { + + final List partitionKeyTable = new ArrayList<>(); + final List subRecordData = new ArrayList<>(); + final List subRecordPartitionKeyIndexes = new ArrayList<>(); + + final CodedInputStream input = CodedInputStream.newInstance(data, protobufOffset, protobufLength); + while (!input.isAtEnd()) { + final int tag = input.readTag(); + final int fieldNumber = WireFormat.getTagFieldNumber(tag); + switch (fieldNumber) { + case FIELD_PARTITION_KEY_TABLE: + partitionKeyTable.add(input.readString()); + break; + case FIELD_EXPLICIT_HASH_KEY_TABLE: + input.readString(); + break; + case FIELD_RECORDS: + final int length = input.readRawVarint32(); + final int oldLimit = input.pushLimit(length); + int partitionKeyIndex = 0; + byte[] subRecordPayload = new byte[0]; + while (!input.isAtEnd()) { + final int innerTag = input.readTag(); + final int innerField = WireFormat.getTagFieldNumber(innerTag); + switch (innerField) { + case RECORD_FIELD_PARTITION_KEY_INDEX: + partitionKeyIndex = (int) input.readUInt64(); + break; + case RECORD_FIELD_EXPLICIT_HASH_KEY_INDEX: + input.readUInt64(); + break; + case RECORD_FIELD_DATA: + subRecordPayload = input.readByteArray(); + break; + default: + input.skipField(innerTag); + break; + } + } + input.popLimit(oldLimit); + subRecordData.add(subRecordPayload); + subRecordPartitionKeyIndexes.add(partitionKeyIndex); + break; + default: + input.skipField(tag); + break; + } + } + + final String sequenceNumber = kinesisRecord.sequenceNumber(); + final Instant arrival = kinesisRecord.approximateArrivalTimestamp(); + final String fallbackPartitionKey = kinesisRecord.partitionKey(); + + for (int i = 0; i < subRecordData.size(); i++) { + final int partitionKeyTableIndex = subRecordPartitionKeyIndexes.get(i); + final String partitionKey = partitionKeyTableIndex < partitionKeyTable.size() + ? partitionKeyTable.get(partitionKeyTableIndex) : fallbackPartitionKey; + + final UserRecord record = new UserRecord(shardId, sequenceNumber, i, partitionKey, subRecordData.get(i), arrival); + out.add(record); + } + } + + private static UserRecord passthrough(final String shardId, final Record record, final byte[] data) { + return new UserRecord( + shardId, + record.sequenceNumber(), + 0, + record.partitionKey(), + data, + record.approximateArrivalTimestamp()); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java deleted file mode 100644 index fd2aa3be0fb4..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessor.java +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processor.ProcessSession; -import org.apache.nifi.processor.exception.ProcessException; -import org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverter; -import org.apache.nifi.schema.access.SchemaNotFoundException; -import org.apache.nifi.serialization.MalformedRecordException; -import org.apache.nifi.serialization.RecordReader; -import org.apache.nifi.serialization.RecordReaderFactory; -import org.apache.nifi.serialization.RecordSetWriter; -import org.apache.nifi.serialization.RecordSetWriterFactory; -import org.apache.nifi.serialization.WriteResult; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import static java.util.Collections.emptyMap; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; - -final class ReaderRecordProcessor { - - private final RecordReaderFactory recordReaderFactory; - private final KinesisRecordConverter recordConverter; - private final RecordSetWriterFactory recordWriterFactory; - private final ComponentLog logger; - - ReaderRecordProcessor( - final RecordReaderFactory recordReaderFactory, - final KinesisRecordConverter recordConverter, - final RecordSetWriterFactory recordWriterFactory, - final ComponentLog logger) { - this.recordReaderFactory = recordReaderFactory; - this.recordConverter = recordConverter; - this.recordWriterFactory = recordWriterFactory; - this.logger = logger; - } - - ProcessingResult processRecords( - final ProcessSession session, - final String streamName, - final String shardId, - final List records) { - final List successFlowFiles = new ArrayList<>(); - final List failureFlowFiles = new ArrayList<>(); - - ActiveFlowFile activeFlowFile = null; - - for (final KinesisClientRecord kinesisRecord : records) { - final int dataSize = kinesisRecord.data().remaining(); - final byte[] data = new byte[dataSize]; - kinesisRecord.data().get(data); - - try (final InputStream in = new ByteArrayInputStream(data); - final RecordReader reader = recordReaderFactory.createRecordReader(emptyMap(), in, data.length, logger)) { - - Record record; - while ((record = reader.nextRecord()) != null) { - final Record convertedRecord = recordConverter.convert(record, kinesisRecord, streamName, shardId); - final RecordSchema writeSchema = recordWriterFactory.getSchema(emptyMap(), convertedRecord.getSchema()); - - if (activeFlowFile == null) { - activeFlowFile = ActiveFlowFile.startNewFile(logger, session, recordWriterFactory, writeSchema, streamName, shardId); - } else if (!writeSchema.equals(activeFlowFile.schema())) { - // If the write schema has changed, we need to complete the current FlowFile and start a new one. - final FlowFile completedFlowFile = activeFlowFile.complete(); - successFlowFiles.add(completedFlowFile); - - activeFlowFile = ActiveFlowFile.startNewFile(logger, session, recordWriterFactory, writeSchema, streamName, shardId); - } - - activeFlowFile.writeRecord(convertedRecord, kinesisRecord); - } - } catch (final IOException | MalformedRecordException | SchemaNotFoundException e) { - logger.error("Reader or Writer failed to process Kinesis Record with Stream Name [{}] Shard Id [{}] Sequence Number [{}] SubSequence Number [{}]", - streamName, shardId, kinesisRecord.sequenceNumber(), kinesisRecord.subSequenceNumber(), e); - final FlowFile failureFlowFile = createParseFailureFlowFile(session, streamName, shardId, kinesisRecord, e); - failureFlowFiles.add(failureFlowFile); - } - } - - if (activeFlowFile != null) { - final FlowFile completedFlowFile = activeFlowFile.complete(); - successFlowFiles.add(completedFlowFile); - } - - return new ProcessingResult(successFlowFiles, failureFlowFiles); - } - - private static FlowFile createParseFailureFlowFile( - final ProcessSession session, - final String streamName, - final String shardId, - final KinesisClientRecord record, - final Exception e) { - FlowFile flowFile = session.create(); - - record.data().rewind(); - flowFile = session.write(flowFile, out -> { - try (final WritableByteChannel channel = Channels.newChannel(out)) { - channel.write(record.data()); - } - }); - - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, record, record); - - final Throwable cause = e.getCause() != null ? e.getCause() : e; - attributes.put(RECORD_ERROR_MESSAGE, cause.toString()); - - flowFile = session.putAllAttributes(flowFile, attributes); - session.getProvenanceReporter().receive(flowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - return flowFile; - } - - record ProcessingResult(List successFlowFiles, List parseFailureFlowFiles) { - } - - /** - * A class that manages a single {@link FlowFile} with a static schema that is currently being written to. - * On a schema change the current {@link ActiveFlowFile} should be completed a new instance of this class - * with a new schema should be created. - * - * An {@link ActiveFlowFile} must have at least one record written to it before it can be completed. - */ - private static final class ActiveFlowFile { - - private final ComponentLog logger; - - private final ProcessSession session; - private final FlowFile flowFile; - private final RecordSetWriter writer; - private final RecordSchema schema; - - private final String streamName; - private final String shardId; - - private KinesisClientRecord firstRecord; - private KinesisClientRecord lastRecord; - - private ActiveFlowFile( - final ComponentLog logger, - final ProcessSession session, - final FlowFile flowFile, - final RecordSetWriter writer, - final RecordSchema schema, - final String streamName, - final String shardId) { - this.logger = logger; - this.session = session; - this.flowFile = flowFile; - this.writer = writer; - this.schema = schema; - this.streamName = streamName; - this.shardId = shardId; - } - - static ActiveFlowFile startNewFile( - final ComponentLog logger, - final ProcessSession session, - final RecordSetWriterFactory recordWriterFactory, - final RecordSchema writeSchema, - final String streamName, - final String shardId) throws SchemaNotFoundException { - final FlowFile flowFile = session.create(); - final OutputStream outputStream = session.write(flowFile); - - try { - final RecordSetWriter writer = recordWriterFactory.createWriter(logger, writeSchema, outputStream, flowFile); - writer.beginRecordSet(); - - return new ActiveFlowFile(logger, session, flowFile, writer, writeSchema, streamName, shardId); - - } catch (final SchemaNotFoundException e) { - logger.debug("Failed to find writeSchema for Kinesis stream record: {}", e.getMessage()); - try { - outputStream.close(); - } catch (final IOException ioe) { - e.addSuppressed(ioe); - } - throw e; - - } catch (final IOException e) { - final ProcessException processException = new ProcessException("Failed to create a writer for a FlowFile", e); - - logger.debug("Stopping Kinesis records processing. Failed to create a writer for a FlowFile: {}", e.getMessage()); - try { - outputStream.close(); - } catch (final IOException ioe) { - processException.addSuppressed(ioe); - } - throw processException; - } - } - - RecordSchema schema() { - return schema; - } - - void writeRecord(final Record record, final KinesisClientRecord kinesisRecord) { - try { - writer.write(record); - } catch (final IOException e) { - logger.debug("Stopping Kinesis records processing. Failed to write to a FlowFile: {}", e.getMessage()); - throw new ProcessException("Failed to write a record into a FlowFile", e); - } - - if (firstRecord == null) { - firstRecord = kinesisRecord; - } - lastRecord = kinesisRecord; - } - - FlowFile complete() { - if (firstRecord == null || lastRecord == null) { - throw new IllegalStateException("Cannot complete an ActiveFlowFile that has no records"); - } - - try { - final WriteResult finalResult = writer.finishRecordSet(); - writer.close(); - - final Map attributes = ConsumeKinesisAttributes.fromKinesisRecords(streamName, shardId, firstRecord, lastRecord); - attributes.putAll(finalResult.getAttributes()); - attributes.put(RECORD_COUNT, String.valueOf(finalResult.getRecordCount())); - attributes.put(MIME_TYPE, writer.getMimeType()); - - final FlowFile completedFlowFile = session.putAllAttributes(flowFile, attributes); - session.getProvenanceReporter().receive(completedFlowFile, ProvenanceTransitUriFormat.toTransitUri(streamName, shardId)); - - return completedFlowFile; - - } catch (final IOException e) { - final ProcessException processException = new ProcessException("Failed to complete a FlowFile", e); - - logger.debug("Stopping Kinesis records processing. Failed to complete a FlowFile: {}", e.getMessage()); - try { - writer.close(); - } catch (final IOException ioe) { - processException.addSuppressed(ioe); - } - - throw processException; - } - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java deleted file mode 100644 index 12fa52c6fe7a..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/RecordBuffer.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; -import java.util.Optional; - -/** - * RecordBuffer keeps track of all created Shard buffers, including exclusive read access via leasing. - * It acts as the main interface between KCL callbacks and the {@link ConsumeKinesis} processor, - * routing events to appropriate Shard buffers and ensuring thread-safe operations. - */ -interface RecordBuffer { - - /** - * Interface for interactions from the Kinesis Client Library to the Record Buffer. - * Reflects the methods called by {@link software.amazon.kinesis.processor.ShardRecordProcessor}. - */ - interface ForKinesisClientLibrary { - - ShardBufferId createBuffer(String shardId); - - void addRecords(ShardBufferId bufferId, List records, RecordProcessorCheckpointer checkpointer); - - /** - * Called when a shard ends - waits until the buffer is flushed then performs the final checkpoint. - */ - void checkpointEndedShard(ShardBufferId bufferId, RecordProcessorCheckpointer checkpointer); - - /** - * Called when a consumer is shut down. Performs the checkpoint and returns - * without waiting for the buffer to be flushed. - */ - void shutdownShardConsumption(ShardBufferId bufferId, RecordProcessorCheckpointer checkpointer); - - /** - * Called when lease is lost - immediately invalidates the buffer to prevent further operations. - */ - void consumerLeaseLost(ShardBufferId bufferId); - } - - /** - * Interface for interactions from {@link ConsumeKinesis} processor to the Record Buffer. - */ - interface ForProcessor { - - /** - * Acquires an exclusive lease for a buffer that has data available for consumption. - * If no data is available in the buffers, returns an empty Optional. - *

- * After acquiring a lease, the processor can consume records from the buffer. - * After consuming the records the processor must always {@link #returnBufferLease(LEASE)}. - */ - Optional acquireBufferLease(); - - /** - * Consumes records from the buffer associated with the given lease. - * The records have to be committed or rolled back later. - */ - List consumeRecords(LEASE lease); - - void commitConsumedRecords(LEASE lease); - - void rollbackConsumedRecords(LEASE lease); - - /** - * Returns the lease for a buffer back to the pool making it available for consumption again. - * The method can be called multiple times with the same lease, but only the first call will actually take an effect. - */ - void returnBufferLease(LEASE lease); - } - - record ShardBufferId(String shardId, long bufferId) { - } - - interface ShardBufferLease { - String shardId(); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java similarity index 70% rename from nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java rename to nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java index e7575d6fcf77..9b88d71a9a4a 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormat.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/ShardFetchResult.java @@ -16,12 +16,16 @@ */ package org.apache.nifi.processors.aws.kinesis; -final class ProvenanceTransitUriFormat { +import java.math.BigInteger; +import java.util.List; - static String toTransitUri(final String streamName, final String shardId) { - return "kinesis:stream/" + streamName + "/" + shardId; +record ShardFetchResult(String shardId, List records, long millisBehindLatest) { + + BigInteger firstSequenceNumber() { + return new BigInteger(records.getFirst().sequenceNumber()); } - private ProvenanceTransitUriFormat() { + BigInteger lastSequenceNumber() { + return new BigInteger(records.getLast().sequenceNumber()); } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java new file mode 100644 index 000000000000..e44de3c506f1 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/UserRecord.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import java.time.Instant; + +/** + * A single user record extracted from a Kinesis record. For non-aggregated records, + * the fields map directly from the Kinesis API {@code Record}. For KPL-aggregated + * records, each sub-record within the aggregate gets its own instance with a unique + * {@code subSequenceNumber}. + * + * @param shardId the shard from which this record was fetched + * @param sequenceNumber the Kinesis sequence number of the enclosing record + * @param subSequenceNumber zero for non-aggregated records; index within the aggregate for KPL records + * @param partitionKey the partition key (from the enclosing record or the KPL sub-record) + * @param data the user payload bytes + * @param approximateArrivalTimestamp approximate time the enclosing record arrived at Kinesis + */ +record UserRecord( + String shardId, + String sequenceNumber, + long subSequenceNumber, + String partitionKey, + byte[] data, + Instant approximateArrivalTimestamp) { +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java deleted file mode 100644 index d454577d420b..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverter.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.FIELD_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.composeMetadataObject; - -public final class InjectMetadataRecordConverter implements KinesisRecordConverter { - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - final List schemaFields = new ArrayList<>(record.getSchema().getFields()); - schemaFields.add(FIELD_METADATA); - final RecordSchema schema = new SimpleRecordSchema(schemaFields); - - final Record metadata = composeMetadataObject(kinesisRecord, streamName, shardId); - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.put(METADATA, metadata); - - return new MapRecord(schema, recordValues); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java deleted file mode 100644 index 83a65fd81f2f..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverter.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.record.Record; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -public interface KinesisRecordConverter { - - Record convert(Record record, KinesisClientRecord kinesisRecord, String streamName, String shardId); -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java deleted file mode 100644 index b67ae22072cd..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/ValueRecordConverter.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.record.Record; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -public final class ValueRecordConverter implements KinesisRecordConverter { - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - return record; - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java deleted file mode 100644 index 693077f4437e..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/main/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.FIELD_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.composeMetadataObject; - -public final class WrapperRecordConverter implements KinesisRecordConverter { - - private static final String VALUE = "value"; - - @Override - public Record convert(final Record record, final KinesisClientRecord kinesisRecord, final String streamName, final String shardId) { - final Record metadata = composeMetadataObject(kinesisRecord, streamName, shardId); - - final RecordSchema convertedSchema = new SimpleRecordSchema(List.of( - FIELD_METADATA, - new RecordField(VALUE, RecordFieldType.RECORD.getRecordDataType(record.getSchema()))) - ); - final Map convertedRecord = Map.of( - METADATA, metadata, - VALUE, record - ); - - return new MapRecord(convertedSchema, convertedRecord); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java new file mode 100644 index 000000000000..22fc1505299b --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/CheckpointTableUtilsTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class CheckpointTableUtilsTest { + + private static final String STREAM_NAME = "my-stream"; + private static final String SHARD_ID_1 = "shardId-0001"; + private static final String SHARD_ID_2 = "shardId-0002"; + private static final String SOURCE_TABLE = "source-table"; + private static final String DEST_TABLE = "dest-table"; + + private static AttributeValue str(final String value) { + return AttributeValue.builder().s(value).build(); + } + + @Test + void testCopyCheckpointItemsCopiesShardItems() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + final Map item = Map.of( + "streamName", str(STREAM_NAME), + "shardId", str(SHARD_ID_1), + "sequenceNumber", str("12345")); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(item).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb, times(1)).putItem(putCaptor.capture()); + assertEquals(item, putCaptor.getValue().item()); + } + + @Test + void testCopyCheckpointItemsSkipsNodeAndMigrationMarkers() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + final Map nodeItem = Map.of( + "streamName", str(STREAM_NAME), + "shardId", str("__node__#node-a")); + final Map migrationMarkerItem = Map.of( + "streamName", str(STREAM_NAME), + "shardId", str("__migration__")); + final Map shardItem = Map.of( + "streamName", str(STREAM_NAME), + "shardId", str(SHARD_ID_2), + "sequenceNumber", str("67890")); + + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( + ScanResponse.builder().items(List.of(nodeItem, migrationMarkerItem, shardItem)).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb, times(1)).putItem(putCaptor.capture()); + assertEquals(shardItem, putCaptor.getValue().item()); + } + + @Test + void testCopyCheckpointItemsSkipsAllMarkers() { + final DynamoDbClient dynamoDb = mock(DynamoDbClient.class); + final ComponentLog logger = mock(ComponentLog.class); + + final Map nodeItem = Map.of( + "streamName", str(STREAM_NAME), + "shardId", str("__node__#node-b")); + + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn( + ScanResponse.builder().items(List.of(nodeItem)).build()); + + CheckpointTableUtils.copyCheckpointItems(dynamoDb, logger, SOURCE_TABLE, DEST_TABLE); + + verify(dynamoDb, never()).putItem(any(PutItemRequest.class)); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java index 6a5b866f25da..28e833222083 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisIT.java @@ -16,764 +16,1048 @@ */ package org.apache.nifi.processors.aws.kinesis; -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.flowfile.attributes.CoreAttributes; +import com.google.protobuf.ByteString; +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.nifi.avro.AvroReader; +import org.apache.nifi.avro.AvroRecordSetWriter; +import org.apache.nifi.controller.AbstractControllerService; import org.apache.nifi.json.JsonRecordSetWriter; import org.apache.nifi.json.JsonTreeReader; -import org.apache.nifi.processor.Processor; -import org.apache.nifi.processor.Relationship; -import org.apache.nifi.processor.exception.FlowFileHandlingException; +import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderControllerService; import org.apache.nifi.processors.aws.region.RegionUtil; import org.apache.nifi.provenance.ProvenanceEventRecord; -import org.apache.nifi.provenance.ProvenanceEventType; -import org.apache.nifi.reporting.InitializationException; -import org.apache.nifi.schema.access.SchemaAccessUtils; -import org.apache.nifi.schema.inference.SchemaInferenceUtil; +import org.apache.nifi.serialization.RecordSetWriter; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.RecordSet; import org.apache.nifi.util.MockFlowFile; -import org.apache.nifi.util.MockProcessSession; -import org.apache.nifi.util.SharedSessionState; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; -import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.api.parallel.ExecutionMode; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.EnabledIfDockerAvailable; +import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.localstack.LocalStackContainer; import org.testcontainers.utility.DockerImageName; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.kinesis.KinesisClient; -import software.amazon.awssdk.services.kinesis.model.Consumer; -import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; -import software.amazon.awssdk.services.kinesis.model.ListStreamConsumersResponse; -import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; -import software.amazon.awssdk.services.kinesis.model.ScalingType; -import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.awssdk.services.kinesis.model.CreateStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.PutRecordRequest; import software.amazon.awssdk.services.kinesis.model.StreamStatus; - -import java.net.URI; -import java.time.Duration; -import java.util.Collection; +import software.amazon.kinesis.retrieval.kpl.Messages; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; -import java.util.concurrent.Callable; -import java.util.concurrent.atomic.AtomicLong; -import java.util.stream.IntStream; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.MINUTES; -import static java.util.stream.Collectors.groupingBy; -import static java.util.stream.Collectors.mapping; -import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toSet; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_PARSE_FAILURE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_SUCCESS; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.JsonRecordAssert.assertFlowFileRecordPayloads; -import static org.junit.jupiter.api.Assertions.assertAll; +import java.util.stream.Collectors; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Timeout.ThreadMode.SEPARATE_THREAD; -/** - * Tests run in parallel to optimize execution time as Kinesis consumer coordination takes a lot. - */ -@Execution(ExecutionMode.CONCURRENT) -@Timeout(value = 5, unit = MINUTES, threadMode = SEPARATE_THREAD) +@Testcontainers +@EnabledIfDockerAvailable class ConsumeKinesisIT { - private static final Logger logger = LoggerFactory.getLogger(ConsumeKinesisIT.class); - private static final DockerImageName LOCALSTACK_IMAGE = DockerImageName.parse("localstack/localstack:4.12.0"); + @Container + private static final LocalStackContainer LOCALSTACK = new LocalStackContainer( + DockerImageName.parse("localstack/localstack:4")); + + private static final String AVRO_SCHEMA_A = """ + { + "type": "record", + "name": "A", + "fields": [ + {"name": "id", "type": "int"}, + {"name": "name", "type": "string"} + ] + }"""; + + private static final String AVRO_SCHEMA_B = """ + { + "type": "record", + "name": "B", + "fields": [ + {"name": "code", "type": "string"}, + {"name": "value", "type": "double"} + ] + }"""; - private static final LocalStackContainer localstack = new LocalStackContainer(LOCALSTACK_IMAGE).withServices("kinesis", "dynamodb", "cloudwatch"); + private TestRunner runner; + private KinesisClient kinesisClient; + private int credentialServiceCounter = 0; - private static KinesisClient kinesisClient; - private static DynamoDbClient dynamoDbClient; + @BeforeEach + void setUp() throws Exception { + kinesisClient = KinesisClient.builder() + .endpointOverride(LOCALSTACK.getEndpoint()) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create(LOCALSTACK.getAccessKey(), LOCALSTACK.getSecretKey()))) + .region(Region.of(LOCALSTACK.getRegion())) + .build(); - private String streamName; - private String applicationName; - private TestRunner runner; - private TestKinesisStreamClient streamClient; + runner = TestRunners.newTestRunner(FastTimingConsumeKinesis.class); - @BeforeAll - static void oneTimeSetup() { - localstack.start(); + final JsonTreeReader reader = new JsonTreeReader(); + runner.addControllerService("json-reader", reader); + runner.enableControllerService(reader); - final AwsCredentialsProvider credentialsProvider = StaticCredentialsProvider.create( - AwsBasicCredentials.create(localstack.getAccessKey(), localstack.getSecretKey()) - ); + final JsonRecordSetWriter normalWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", normalWriter); + runner.enableControllerService(normalWriter); - kinesisClient = KinesisClient.builder() - .endpointOverride(localstack.getEndpoint()) - .credentialsProvider(credentialsProvider) - .region(Region.of(localstack.getRegion())) - .build(); + final FailingRecordSetWriterFactory failingWriter = new FailingRecordSetWriterFactory(); + runner.addControllerService("failing-writer", failingWriter); + runner.enableControllerService(failingWriter); - dynamoDbClient = DynamoDbClient.builder() - .endpointOverride(localstack.getEndpoint()) - .credentialsProvider(credentialsProvider) - .region(Region.of(localstack.getRegion())) - .build(); + addCredentialService(runner, "creds"); + runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "test-app-" + System.currentTimeMillis()); + runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "creds"); + runner.setProperty(RegionUtil.REGION, LOCALSTACK.getRegion()); + runner.setProperty(ConsumeKinesis.ENDPOINT_OVERRIDE, LOCALSTACK.getEndpoint().toString()); + runner.setProperty(ConsumeKinesis.MAX_BATCH_DURATION, "200 ms"); } - @AfterAll - static void tearDown() { + @AfterEach + void tearDown() { if (kinesisClient != null) { kinesisClient.close(); } - if (dynamoDbClient != null) { - dynamoDbClient.close(); - } - localstack.stop(); } - @BeforeEach - void setUp() throws InitializationException { - final UUID testId = UUID.randomUUID(); - streamName = "%s-kinesis-stream-%s".formatted(getClass().getSimpleName(), testId); - streamClient = new TestKinesisStreamClient(kinesisClient, streamName); - applicationName = "%s-test-kinesis-app-%s".formatted(getClass().getSimpleName(), testId); - runner = createTestRunner(streamName, applicationName); - } + @Test + void testFlowFilePerRecordStrategy() throws Exception { + final String streamName = "per-record-test"; + final int recordCount = 5; - @AfterEach - void tearDownEach() { - runner.stop(); + createStream(streamName); + publishRecords(streamName, recordCount); - if (streamClient != null) { - try { - streamClient.deleteStream(); - } catch (final Exception e) { - logger.warn("Failed to delete stream {}: {}", streamName, e.getMessage()); - } + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(recordCount, flowFiles.size(), "Expected one FlowFile per Kinesis record"); + + for (final MockFlowFile ff : flowFiles) { + assertEquals("1", ff.getAttribute("record.count")); + assertEquals(streamName, ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); + assertEquals(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE), ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_PARTITION_KEY)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SUBSEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SUBSEQUENCE)); + + final String content = ff.getContent(); + assertTrue(content.startsWith("{"), "Expected raw JSON content: " + content); } - // Removing tables generated by KCL. - deleteTable(applicationName); - deleteTable(applicationName + "-CoordinatorState"); - deleteTable(applicationName + "-WorkerMetricStats"); + final Set emittedShardIds = flowFiles.stream() + .map(ff -> ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)) + .collect(Collectors.toSet()); + final List receiveEvents = runner.getProvenanceEvents().stream() + .filter(event -> "RECEIVE".equals(event.getEventType().name())) + .toList(); + assertEquals(recordCount, receiveEvents.size(), "Expected one RECEIVE event per emitted FlowFile"); + for (final ProvenanceEventRecord receiveEvent : receiveEvents) { + final String transitUri = receiveEvent.getTransitUri(); + assertNotNull(transitUri, "RECEIVE event should include a transit URI"); + assertTrue(emittedShardIds.stream().anyMatch(shardId -> transitUri.endsWith("/" + shardId)), + "RECEIVE transit URI should include one of the emitted shard IDs: " + transitUri); + } + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } - private void deleteTable(final String tableName) { - try { - dynamoDbClient.deleteTable(req -> req.tableName(tableName)); - } catch (final Exception e) { - logger.warn("Failed to delete DynamoDB table {}: {}", tableName, e.getMessage()); + @Test + void testRecordOrientedStrategy() throws Exception { + final String streamName = "record-oriented-test"; + final int recordCount = 5; + + createStream(streamName); + publishRecords(streamName, recordCount); + + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + assertNotNull(ff.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE)); } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeSingleMessageFromSingleShard() { - streamClient.createStream(1); - - final String testMessage = "Hello, Kinesis!"; - streamClient.putRecord("test-partition-key", testMessage); + void testRecordOrientedStrategyWithInjectedMetadata() throws Exception { + final String streamName = "record-oriented-metadata-test"; + final int recordCount = 5; - runProcessorWithInitAndWaitForFiles(runner, 1); + createStream(streamName); + publishRecords(streamName, recordCount); - runner.assertTransferCount(REL_SUCCESS, 1); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final MockFlowFile flowFile = flowFiles.getFirst(); - - flowFile.assertContentEquals(testMessage); - flowFile.assertAttributeEquals("aws.kinesis.partition.key", "test-partition-key"); - assertNotNull(flowFile.getAttribute("aws.kinesis.first.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.last.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.shard.id")); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, "INJECT_METADATA"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + long totalRecords = 0; + for (final MockFlowFile flowFile : flowFiles) { + totalRecords += Long.parseLong(flowFile.getAttribute("record.count")); + + final String content = flowFile.getContent(); + assertTrue(content.contains("\"kinesisMetadata\""), "Expected injected kinesisMetadata object"); + assertTrue(content.contains("\"stream\":\"" + streamName + "\""), "Expected stream in injected metadata"); + assertTrue(content.contains("\"shardId\":\""), "Expected shardId in injected metadata"); + assertTrue(content.contains("\"sequenceNumber\":\""), "Expected sequenceNumber in injected metadata"); + assertTrue(content.contains("\"subSequenceNumber\":0"), "Expected default subSequenceNumber in injected metadata"); + assertTrue(content.contains("\"partitionKey\":\""), "Expected partitionKey in injected metadata"); + } - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFile); + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - // Creates an enhanced fan-out consumer by default. - assertEquals( - List.of(applicationName), - streamClient.getEnhancedFanOutConsumerNames(), - "Expected a single enhanced fan-out consumer with an application name"); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeSingleMessageFromSingleShard_withoutEnhancedFanOut() { - runner.setProperty(ConsumeKinesis.CONSUMER_TYPE, ConsumeKinesis.ConsumerType.SHARED_THROUGHPUT); + void testDemarcatorStrategy() throws Exception { + final String streamName = "demarcator-test"; + final int recordCount = 3; + + createStream(streamName); + publishRecords(streamName, recordCount); + + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "LINE_DELIMITED"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); - streamClient.createStream(1); + final StringBuilder allContent = new StringBuilder(); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + allContent.append(ff.getContent()); + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - final String testMessage = "Hello, Kinesis!"; - streamClient.putRecord("test-partition-key", testMessage); + final String[] lines = allContent.toString().split("\n"); + assertEquals(recordCount, lines.length, "Expected one line per Kinesis record"); + for (final String line : lines) { + assertTrue(line.startsWith("{"), "Expected JSON content: " + line); + } - runProcessorWithInitAndWaitForFiles(runner, 1); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); + } - runner.assertTransferCount(REL_SUCCESS, 1); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final MockFlowFile flowFile = flowFiles.getFirst(); + @Test + void testDemarcatorStrategyWithCustomDelimiter() throws Exception { + final String streamName = "custom-delim-test"; + final int recordCount = 3; + final String delimiter = "|||"; - flowFile.assertContentEquals(testMessage); - flowFile.assertAttributeEquals("aws.kinesis.partition.key", "test-partition-key"); - assertNotNull(flowFile.getAttribute("aws.kinesis.first.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.last.sequence.number")); - assertNotNull(flowFile.getAttribute("aws.kinesis.shard.id")); + createStream(streamName); + publishRecords(streamName, recordCount); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFile); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "DEMARCATOR"); + runner.setProperty(ConsumeKinesis.MESSAGE_DEMARCATOR, delimiter); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); + + final StringBuilder allContent = new StringBuilder(); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + allContent.append(ff.getContent()); + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "Total record count across all FlowFiles"); - assertTrue( - streamClient.getEnhancedFanOutConsumerNames().isEmpty(), - "No enhanced fan-out consumers should be created for Shared Throughput consumer type"); + final String[] parts = allContent.toString().split("\\|\\|\\|"); + assertEquals(recordCount, parts.length, "Expected records separated by custom delimiter"); + for (final String part : parts) { + assertTrue(part.startsWith("{"), "Expected JSON content: " + part); + } + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeManyMessagesFromSingleShardWithOrdering() { - final int messageCount = 10; + void testFailedWriteRollsBackAndRecordsAreReConsumed() throws Exception { + final String streamName = "rollback-test"; + final int recordCount = 5; + + createStream(streamName); + publishRecords(streamName, recordCount); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); - streamClient.createStream(1); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "failing-writer"); + runUntilOutput(runner); - final List messages = IntStream.range(0, messageCount).mapToObj(i -> "Message-" + i).toList(); - streamClient.putRecords("partition-key", messages); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); - runProcessorWithInitAndWaitForFiles(runner, messageCount); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertTransferCount(REL_SUCCESS, messageCount); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - final List flowFilesContent = flowFiles.stream().map(MockFlowFile::getContent).toList(); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile transferred to success"); - assertEquals(messages, flowFilesContent); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(recordCount, totalRecords, "All records should be re-consumed after rollback"); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter, "Records Consumed counter should be set"); + assertEquals(recordCount, counter.longValue()); } @Test - void testConsumeMessagesFromMultipleShardsStream() { - final int shardCount = 5; - final int messagesPerPartitionKey = 5; + void testClusterSimulationDistributesShards() throws Exception { + final String streamName = "cluster-sim-test"; + final int shardCount = 4; + final String appName = "cluster-app-" + System.currentTimeMillis(); + + createStream(streamName, shardCount); - streamClient.createStream(shardCount); + final TestRunner runner1 = createConfiguredRunner(streamName, appName); + final TestRunner runner2 = createConfiguredRunner(streamName, appName); - // Every shard has message with the same payload. - final List shardMessages = IntStream.range(0, messagesPerPartitionKey).mapToObj(String::valueOf).toList(); + runner1.run(1, false, true); + runner2.run(1, false, true); - IntStream.range(0, shardCount).forEach(shard -> streamClient.putRecords(String.valueOf(shard), shardMessages)); + int recordId = 0; + final long deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); - // Run processor and wait for all records - final int totalMessages = shardCount * messagesPerPartitionKey; - runProcessorWithInitAndWaitForFiles(runner, totalMessages); + runner1.run(1, false, false); + runner2.run(1, false, false); - // Verify results - runner.assertTransferCount(REL_SUCCESS, totalMessages); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); + if (hasFlowFiles(runner1) && hasFlowFiles(runner2)) { + break; + } + } - final Map> partitionKey2Messages = flowFiles.stream() - .collect(groupingBy( - f -> f.getAttribute("aws.kinesis.partition.key"), - mapping(MockFlowFile::getContent, toList()) - )); + runner1.run(1, true, false); + runner2.run(1, true, false); - assertEquals(shardCount, partitionKey2Messages.size()); + final List flowFiles1 = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List flowFiles2 = runner2.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); - assertAll( - partitionKey2Messages.values().stream() - .map(actual -> () -> assertEquals(shardMessages, actual)) - ); + assertFalse(flowFiles1.isEmpty(), "Runner 1 should have received data"); + assertFalse(flowFiles2.isEmpty(), "Runner 2 should have received data"); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles); + final Set uniqueRecords = new HashSet<>(); + for (final MockFlowFile ff : flowFiles1) { + uniqueRecords.add(ff.getContent()); + } + for (final MockFlowFile ff : flowFiles2) { + uniqueRecords.add(ff.getContent()); + } + assertEquals(flowFiles1.size() + flowFiles2.size(), uniqueRecords.size(), + "No duplicate records should be consumed"); } @Test - @Disabled("Does not work with LocalStack: https://github.com/localstack/localstack/issues/12833. Enable only when using real Kinesis.") - void testResharding_inParallelWithConsumption() { - // Using partition keys with uniformally distributed hashes to ensure the data is distributed across split shards. - final List partitionKeys = List.of( - "pk-0-14", // 035517ff4ca68849589f43842c07362f - "pk-1-14", // 5f045ae51eea9bd124d76041a6a27073 - "pk-2-2", // 85fb9a2b01b009904eb8a6fa13a21d6c - "pk-3-2" // dbf24a6e26910143c60188e2fcb53b4f - ); - - // Data to be produced at each stage - final int partitionRecordsPerStage = 3; - final int totalStages = 5; // initial + 4 resharding stages with data - final int totalRecordsPerPartition = partitionRecordsPerStage * totalStages; - final int totalRecords = partitionKeys.size() * totalRecordsPerPartition; - - // Start resharding and data production operations in background thread. - final Thread reshardingThread = new Thread(() -> { - int messageSeq = 0; // For each partition key the message content are sequential numbers. - - streamClient.createStream(1); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(2); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(4); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(3); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - messageSeq += partitionRecordsPerStage; - - streamClient.reshardStream(2); - putRecords(partitionKeys, partitionRecordsPerStage, messageSeq); - }); - - reshardingThread.start(); - - runProcessorWithInitAndWaitForFiles(runner, totalRecords); - - runner.assertTransferCount(REL_SUCCESS, totalRecords); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - - final Map> partitionKeyToMessages = flowFiles.stream() - .collect(groupingBy( - f -> f.getAttribute("aws.kinesis.partition.key"), - mapping(MockFlowFile::getContent, toList()) - )); - - final List expectedPartitionMessages = IntStream.range(0, totalRecordsPerPartition).mapToObj(Integer::toString).toList(); - assertAll( - partitionKeyToMessages.entrySet().stream() - .map(actual -> () -> assertEquals( - expectedPartitionMessages, - actual.getValue(), - "Partition messages do not match expected for partition key: " + actual.getKey())) - ); - } - - private void putRecords(final Collection partitionKeys, final int count, final int startIndex) { - IntStream.range(startIndex, startIndex + count).forEach(i -> { - final String message = Integer.toString(i); - partitionKeys.forEach(partitionKey -> streamClient.putRecord(partitionKey, message)); - }); + void testClusterScaleDownAndScaleUpRebalancesShards() throws Exception { + final String streamName = "cluster-rebalance-test"; + final int shardCount = 6; + final String appName = "cluster-rebalance-app-" + System.currentTimeMillis(); + + createStream(streamName, shardCount); + + final TestRunner runner1 = createConfiguredRunner(streamName, appName); + final TestRunner runner2 = createConfiguredRunner(streamName, appName); + + runner1.run(1, false, true); + runner2.run(1, false, true); + + int recordId = 0; + long deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + runner2.run(1, false, false); + if (hasFlowFiles(runner1) && hasFlowFiles(runner2)) { + break; + } + } + assertFalse(runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 1 should receive data in stage one"); + assertFalse(runner2.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 2 should receive data in stage one"); + + final int runner1Checkpoint = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).size(); + runner2.run(1, true, false); + + deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + + final Set shards = new HashSet<>(); + for (final MockFlowFile ff : getNewFlowFiles(runner1, runner1Checkpoint)) { + shards.add(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + } + if (shards.size() == shardCount) { + break; + } + } + + final Set stageTwoShards = new HashSet<>(); + for (final MockFlowFile ff : getNewFlowFiles(runner1, runner1Checkpoint)) { + stageTwoShards.add(ff.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + } + assertEquals(shardCount, stageTwoShards.size(), "Single active runner should consume from all shards"); + + final TestRunner runner3 = createConfiguredRunner(streamName, appName); + final TestRunner runner4 = createConfiguredRunner(streamName, appName); + final int runner1StageThreeCheckpoint = runner1.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).size(); + + runner3.run(1, false, true); + runner4.run(1, false, true); + + deadline = System.currentTimeMillis() + 30_000; + while (System.currentTimeMillis() < deadline) { + publishRecord(streamName, recordId++); + Thread.sleep(10); + runner1.run(1, false, false); + runner3.run(1, false, false); + runner4.run(1, false, false); + + if (!getNewFlowFiles(runner1, runner1StageThreeCheckpoint).isEmpty() + && hasFlowFiles(runner3) && hasFlowFiles(runner4)) { + break; + } + } + + runner1.run(1, true, false); + runner3.run(1, true, false); + runner4.run(1, true, false); + + assertFalse(getNewFlowFiles(runner1, runner1StageThreeCheckpoint).isEmpty(), + "Runner 1 should receive data in stage three"); + assertFalse(runner3.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 3 should receive data in stage three"); + assertFalse(runner4.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(), + "Runner 4 should receive data in stage three"); } @Test - void testSessionRollback() throws InterruptedException { - streamClient.createStream(1); + void testKplAggregatedRecordsFlowFilePerRecord() throws Exception { + final String streamName = "kpl-per-record-test"; + createStream(streamName); - // Initialize the processor. - runner.run(1, false, true); + publishAggregatedRecord(streamName, "agg-pk-1", + List.of("pk-a", "pk-b"), + List.of("{\"id\":1,\"name\":\"Alice\"}", "{\"id\":2,\"name\":\"Bob\"}", "{\"id\":3,\"name\":\"Charlie\"}"), + List.of(0, 1, 0)); - final ConsumeKinesis processor = (ConsumeKinesis) runner.getProcessor(); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - final String firstMessage = "Initial-Rollback-Message"; - streamClient.putRecord("key", firstMessage); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(3, flowFiles.size(), "Each sub-record in the aggregate should produce its own FlowFile"); - // First attempt with a failing session - should rollback. - while (true) { - final MockProcessSession failingSession = createFailingSession(processor); - try { - processor.onTrigger(runner.getProcessContext(), failingSession); - } catch (final FlowFileHandlingException ignored) { - failingSession.assertAllFlowFilesTransferred(REL_SUCCESS, 0); - break; // Expected rollback occurred - } - Thread.sleep(1000); + final Set contents = new HashSet<>(); + for (final MockFlowFile ff : flowFiles) { + assertEquals("1", ff.getAttribute("record.count")); + assertEquals(streamName, ff.getAttribute(ConsumeKinesis.ATTR_STREAM_NAME)); + contents.add(ff.getContent()); } - // Write another message. - final String secondMessage = "Another-Test-Message"; - streamClient.putRecord("key", secondMessage); + assertTrue(contents.contains("{\"id\":1,\"name\":\"Alice\"}")); + assertTrue(contents.contains("{\"id\":2,\"name\":\"Bob\"}")); + assertTrue(contents.contains("{\"id\":3,\"name\":\"Charlie\"}")); + + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(3, counter.longValue()); + } + + @Test + void testKplAggregatedRecordsRecordOrientedWithMetadata() throws Exception { + final String streamName = "kpl-metadata-test"; + createStream(streamName); - runProcessorAndWaitForFiles(runner, 2); + publishAggregatedRecord(streamName, "agg-pk-1", + List.of("pk-inner"), + List.of("{\"id\":10,\"name\":\"Alpha\"}", "{\"id\":20,\"name\":\"Beta\"}"), + List.of(0, 0)); - // Verify the messages are written in the correct order. - runner.assertTransferCount(REL_SUCCESS, 2); - final List flowFiles = runner.getFlowFilesForRelationship(REL_SUCCESS); - flowFiles.getFirst().assertContentEquals(firstMessage); - flowFiles.getLast().assertContentEquals(secondMessage); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, "INJECT_METADATA"); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty()); + + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + final String content = ff.getContent(); + assertTrue(content.contains("\"kinesisMetadata\""), "Expected injected metadata"); + assertTrue(content.contains("\"stream\":\"" + streamName + "\"")); + assertTrue(content.contains("\"partitionKey\":\"pk-inner\""), "Expected inner partition key, not the outer one"); + } + assertEquals(2, totalRecords); - assertReceiveProvenanceEvents(runner.getProvenanceEvents(), flowFiles.getFirst(), flowFiles.getLast()); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(2, counter.longValue()); } @Test - void testRecordProcessingWithSchemaChangesAndInvalidRecords() throws InitializationException { - streamClient.createStream(1); + void testKplAggregatedMixedWithPlainRecords() throws Exception { + final String streamName = "kpl-mixed-test"; + createStream(streamName); - final TestRunner recordRunner = createRecordTestRunner(streamName, applicationName); + publishRecords(streamName, 2); - final List testRecords = List.of( - "{\"name\":\"John\",\"age\":30}", // Schema A - "{\"name\":\"Jane\",\"age\":25}", // Schema A - "{invalid json}", - "{\"id\":\"123\",\"value\":\"test\"}" // Schema B - ); + publishAggregatedRecord(streamName, "agg-pk", + List.of("pk-agg"), + List.of("{\"id\":100,\"name\":\"Aggregated-1\"}", "{\"id\":101,\"name\":\"Aggregated-2\"}", "{\"id\":102,\"name\":\"Aggregated-3\"}"), + List.of(0, 0, 0)); - testRecords.forEach(record -> streamClient.putRecord("key", record)); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - runProcessorWithInitAndWaitForFiles(recordRunner, 3); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(5, flowFiles.size(), "2 plain + 3 deaggregated sub-records"); - // Verify successful records. - recordRunner.assertTransferCount(REL_SUCCESS, 2); - final List successFlowFiles = recordRunner.getFlowFilesForRelationship(REL_SUCCESS); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(5, counter.longValue()); + } - final MockFlowFile firstFlowFile = successFlowFiles.getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecordPayloads(firstFlowFile, testRecords.getFirst(), testRecords.get(1)); + @Test + void testKplMultipleAggregatedRecords() throws Exception { + final String streamName = "kpl-multi-agg-test"; + createStream(streamName); + + for (int batch = 0; batch < 3; batch++) { + final List payloads = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + final int id = batch * 5 + i; + payloads.add("{\"id\":" + id + ",\"name\":\"rec-" + id + "\"}"); + } + publishAggregatedRecord(streamName, "agg-pk-" + batch, + List.of("pk-" + batch), payloads, payloads.stream().map(p -> 0).toList()); + } - final MockFlowFile secondFlowFile = successFlowFiles.get(1); - assertEquals("1", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecordPayloads(secondFlowFile, testRecords.getLast()); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runUntilOutput(runner); - // Verify failure record. - recordRunner.assertTransferCount(REL_PARSE_FAILURE, 1); - final List parseFailureFlowFiles = recordRunner.getFlowFilesForRelationship(REL_PARSE_FAILURE); - final MockFlowFile parseFailureFlowFile = parseFailureFlowFiles.getFirst(); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(15, flowFiles.size()); - parseFailureFlowFile.assertContentEquals(testRecords.get(2)); - assertNotNull(parseFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); + final Set contents = new HashSet<>(); + for (final MockFlowFile ff : flowFiles) { + contents.add(ff.getContent()); + } + for (int i = 0; i < 15; i++) { + assertTrue(contents.contains("{\"id\":" + i + ",\"name\":\"rec-" + i + "\"}"), + "Missing deaggregated record with id=" + i); + } - // Verify provenance events. - assertReceiveProvenanceEvents(recordRunner.getProvenanceEvents(), firstFlowFile, secondFlowFile, parseFailureFlowFile); + final Long counter = runner.getCounterValue("Records Consumed"); + assertNotNull(counter); + assertEquals(15, counter.longValue()); } @Test - void testRecordProcessingWithDemarcator() throws InitializationException { - streamClient.createStream(1); + void testAvroSameSchemaProducesSingleFlowFile() throws Exception { + final String streamName = "avro-same-schema-test"; + createStream(streamName); - final TestRunner demarcatorTestRunner = createDemarcatorTestRunner(streamName, applicationName, System.lineSeparator()); + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); - final List testRecords = List.of( - "{\"name\":\"John\",\"age\":30}", // Schema A - "{\"name\":\"Jane\",\"age\":25}", // Schema A - "{invalid json}", - "{\"id\":\"123\",\"value\":\"test\"}" // Schema B - ); + for (int i = 0; i < 4; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "record-" + i); + publishAvroRecord(streamName, "pk-" + i, schemaA, rec); + } - testRecords.forEach(record -> streamClient.putRecord("key", record)); + configureAvroRecordOriented(streamName); + runUntilOutput(runner); - runProcessorWithInitAndWaitForFiles(demarcatorTestRunner, 1); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertFalse(flowFiles.isEmpty(), "Expected at least one FlowFile"); - // All records from the same shard are put as is into the same FlowFile. - demarcatorTestRunner.assertTransferCount(REL_SUCCESS, 1); - final List successFlowFiles = demarcatorTestRunner.getFlowFilesForRelationship(REL_SUCCESS); + long totalRecords = 0; + for (final MockFlowFile ff : flowFiles) { + totalRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(4, totalRecords, "All 4 same-schema Avro records should be consumed"); + } - final MockFlowFile flowFile = successFlowFiles.getFirst(); - assertEquals("4", flowFile.getAttribute(RECORD_COUNT)); - flowFile.assertContentEquals(String.join(System.lineSeparator(), testRecords)); + @Test + void testAvroDifferentSchemasSplitFlowFiles() throws Exception { + final String streamName = "avro-diff-schema-test"; + createStream(streamName); + + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); + + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "a-" + i); + publishAvroRecord(streamName, "pk-a-" + i, schemaA, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaB); + rec.put("code", "code-" + i); + rec.put("value", i * 1.5); + publishAvroRecord(streamName, "pk-b-" + i, schemaB, rec); + } - // Verify provenance events. - assertReceiveProvenanceEvents(demarcatorTestRunner.getProvenanceEvents(), flowFile); - } + configureAvroRecordOriented(streamName); + runUntilOutput(runner); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(2, flowFiles.size(), "A,A,B,B should produce 2 FlowFiles"); - private static void assertReceiveProvenanceEvents(final List actualEvents, final FlowFile... expectedFlowFiles) { - assertReceiveProvenanceEvents(actualEvents, List.of(expectedFlowFiles)); + assertEquals("2", flowFiles.get(0).getAttribute("record.count")); + assertEquals("2", flowFiles.get(1).getAttribute("record.count")); } - private static void assertReceiveProvenanceEvents(final List actualEvents, final Collection expectedFlowFiles) { - assertEquals(expectedFlowFiles.size(), actualEvents.size(), "Each produced FlowFile must have a provenance event"); + @Test + void testAvroInterleavedSchemasSplitFlowFiles() throws Exception { + final String streamName = "avro-interleaved-test"; + createStream(streamName); + + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); + + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", i); + rec.put("name", "a1-" + i); + publishAvroRecord(streamName, "pk-a1-" + i, schemaA, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaB); + rec.put("code", "code-" + i); + rec.put("value", i * 2.0); + publishAvroRecord(streamName, "pk-b-" + i, schemaB, rec); + } + for (int i = 0; i < 2; i++) { + final GenericRecord rec = new GenericData.Record(schemaA); + rec.put("id", 100 + i); + rec.put("name", "a2-" + i); + publishAvroRecord(streamName, "pk-a2-" + i, schemaA, rec); + } - assertAll( - actualEvents.stream().map(event -> () -> - assertEquals(ProvenanceEventType.RECEIVE, event.getEventType(), "Unexpected Provenance Event Type")) - ); + configureAvroRecordOriented(streamName); + runUntilOutput(runner); - final Set eventFlowFileUuids = actualEvents.stream() - .map(ProvenanceEventRecord::getFlowFileUuid) - .collect(toSet()); + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + assertEquals(3, flowFiles.size(), "A,A,B,B,A,A should produce 3 FlowFiles (no demux)"); - assertAll( - expectedFlowFiles.stream() - .map(flowFile -> flowFile.getAttribute(CoreAttributes.UUID.key())) - .map(uuid -> () -> - assertTrue(eventFlowFileUuids.contains(uuid), "Expected Provenance Event for FlowFile UUID: %s was not present".formatted(uuid))) - ); + assertEquals("2", flowFiles.get(0).getAttribute("record.count")); + assertEquals("2", flowFiles.get(1).getAttribute("record.count")); + assertEquals("2", flowFiles.get(2).getAttribute("record.count")); } - private TestRunner createTestRunner(final String streamName, final String applicationName) throws InitializationException { - final TestRunner runner = TestRunners.newTestRunner(TestConsumeKinesis.class); + @Test + void testAvroMixedWithCorruptDataRoutesToParseFailure() throws Exception { + final String streamName = "avro-corrupt-mix-test"; + createStream(streamName); - final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); - runner.addControllerService("credentials", credentialsService); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, localstack.getAccessKey()); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, localstack.getSecretKey()); - runner.enableControllerService(credentialsService); + final Schema schemaA = new Schema.Parser().parse(AVRO_SCHEMA_A); + final Schema schemaB = new Schema.Parser().parse(AVRO_SCHEMA_B); - runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "credentials"); - runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); - runner.setProperty(ConsumeKinesis.APPLICATION_NAME, applicationName); - runner.setProperty(RegionUtil.REGION, localstack.getRegion()); - runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, ConsumeKinesis.InitialPosition.TRIM_HORIZON); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.FLOW_FILE); + final GenericRecord recA = new GenericData.Record(schemaA); + recA.put("id", 1); + recA.put("name", "valid-a"); + publishAvroRecord(streamName, "pk-a", schemaA, recA); - runner.setProperty(ConsumeKinesis.METRICS_PUBLISHING, ConsumeKinesis.MetricsPublishing.CLOUDWATCH); + publishCorruptRecord(streamName, "pk-corrupt", "THIS_IS_NOT_AVRO_DATA"); - runner.setProperty(ConsumeKinesis.MAX_BYTES_TO_BUFFER, "10 MB"); + final GenericRecord recB = new GenericData.Record(schemaB); + recB.put("code", "valid-b"); + recB.put("value", 3.14); + publishAvroRecord(streamName, "pk-b", schemaB, recB); - runner.assertValid(); - return runner; - } + configureAvroRecordOriented(streamName); + runUntilOutput(runner); + + final List successFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); - private TestRunner createRecordTestRunner(final String streamName, final String applicationName) throws InitializationException { - final TestRunner runner = createTestRunner(streamName, applicationName); + long totalSuccessRecords = 0; + for (final MockFlowFile ff : successFlowFiles) { + totalSuccessRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(2, totalSuccessRecords, "Both valid Avro records should be in success"); + assertEquals(1, failureFlowFiles.size(), "Corrupt record should be routed to parse failure"); + failureFlowFiles.getFirst().assertContentEquals("THIS_IS_NOT_AVRO_DATA"); + } - final JsonTreeReader jsonReader = new JsonTreeReader(); - runner.addControllerService("json-reader", jsonReader); - runner.setProperty(jsonReader, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaInferenceUtil.INFER_SCHEMA.getValue()); - runner.enableControllerService(jsonReader); + @Test + void testParseFailureWithSomeValidJsonRecords() throws Exception { + final String streamName = "json-parse-failure-test"; + createStream(streamName); - final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); - runner.addControllerService("json-writer", jsonWriter); - runner.setProperty(jsonWriter, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaAccessUtils.INHERIT_RECORD_SCHEMA.getValue()); - runner.enableControllerService(jsonWriter); + publishRecord(streamName, 0); + publishCorruptRecord(streamName, "pk-bad-1", "CORRUPT_DATA_1"); + publishRecord(streamName, 1); + publishCorruptRecord(streamName, "pk-bad-2", "CORRUPT_DATA_2"); + publishRecord(streamName, 2); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.RECORD); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertValid(); - return runner; + final List successFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + + long totalSuccessRecords = 0; + for (final MockFlowFile ff : successFlowFiles) { + totalSuccessRecords += Long.parseLong(ff.getAttribute("record.count")); + } + assertEquals(3, totalSuccessRecords, "All valid JSON records should route to success"); + assertEquals(2, failureFlowFiles.size(), "Both corrupt records should route to parse failure"); } - private TestRunner createDemarcatorTestRunner(final String streamName, final String applicationName, final String demarcator) throws InitializationException { - final TestRunner runner = createTestRunner(streamName, applicationName); + @Test + void testAllCorruptRecordsRouteToParseFailure() throws Exception { + final String streamName = "all-corrupt-test"; + createStream(streamName); + + publishCorruptRecord(streamName, "pk-1", "NOT_VALID_DATA_1"); + publishCorruptRecord(streamName, "pk-2", "NOT_VALID_DATA_2"); + publishCorruptRecord(streamName, "pk-3", "NOT_VALID_DATA_3"); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.DEMARCATOR); - runner.setProperty(ConsumeKinesis.MESSAGE_DEMARCATOR, demarcator); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runUntilOutput(runner); - runner.assertValid(); - return runner; + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + final List failureFlowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + assertEquals(3, failureFlowFiles.size(), "All corrupt records should route to parse failure"); } - private void runProcessorWithInitAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount) { - runProcessorAndWaitForFiles(runner, expectedFlowFileCount, true); + private void addCredentialService(final TestRunner testRunner, final String serviceId) throws Exception { + final AWSCredentialsProviderControllerService credsSvc = new AWSCredentialsProviderControllerService(); + testRunner.addControllerService(serviceId, credsSvc); + testRunner.setProperty(credsSvc, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, LOCALSTACK.getAccessKey()); + testRunner.setProperty(credsSvc, AWSCredentialsProviderControllerService.SECRET_KEY, LOCALSTACK.getSecretKey()); + testRunner.enableControllerService(credsSvc); } - private void runProcessorAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount) { - runProcessorAndWaitForFiles(runner, expectedFlowFileCount, false); - } + private TestRunner createConfiguredRunner(final String streamName, final String appName) throws Exception { + final TestRunner configuredRunner = TestRunners.newTestRunner(FastTimingConsumeKinesis.class); + + final String credsId = "creds-" + (credentialServiceCounter++); + addCredentialService(configuredRunner, credsId); - private void runProcessorAndWaitForFiles(final TestRunner runner, final int expectedFlowFileCount, final boolean withInit) { - logger.info("Running processor and waiting for {} files", expectedFlowFileCount); + configuredRunner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + configuredRunner.setProperty(ConsumeKinesis.APPLICATION_NAME, appName); + configuredRunner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, credsId); + configuredRunner.setProperty(RegionUtil.REGION, LOCALSTACK.getRegion()); + configuredRunner.setProperty(ConsumeKinesis.ENDPOINT_OVERRIDE, LOCALSTACK.getEndpoint().toString()); + configuredRunner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + configuredRunner.setProperty(ConsumeKinesis.MAX_BATCH_DURATION, "200 ms"); - if (withInit) { - runner.run(1, false, true); + return configuredRunner; + } + + /** + * Runs the processor with retries until at least one FlowFile appears on any output relationship + * (success or parse failure), or a 30-second deadline is reached. This guards against + * timing-sensitive tests on slow systems where LocalStack may not propagate records within a + * single 200ms batch window. + */ + private void runUntilOutput(final TestRunner testRunner) throws InterruptedException { + final long deadline = System.currentTimeMillis() + 30_000; + testRunner.run(1, false, true); + while (!hasOutput(testRunner) && System.currentTimeMillis() < deadline) { + Thread.sleep(200); + testRunner.run(1, false, false); } + testRunner.run(1, true, false); + } + + private boolean hasOutput(final TestRunner testRunner) { + return !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty() + || !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE).isEmpty(); + } - final Set relationships = runner.getProcessor().getRelationships(); + private boolean hasFlowFiles(final TestRunner testRunner) { + return !testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).isEmpty(); + } - while (true) { - runner.run(1, false, false); + private List getNewFlowFiles(final TestRunner testRunner, final int startIndex) { + final List allFlowFiles = testRunner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + return new ArrayList<>(allFlowFiles.subList(startIndex, allFlowFiles.size())); + } - final int currentCount = relationships.stream() - .map(runner::getFlowFilesForRelationship) - .mapToInt(Collection::size) - .sum(); - logger.info("Current files count: {}, expected: {}", currentCount, expectedFlowFileCount); + private void createStream(final String streamName) { + createStream(streamName, 1); + } - if (currentCount >= expectedFlowFileCount) { + private void createStream(final String streamName, final int shardCount) { + final CreateStreamRequest request = CreateStreamRequest.builder() + .streamName(streamName) + .shardCount(shardCount) + .build(); + kinesisClient.createStream(request); + waitForStreamActive(streamName); + } + + private void waitForStreamActive(final String streamName) { + final DescribeStreamRequest request = DescribeStreamRequest.builder().streamName(streamName).build(); + for (int i = 0; i < 300; i++) { + if (kinesisClient.describeStream(request).streamDescription().streamStatus() == StreamStatus.ACTIVE) { return; } try { - Thread.sleep(5_000); + Thread.sleep(100); } catch (final InterruptedException e) { Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for files", e); + throw new RuntimeException(e); } } - } - private MockProcessSession createFailingSession(final Processor processor) { - final SharedSessionState sharedState = new SharedSessionState(processor, new AtomicLong()); - - return MockProcessSession.builder(sharedState, processor) - .failCommit() - .build(); + throw new RuntimeException("Stream " + streamName + " did not become ACTIVE within 30 seconds"); } - public static class TestConsumeKinesis extends ConsumeKinesis { - - @Override - URI getKinesisEndpointOverride() { - return localstack.getEndpoint(); + private void publishAggregatedRecord(final String streamName, final String outerPartitionKey, + final List partitionKeys, final List jsonPayloads, final List pkIndices) { + final Messages.AggregatedRecord.Builder builder = Messages.AggregatedRecord.newBuilder(); + for (final String pk : partitionKeys) { + builder.addPartitionKeyTable(pk); } - - @Override - URI getDynamoDbEndpointOverride() { - return localstack.getEndpoint(); + for (int i = 0; i < jsonPayloads.size(); i++) { + builder.addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(pkIndices.get(i)) + .setData(ByteString.copyFromUtf8(jsonPayloads.get(i)))); } - @Override - URI getCloudwatchEndpointOverride() { - return localstack.getEndpoint(); + final byte[] protobufBytes = builder.build().toByteArray(); + final byte[] payload; + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + out.write(ProducerLibraryDeaggregator.KPL_MAGIC); + out.write(protobufBytes); + out.write(md5); + payload = out.toByteArray(); + } catch (final Exception e) { + throw new RuntimeException(e); } - } - - /** - * Test client wrapper for Kinesis operations with built-in retry logic and error handling. - */ - private static class TestKinesisStreamClient { - private static final Logger logger = LoggerFactory.getLogger(TestKinesisStreamClient.class); - - private static final int MAX_RETRIES = 10; - private static final long INITIAL_RETRY_DELAY_MILLIS = 1_000; - private static final long MAX_RETRY_DELAY_MILLIS = 60 * 1_000; - private static final Duration STREAM_WAIT_TIMEOUT = Duration.ofMinutes(2); - - private final KinesisClient kinesisClient; - private final String streamName; + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(outerPartitionKey) + .data(SdkBytes.fromByteArray(payload)) + .build()); + } - TestKinesisStreamClient(KinesisClient kinesisClient, String streamName) { - this.kinesisClient = kinesisClient; - this.streamName = streamName; + private void publishRecords(final String streamName, final int count) { + for (int i = 0; i < count; i++) { + publishRecord(streamName, i); } + } - void createStream(final int shardCount) { - logger.info("Creating stream: {} with {} shards", streamName, shardCount); + private void publishRecord(final String streamName, final int recordId) { + final String json = """ + {"id": %d, "name": "record-%d"}""".formatted(recordId, recordId); + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey("key-" + recordId) + .data(SdkBytes.fromUtf8String(json)) + .build()); + } - executeWithRetry( - "createStream", - () -> kinesisClient.createStream(req -> req.streamName(streamName).shardCount(shardCount))); + private void publishCorruptRecord(final String streamName, final String partitionKey, final String corruptData) { + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(partitionKey) + .data(SdkBytes.fromUtf8String(corruptData)) + .build()); + } - waitForStreamActive(); - logger.info("Stream {} is now active", streamName); + private static byte[] serializeAvroContainer(final Schema schema, final GenericRecord... records) { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + final GenericDatumWriter datumWriter = new GenericDatumWriter<>(schema); + try (final DataFileWriter fileWriter = new DataFileWriter<>(datumWriter)) { + fileWriter.create(schema, baos); + for (final GenericRecord record : records) { + fileWriter.append(record); + } + } + return baos.toByteArray(); + } catch (final IOException e) { + throw new RuntimeException(e); } + } - List getEnhancedFanOutConsumerNames() { - final String arn = describeStream().streamARN(); + private void publishAvroRecord(final String streamName, final String partitionKey, final Schema schema, + final GenericRecord record) { + final byte[] avroBytes = serializeAvroContainer(schema, record); + kinesisClient.putRecord(PutRecordRequest.builder() + .streamName(streamName) + .partitionKey(partitionKey) + .data(SdkBytes.fromByteArray(avroBytes)) + .build()); + } - final ListStreamConsumersResponse response = executeWithRetry( - "listStreamConsumers", - () -> kinesisClient.listStreamConsumers(req -> req.streamARN(arn)) - ); + private void configureAvroRecordOriented(final String streamName) throws Exception { + final AvroReader avroReader = new AvroReader(); + runner.addControllerService("avro-reader", avroReader); + runner.enableControllerService(avroReader); - return response.consumers().stream() - .map(Consumer::consumerName) - .toList(); - } + final AvroRecordSetWriter avroWriter = new AvroRecordSetWriter(); + runner.addControllerService("avro-writer", avroWriter); + runner.enableControllerService(avroWriter); - private StreamDescription describeStream() { - final DescribeStreamResponse response = executeWithRetry( - "describeStream", - () -> kinesisClient.describeStream(req -> req.streamName(streamName)) - ); + runner.setProperty(ConsumeKinesis.STREAM_NAME, streamName); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "avro-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "avro-writer"); + } - return response.streamDescription(); + static class FailingRecordSetWriterFactory extends AbstractControllerService implements RecordSetWriterFactory { + @Override + public RecordSchema getSchema(final Map variables, final RecordSchema readSchema) { + return readSchema; } - void deleteStream() { - logger.info("Deleting stream: {}", streamName); - - executeWithRetry( - "deleteStream", - () -> kinesisClient.deleteStream(req -> req.streamName(streamName).enforceConsumerDeletion(true))); + @Override + public RecordSetWriter createWriter(final ComponentLog logger, final RecordSchema schema, + final OutputStream out, final Map variables) { + return new FailingRecordSetWriter(); } + } - void putRecord(final String partitionKey, final String data) { - final SdkBytes bytes = SdkBytes.fromString(data, UTF_8); - - executeWithRetry( - "putRecord", - () -> kinesisClient.putRecord(req -> req.streamName(streamName).partitionKey(partitionKey).data(bytes))); + private static class FailingRecordSetWriter implements RecordSetWriter { + @Override + public WriteResult write(final Record record) throws IOException { + throw new IOException("Simulated write failure for rollback testing"); } - void putRecords(final String partitionKey, final List data) { - final List records = data.stream() - .map(it -> PutRecordsRequestEntry.builder() - .data(SdkBytes.fromString(it, UTF_8)) - .partitionKey(partitionKey) - .build()) - .toList(); - - executeWithRetry( - "putRecords", - () -> kinesisClient.putRecords(req -> req.streamName(streamName).records(records))); + @Override + public WriteResult write(final RecordSet recordSet) throws IOException { + throw new IOException("Simulated write failure for rollback testing"); } - /** - * Adjusts a number of shards for the stream. - * Note: in order to ensure new shards become active, the method waits for 30 seconds. - */ - void reshardStream(final int targetShardCount) { - logger.info("Resharding stream {} to {} shards", streamName, targetShardCount); + @Override + public void beginRecordSet() { + } - executeWithRetry( - "reshardStream", - () -> kinesisClient.updateShardCount(req -> req.streamName(streamName).targetShardCount(targetShardCount).scalingType(ScalingType.UNIFORM_SCALING))); + @Override + public WriteResult finishRecordSet() { + return WriteResult.of(0, Map.of()); + } - waitForStreamActive(); + @Override + public String getMimeType() { + return "application/json"; + } - try { - // After resharding new messages can still be put into the older shards for some time, so we wait a bit. - Thread.sleep(30_000); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } + @Override + public void flush() { + } - logger.info("Stream {} resharding completed", streamName); - } - - private void waitForStreamActive() { - final long timeoutMillis = System.currentTimeMillis() + STREAM_WAIT_TIMEOUT.toMillis(); - - while (System.currentTimeMillis() < timeoutMillis) { - try { - final StreamStatus status = describeStream().streamStatus(); - if (status == StreamStatus.ACTIVE) { - return; - } - - logger.info("Stream {} status: {}, waiting...", streamName, status); - Thread.sleep(1000); - } catch (final InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for stream to be active", e); - } catch (final RuntimeException e) { - logger.warn("Error checking stream status for {}: {}", streamName, e.getMessage()); - try { - Thread.sleep(1000); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted while waiting for stream to be active", ie); - } - } - } + @Override + public void close() { + } + } - throw new IllegalStateException("Stream " + streamName + " did not become active within timeout"); - } - - private T executeWithRetry(final String operation, final Callable op) { - Exception lastException = null; - - for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) { - try { - return op.call(); - } catch (final Exception e) { - lastException = e; - logger.warn("Attempt {} of {} failed for operation {}: {}", - attempt, MAX_RETRIES, operation, e.getMessage()); - - if (attempt < MAX_RETRIES) { - try { - final long delayMillis = INITIAL_RETRY_DELAY_MILLIS * (1 << (attempt - 1)); - Thread.sleep(Math.min(delayMillis, MAX_RETRY_DELAY_MILLIS)); - } catch (final InterruptedException ie) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Thread interrupted during retry delay", ie); - } - } - } - } + public static class FastTimingConsumeKinesis extends ConsumeKinesis { + private static final long TEST_SHARD_CACHE_MILLIS = 500; + private static final long TEST_LEASE_DURATION_MILLIS = 3_000; + private static final long TEST_LEASE_REFRESH_INTERVAL_MILLIS = 1_000; + private static final long TEST_NODE_HEARTBEAT_EXPIRATION_MILLIS = 4_000; - throw new IllegalStateException("Operation " + operation + " failed after " + MAX_RETRIES + " attempts", lastException); + @Override + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return new KinesisShardManager( + kinesisClient, + dynamoDbClient, + logger, + checkpointTableName, + streamName, + TEST_SHARD_CACHE_MILLIS, + TEST_LEASE_DURATION_MILLIS, + TEST_LEASE_REFRESH_INTERVAL_MILLIS, + TEST_NODE_HEARTBEAT_EXPIRATION_MILLIS); } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java index 1024896b2c9c..351f161ff6c0 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ConsumeKinesisTest.java @@ -16,85 +16,534 @@ */ package org.apache.nifi.processors.aws.kinesis; +import org.apache.nifi.json.JsonRecordSetWriter; +import org.apache.nifi.json.JsonTreeReader; +import org.apache.nifi.logging.ComponentLog; import org.apache.nifi.processor.Relationship; import org.apache.nifi.processors.aws.credentials.provider.service.AWSCredentialsProviderControllerService; -import org.apache.nifi.processors.aws.region.RegionUtil; -import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.util.MockFlowFile; +import org.apache.nifi.util.PropertyMigrationResult; import org.apache.nifi.util.TestRunner; import org.apache.nifi.util.TestRunners; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Shard; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Set; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.PROCESSING_STRATEGY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.DEMARCATOR; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.FLOW_FILE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.ProcessingStrategy.RECORD; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_PARSE_FAILURE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesis.REL_SUCCESS; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; class ConsumeKinesisTest { - private TestRunner testRunner; + private TestRunner runner; @BeforeEach - void setUp() { - testRunner = createTestRunner(); + void setUp() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + final JsonTreeReader reader = new JsonTreeReader(); + runner.addControllerService("json-reader", reader); + runner.enableControllerService(reader); + + final JsonRecordSetWriter writer = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", writer); + runner.enableControllerService(writer); + } + + private void setCommonProperties() throws Exception { + final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); + runner.addControllerService("creds", credentialsService); + runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, "AK_STUB"); + runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, "SK_STUB"); + runner.enableControllerService(credentialsService); + + runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "test-app"); + runner.setProperty(ConsumeKinesis.STREAM_NAME, "test-stream"); + runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "creds"); } @Test - void getRelationshipsForFlowFileProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, FLOW_FILE); + void testProcessingStrategyValidation() throws Exception { + setCommonProperties(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + runner.assertValid(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "LINE_DELIMITED"); + runner.assertValid(); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.assertNotValid(); - final Set relationships = testRunner.getProcessor().getRelationships(); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.assertNotValid(); - assertEquals(Set.of(REL_SUCCESS), relationships); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.assertValid(); } @Test - void getRelationshipsForRecordProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, RECORD); + void testAllValidRecordsRoutedToSuccess() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "{\"name\":\"Charlie\"}")); - final Set relationships = testRunner.getProcessor().getRelationships(); + triggerWithRecords(records); - assertEquals(Set.of(REL_SUCCESS, REL_PARSE_FAILURE), relationships); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "3"); } @Test - void getRelationshipsForDemarcatorProcessingStrategy() { - testRunner.setProperty(PROCESSING_STRATEGY, DEMARCATOR); + void testSingleInvalidRecordRoutedToParseFailure() throws Exception { + assertInvalidRecordAtPosition("1", "THIS IS NOT JSON", + testRecord("1", "THIS IS NOT JSON"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "{\"name\":\"Charlie\"}")); + assertInvalidRecordAtPosition("2", "CORRUPT DATA HERE", + testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "CORRUPT DATA HERE"), testRecord("3", "{\"name\":\"Charlie\"}")); + assertInvalidRecordAtPosition("3", "NOT VALID JSON!!!", + testRecord("1", "{\"name\":\"Alice\"}"), testRecord("2", "{\"name\":\"Bob\"}"), testRecord("3", "NOT VALID JSON!!!")); + } + + @Test + void testMultipleInvalidRecordsInBatch() throws Exception { + final List records = List.of( + testRecord("1", "BAD FIRST"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "BAD THIRD"), + testRecord("4", "{\"name\":\"Dave\"}"), + testRecord("5", "BAD FIFTH")); + + triggerWithRecords(records); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 3); - final Set relationships = testRunner.getProcessor().getRelationships(); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); - assertEquals(Set.of(REL_SUCCESS), relationships); + final List failures = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE); + final List failureSequences = new ArrayList<>(); + for (final MockFlowFile flowFile : failures) { + failureSequences.add(flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE)); + } + assertTrue(failureSequences.contains("1")); + assertTrue(failureSequences.contains("3")); + assertTrue(failureSequences.contains("5")); } - private static TestRunner createTestRunner() { - final TestRunner runner = TestRunners.newTestRunner(ConsumeKinesis.class); + @Test + void testAllInvalidRecordsRoutedToParseFailure() throws Exception { + final List records = List.of( + testRecord("1", "BAD1"), + testRecord("2", "BAD2"), + testRecord("3", "BAD3")); + + triggerWithRecords(records); - final AWSCredentialsProviderControllerService credentialsService = new AWSCredentialsProviderControllerService(); - try { - runner.addControllerService("credentials", credentialsService); - } catch (final InitializationException e) { - throw new RuntimeException(e); + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 0); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 3); + } + + @Test + void testFlowFilePerRecordDeliversAllRecords() throws Exception { + final List records = List.of( + testRecord("1", "record-one"), + testRecord("2", "record-two"), + testRecord("3", "record-three")); + + triggerWithStrategy(records, "FLOW_FILE", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 3); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + for (final MockFlowFile flowFile : flowFiles) { + flowFile.assertAttributeEquals("record.count", "1"); + flowFile.assertAttributeEquals(ConsumeKinesis.ATTR_STREAM_NAME, "test-stream"); + flowFile.assertAttributeEquals(ConsumeKinesis.ATTR_SHARD_ID, "shardId-000000000001"); + final String firstSequence = flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SEQUENCE); + final String lastSequence = flowFile.getAttribute(ConsumeKinesis.ATTR_LAST_SEQUENCE); + assertEquals(firstSequence, lastSequence); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_PARTITION_KEY)); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_FIRST_SUBSEQUENCE)); + assertNotNull(flowFile.getAttribute(ConsumeKinesis.ATTR_LAST_SUBSEQUENCE)); } - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.ACCESS_KEY_ID, "123"); - runner.setProperty(credentialsService, AWSCredentialsProviderControllerService.SECRET_KEY, "123"); - runner.enableControllerService(credentialsService); - runner.setProperty(ConsumeKinesis.AWS_CREDENTIALS_PROVIDER_SERVICE, "credentials"); - runner.setProperty(ConsumeKinesis.STREAM_NAME, "stream"); - runner.setProperty(ConsumeKinesis.APPLICATION_NAME, "application"); - runner.setProperty(RegionUtil.REGION, "us-west-2"); - runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, ConsumeKinesis.InitialPosition.TRIM_HORIZON); - runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, ConsumeKinesis.ProcessingStrategy.FLOW_FILE); + flowFiles.get(0).assertContentEquals("record-one"); + flowFiles.get(1).assertContentEquals("record-two"); + flowFiles.get(2).assertContentEquals("record-three"); + } + + @Test + void testDemarcatorDeliversAllRecords() throws Exception { + final List records = List.of( + testRecord("1", "line-one"), + testRecord("2", "line-two"), + testRecord("3", "line-three")); + + triggerWithStrategy(records, "LINE_DELIMITED", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertContentEquals("line-one\nline-two\nline-three"); + success.assertAttributeEquals("record.count", "3"); + } - runner.setProperty(ConsumeKinesis.METRICS_PUBLISHING, ConsumeKinesis.MetricsPublishing.CLOUDWATCH); + @Test + void testDelimitedUsesLatestArrivalTimestamp() throws Exception { + final Instant firstArrival = Instant.parse("2025-01-15T00:00:00Z"); + final Instant secondArrival = Instant.parse("2025-01-15T00:00:05Z"); + final Instant thirdArrival = Instant.parse("2025-01-15T00:00:03Z"); + final List records = List.of( + testRecord("1", "line-one", firstArrival), + testRecord("2", "line-two", secondArrival), + testRecord("3", "line-three", thirdArrival)); + + triggerWithStrategy(records, "LINE_DELIMITED", "shardId-000000000001"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals(ConsumeKinesis.ATTR_ARRIVAL_TIMESTAMP, String.valueOf(secondArrival.toEpochMilli())); + } - runner.setProperty(ConsumeKinesis.MAX_BYTES_TO_BUFFER, "10 MB"); + @Test + void testMultipleShardsNoDataLoss() throws Exception { + final ShardFetchResult shard1Result = new ShardFetchResult("shard-A", + List.of(testRecord("10", "{\"id\":1}"), testRecord("20", "{\"id\":2}")), 0L); + final ShardFetchResult shard2Result = new ShardFetchResult("shard-B", + List.of(testRecord("30", "{\"id\":3}"), testRecord("40", "{\"id\":4}")), 0L); + + triggerWithResults(List.of(shard1Result, shard2Result), "RECORD"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 2); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final List flowFiles = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS); + final Set shardsSeen = new LinkedHashSet<>(); + long totalRecords = 0; + for (final MockFlowFile flowFile : flowFiles) { + shardsSeen.add(flowFile.getAttribute(ConsumeKinesis.ATTR_SHARD_ID)); + totalRecords += Long.parseLong(flowFile.getAttribute("record.count")); + } + assertEquals(Set.of("shard-A", "shard-B"), shardsSeen); + assertEquals(4, totalRecords); + } + + @Test + void testRecordMetadataInjectionPreservesRecordCount() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}"), + testRecord("3", "{\"name\":\"Charlie\"}")); + + triggerWithOutputStrategy(records, "INJECT_METADATA"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "3"); + + final String content = success.getContent(); + assertTrue(content.contains("kinesisMetadata")); + assertTrue(content.contains("\"stream\"")); + assertTrue(content.contains("\"shardId\"")); + assertTrue(content.contains("\"sequenceNumber\"")); + assertTrue(content.contains("\"partitionKey\"")); + } + + @Test + void testUseWrapperOutputStrategy() throws Exception { + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + testRecord("2", "{\"name\":\"Bob\"}")); + + triggerWithOutputStrategy(records, "USE_WRAPPER"); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 0); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + + final String content = success.getContent(); + assertTrue(content.contains("kinesisMetadata")); + assertTrue(content.contains("value")); + assertTrue(content.contains("Alice")); + } + + @Test + void testAtTimestampInitialPositionRequiresTimestamp() throws Exception { + setCommonProperties(); + runner.setProperty(ConsumeKinesis.INITIAL_STREAM_POSITION, "AT_TIMESTAMP"); + runner.assertNotValid(); + + runner.setProperty(ConsumeKinesis.STREAM_POSITION_TIMESTAMP, "2025-01-15T00:00:00Z"); + runner.assertValid(); + } + + @Test + void testPropertyMigrationRenamesMaxBytesToBuffer() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + setCommonProperties(); + runner.setProperty("Max Bytes to Buffer", "5 MB"); + + final PropertyMigrationResult result = runner.migrateProperties(); + assertTrue(result.getPropertiesRenamed().containsKey("Max Bytes to Buffer")); + assertEquals("Max Batch Size", result.getPropertiesRenamed().get("Max Bytes to Buffer")); + assertEquals("5 MB", runner.getProcessContext().getProperty(ConsumeKinesis.MAX_BATCH_SIZE).getValue()); + } + + @Test + void testPropertyMigrationRemovesCheckpointInterval() throws Exception { + runner = TestRunners.newTestRunner(ConsumeKinesis.class); + + setCommonProperties(); + runner.setProperty("Checkpoint Interval", "5 min"); + + final PropertyMigrationResult result = runner.migrateProperties(); + assertTrue(result.getPropertiesRemoved().contains("Checkpoint Interval")); + } + + @Test + void testDynamicRelationships() throws Exception { + setCommonProperties(); - return runner; + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "FLOW_FILE"); + assertEquals(Set.of("success"), collectRelationshipNames()); + + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + assertEquals(Set.of("success", "parse.failure"), collectRelationshipNames()); + } + + @Test + void testEmptyRecordDoesNotCauseStuckState() throws Exception { + final UserRecord emptyRecord = new UserRecord("shardId-000000000001", "2", 0, "pk-2", new byte[0], Instant.now()); + + final List records = List.of( + testRecord("1", "{\"name\":\"Alice\"}"), + emptyRecord, + testRecord("3", "{\"name\":\"Charlie\"}")); + + triggerWithRecords(records); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + } + + private void triggerWithRecords(final List records) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); + final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + + runner.run(); + } + + private Set collectRelationshipNames() { + final Set names = new LinkedHashSet<>(); + for (final Relationship relationship : runner.getProcessor().getRelationships()) { + names.add(relationship.getName()); + } + return names; + } + + private void assertInvalidRecordAtPosition(final String expectedFailureSequence, final String expectedFailureContent, + final UserRecord... records) throws Exception { + triggerWithRecords(List.of(records)); + + runner.assertTransferCount(ConsumeKinesis.REL_SUCCESS, 1); + runner.assertTransferCount(ConsumeKinesis.REL_PARSE_FAILURE, 1); + + final MockFlowFile success = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_SUCCESS).getFirst(); + success.assertAttributeEquals("record.count", "2"); + + final MockFlowFile failure = runner.getFlowFilesForRelationship(ConsumeKinesis.REL_PARSE_FAILURE).getFirst(); + failure.assertContentEquals(expectedFailureContent); + failure.assertAttributeEquals(ConsumeKinesis.ATTR_FIRST_SEQUENCE, expectedFailureSequence); + assertNotNull(failure.getAttribute(ConsumeKinesis.ATTR_RECORD_ERROR_MESSAGE)); + } + + private void triggerWithOutputStrategy(final List records, final String outputStrategy) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager("shardId-000000000001"); + final ShardFetchResult fetchResult = new ShardFetchResult("shardId-000000000001", records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, "RECORD"); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + runner.setProperty(ConsumeKinesis.OUTPUT_STRATEGY, outputStrategy); + + runner.run(); + } + + private void triggerWithStrategy(final List records, final String processingStrategy, + final String shardId) throws Exception { + final KinesisShardManager mockShardManager = buildShardManager(shardId); + final ShardFetchResult fetchResult = new ShardFetchResult(shardId, records, 0L); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, fetchResult); + runner = TestRunners.newTestRunner(processor); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, processingStrategy); + + runner.run(); + } + + private void triggerWithResults(final List results, final String processingStrategy) throws Exception { + final Set shardIds = new LinkedHashSet<>(); + for (final ShardFetchResult fetchResult : results) { + shardIds.add(fetchResult.shardId()); + } + + final KinesisShardManager mockShardManager = buildShardManager(shardIds.toArray(new String[0])); + final TestableConsumeKinesis processor = new TestableConsumeKinesis(mockShardManager, results); + runner = TestRunners.newTestRunner(processor); + + final JsonTreeReader jsonReader = new JsonTreeReader(); + runner.addControllerService("json-reader", jsonReader); + runner.enableControllerService(jsonReader); + + final JsonRecordSetWriter jsonWriter = new JsonRecordSetWriter(); + runner.addControllerService("json-writer", jsonWriter); + runner.enableControllerService(jsonWriter); + + setCommonProperties(); + runner.setProperty(ConsumeKinesis.PROCESSING_STRATEGY, processingStrategy); + runner.setProperty(ConsumeKinesis.RECORD_READER, "json-reader"); + runner.setProperty(ConsumeKinesis.RECORD_WRITER, "json-writer"); + + runner.run(); + } + + private static KinesisShardManager buildShardManager(final String... shardIds) { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + final List shards = new ArrayList<>(); + for (final String id : shardIds) { + shards.add(Shard.builder().shardId(id).build()); + } + when(mockShardManager.getOwnedShards()).thenReturn(shards); + when(mockShardManager.getCachedShardCount()).thenReturn(shardIds.length); + when(mockShardManager.shouldProcessFetchedResult(anyString())).thenReturn(true); + return mockShardManager; + } + + private static UserRecord testRecord(final String sequenceNumber, final String data) { + return testRecord(sequenceNumber, data, Instant.now()); + } + + private static UserRecord testRecord(final String sequenceNumber, final String data, final Instant arrivalTimestamp) { + return new UserRecord( + "shardId-000000000001", + sequenceNumber, + 0, + "pk-" + sequenceNumber, + data.getBytes(StandardCharsets.UTF_8), + arrivalTimestamp); + } + + static class TestableConsumeKinesis extends ConsumeKinesis { + private final KinesisShardManager mockShardManager; + private final List preloadedResults; + + TestableConsumeKinesis(final KinesisShardManager mockShardManager, final ShardFetchResult preloadedResult) { + this(mockShardManager, List.of(preloadedResult)); + } + + TestableConsumeKinesis(final KinesisShardManager mockShardManager, final List preloadedResults) { + this.mockShardManager = mockShardManager; + this.preloadedResults = preloadedResults; + } + + @Override + protected KinesisShardManager createShardManager(final KinesisClient kinesisClient, final DynamoDbClient dynamoDbClient, + final ComponentLog logger, final String checkpointTableName, final String streamName) { + return mockShardManager; + } + + @Override + protected KinesisConsumerClient createConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger, + final boolean efoMode) { + final KinesisConsumerClient client = new StubConsumerClient(mock(KinesisClient.class), logger); + for (final ShardFetchResult result : preloadedResults) { + client.enqueueResult(result); + } + return client; + } + } + + static class StubConsumerClient extends KinesisConsumerClient { + StubConsumerClient(final KinesisClient kinesisClient, final ComponentLog logger) { + super(kinesisClient, logger); + } + + @Override + void startFetches(final List shards, final String streamName, final int batchSize, + final String initialStreamPosition, final KinesisShardManager shardManager) { + } + + @Override + boolean hasPendingFetches() { + return hasQueuedResults(); + } + + @Override + void acknowledgeResults(final List results) { + } + + @Override + void rollbackResults(final List results) { + } + + @Override + void removeUnownedShards(final Set ownedShards) { + } + + @Override + void logDiagnostics(final int ownedCount, final int cachedShardCount) { + } } } diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java deleted file mode 100644 index e0f3a40e47ee..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/JsonRecordAssert.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.util.MockFlowFile; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; - -final class JsonRecordAssert { - - private static final ObjectMapper MAPPER = new ObjectMapper(); - - static void assertFlowFileRecords(final FlowFile flowFile, final KinesisClientRecord... expectedRecords) { - assertFlowFileRecords(flowFile, List.of(expectedRecords)); - } - - static void assertFlowFileRecords(final FlowFile flowFile, final List expectedRecords) { - final List expectedPayloads = expectedRecords.stream() - .map(KinesisRecordPayload::extract) - .toList(); - assertFlowFileRecordPayloads(flowFile, expectedPayloads); - } - - static void assertFlowFileRecordPayloads(final FlowFile flowFile, final String... expectedPayloads) { - assertFlowFileRecordPayloads(flowFile, List.of(expectedPayloads)); - } - - static void assertFlowFileRecordPayloads(final FlowFile flowFile, final List expectedPayloads) { - try { - final MockFlowFile mockFlowFile = assertInstanceOf(MockFlowFile.class, flowFile, "A passed FlowFile should be an instance of MockFlowFile"); - - final JsonNode node = MAPPER.readTree(mockFlowFile.getContent()); - final ArrayNode array = assertInstanceOf(ArrayNode.class, node, "FlowFile content is expected to be an array"); - - assertEquals(expectedPayloads.size(), array.size(), "Array size mismatch"); - - for (int i = 0; i < expectedPayloads.size(); i++) { - final JsonNode recordNode = array.get(i); - final JsonNode expectedNode = MAPPER.readTree(expectedPayloads.get(i)); - assertEquals(recordNode, expectedNode, "Record at index " + i + " does not match expected JSON"); - } - - } catch (final JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - private JsonRecordAssert() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java new file mode 100644 index 000000000000..b7c9e102c0df --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisConsumerClientTest.java @@ -0,0 +1,502 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; +import software.amazon.awssdk.services.kinesis.model.StartingPosition; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardRequest; +import software.amazon.awssdk.services.kinesis.model.SubscribeToShardResponseHandler; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KinesisConsumerClientTest { + + /** + * Verifies that when an EFO subscription expires and is renewed, the renewal uses the + * maximum of lastAcknowledged, lastQueued, and DynamoDB checkpoint. Here the acknowledged + * sequence (99999) exceeds the checkpoint (11111), so the renewal should use 99999. + */ + @Test + void testSubscriptionRenewalUsesLastAcknowledgedSequenceNumber() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("11111"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(1, capturedRequests.size()); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, capturedRequests.get(0).startingPosition().type()); + assertEquals("11111", capturedRequests.get(0).startingPosition().sequenceNumber(), + "Initial subscription should use the DynamoDB checkpoint"); + + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.setLastQueuedSequenceNumber(new BigInteger("99999")); + consumer.resetForRenewal(); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, capturedRequests.get(1).startingPosition().type()); + assertEquals("99999", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued, checkpoint) = lastQueued"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that when a ShardConsumer has no acknowledged data, the renewal falls back to + * the DynamoDB checkpoint. + */ + @Test + void testSubscriptionRenewalFallsBackToCheckpointWhenNoQueuedData() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("55555"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.resetForRenewal(); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("55555", capturedRequests.get(0).startingPosition().sequenceNumber()); + assertEquals("55555", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal with no acknowledged data should fall back to DynamoDB checkpoint"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that renewal uses the lastQueuedSequenceNumber when it exceeds the checkpoint. + */ + @Test + void testSubscriptionRenewalUsesLastQueuedSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("10000"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + consumer.setLastQueuedSequenceNumber(new BigInteger("20000")); + consumer.resetForRenewal(); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("20000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=20000, checkpoint=10000) = lastQueued"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that renewal always uses the maximum of lastQueued and the DynamoDB checkpoint. + * This prevents one-event replay duplicates caused by races between onNext counter updates + * and concurrent handler onError callbacks. + */ + @Test + void testSubscriptionRenewalAlwaysUsesMaxSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000"); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=90000, checkpoint=50000) = 90000"); + + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "95000"); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(3, capturedRequests.size()); + assertEquals("95000", capturedRequests.get(2).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=95000, checkpoint=50000) = 95000"); + + verify(mockShardManager, times(3)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that polling a queued result does not affect the renewal position. The renewal + * uses max(lastQueued, checkpoint) regardless of whether results have been polled. + */ + @Test + void testSubscriptionRenewalAfterPollBeforeAcknowledgeUsesMaxSequence() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + simulateExpiredSubscriptionWithState(client, "shardId-000000000001", "90000"); + client.enqueueResult(shardFetchResult("shardId-000000000001", "90000")); + + final ShardFetchResult polled = client.pollShardResult("shardId-000000000001"); + assertNotNull(polled, "Expected queued result to be polled"); + + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + assertEquals(2, capturedRequests.size()); + assertEquals("90000", capturedRequests.get(1).startingPosition().sequenceNumber(), + "Renewal should use max(lastQueued=90000, checkpoint=50000) = 90000"); + + verify(mockShardManager, times(2)).readCheckpoint("shardId-000000000001"); + } + + /** + * Verifies that acknowledging multiple fetched results from the same shard requests only one + * additional EFO event for that shard. + */ + @Test + void testAcknowledgeResultsRequestsNextOncePerShard() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn("50000"); + + final List capturedRequests = new ArrayList<>(); + final EnhancedFanOutClient client = createEfoClient(capturedRequests); + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", mockShardManager); + + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer("shardId-000000000001"); + final Subscription subscription = mock(Subscription.class); + consumer.setSubscription(subscription); + consumer.pause(); + + client.acknowledgeResults(List.of( + shardFetchResult("shardId-000000000001", "60000"), + shardFetchResult("shardId-000000000001", "61000"))); + + verify(subscription, times(1)).request(1); + } + + /** + * Verifies that concurrent startup calls do not create duplicate initial subscriptions + * for the same shard. + */ + @Test + void testConcurrentStartFetchesCreatesSingleInitialSubscriptionPerShard() throws Exception { + final KinesisShardManager mockShardManager = mock(KinesisShardManager.class); + when(mockShardManager.readCheckpoint("shardId-000000000001")).thenReturn(null); + + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + final List capturedRequests = new ArrayList<>(); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenAnswer(invocation -> { + capturedRequests.add(invocation.getArgument(0)); + Thread.sleep(50L); + return new CompletableFuture<>(); + }); + + final EnhancedFanOutClient client = new EnhancedFanOutClient(mock(KinesisClient.class), mock(ComponentLog.class)); + client.initializeForTest(mockAsyncClient, "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); + + final List shards = List.of(Shard.builder().shardId("shardId-000000000001").build()); + final CountDownLatch startLatch = new CountDownLatch(1); + + final Thread t1 = new Thread(() -> runStartFetches(client, shards, mockShardManager, startLatch)); + final Thread t2 = new Thread(() -> runStartFetches(client, shards, mockShardManager, startLatch)); + t1.start(); + t2.start(); + startLatch.countDown(); + t1.join(5_000L); + t2.join(5_000L); + + assertEquals(1, capturedRequests.size(), + "Concurrent startup should create only one initial SubscribeToShard request per shard"); + } + + private static EnhancedFanOutClient createEfoClient(final List capturedRequests) { + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenAnswer(invocation -> { + capturedRequests.add(invocation.getArgument(0)); + return CompletableFuture.completedFuture(null); + }); + + final EnhancedFanOutClient client = new EnhancedFanOutClient(mock(KinesisClient.class), mock(ComponentLog.class)); + client.initializeForTest(mockAsyncClient, "arn:aws:kinesis:us-east-1:123456789:stream/test/consumer/test:1"); + return client; + } + + private static void simulateExpiredSubscriptionWithState( + final EnhancedFanOutClient client, + final String shardId, + final String lastQueuedSeq) { + final EnhancedFanOutClient.ShardConsumer consumer = client.getShardConsumer(shardId); + consumer.resetForRenewal(); + consumer.setLastQueuedSequenceNumber(new BigInteger(lastQueuedSeq)); + } + + /** + * Verifies that a stale error callback from an old subscription does not corrupt the state + * of a newer subscription. This tests the generation counter mechanism that prevents a race + * between the response handler's onError and the subscriber's onError when they fire on + * different threads for the same error, with a new subscription created in between. + * + *

The race without the generation counter: + *

    + *
  1. Response handler onError fires on Netty thread: subscribing = false
  2. + *
  3. NiFi thread creates new subscription B: subscribing = true
  4. + *
  5. Old subscriber onError fires: nulls B's subscription, subscribing = false
  6. + *
  7. B's state is corrupted, leading to duplicate subscriptions
  8. + *
+ */ + @Test + void testStaleErrorCallbackDoesNotCorruptNewSubscription() throws Exception { + final ComponentLog mockLogger = mock(ComponentLog.class); + final KinesisAsyncClient mockAsyncClient = mock(KinesisAsyncClient.class); + + when(mockAsyncClient.subscribeToShard(any(SubscribeToShardRequest.class), any(SubscribeToShardResponseHandler.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + final EnhancedFanOutClient.ShardConsumer consumer = + new EnhancedFanOutClient.ShardConsumer("shardId-000000000001", result -> { }, new ConcurrentLinkedQueue<>(), mockLogger); + + final StartingPosition pos = StartingPosition.builder() + .type(ShardIteratorType.TRIM_HORIZON) + .build(); + + consumer.subscribe(mockAsyncClient, "test-arn", pos); + final int gen1 = consumer.getSubscriptionGeneration(); + assertEquals(1, gen1); + + consumer.setSubscription(mock(Subscription.class)); + + assertNotNull(consumer.getSubscription(), "Subscription should be set after onSubscribe"); + + consumer.endSubscriptionIfCurrent(gen1); + assertFalse(consumer.isSubscribing(), "subscribing should be false after endSubscription"); + + consumer.subscribe(mockAsyncClient, "test-arn", pos); + final int gen2 = consumer.getSubscriptionGeneration(); + assertEquals(2, gen2); + + consumer.setSubscription(mock(Subscription.class)); + + assertNotNull(consumer.getSubscription(), "New subscription should be set"); + + consumer.endSubscriptionIfCurrent(gen1); + + assertNotNull(consumer.getSubscription(), + "Stale callback (gen1) must NOT null out gen2's subscription"); + assertTrue(consumer.isSubscribing(), + "Stale callback (gen1) must NOT reset gen2's subscribing flag"); + + consumer.endSubscriptionIfCurrent(gen2); + + assertFalse(consumer.isSubscribing(), + "Current-generation callback should clean up normally"); + } + + private static ShardFetchResult shardFetchResult(final String shardId, final String sequenceNumber) { + final UserRecord record = new UserRecord(shardId, sequenceNumber, 0, "pk", "{}".getBytes(), null); + return new ShardFetchResult(shardId, List.of(record), 0L); + } + + /** + * Verifies the per-shard exclusive processing lock: claimShard returns false when a shard is + * already claimed, and releaseShards makes it claimable again. Different shards are independent. + */ + @Test + void testShardClaimExclusivity() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + assertTrue(client.claimShard("shard-1"), "First claim should succeed"); + assertFalse(client.claimShard("shard-1"), "Duplicate claim should fail"); + assertTrue(client.claimShard("shard-2"), "Different shard claim should succeed independently"); + + client.releaseShards(List.of("shard-1")); + assertTrue(client.claimShard("shard-1"), "Claim after release should succeed"); + assertFalse(client.claimShard("shard-2"), "shard-2 was not released and should still be claimed"); + } + + /** + * Verifies that enqueueResult stores results in per-shard queues and that + * pollShardResult retrieves them in FIFO order. + */ + @Test + void testEnqueueResultIsPolledFromShardQueue() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + final ShardFetchResult result = shardFetchResult("shard-1", "12345"); + + client.enqueueResult(result); + + final ShardFetchResult polled = client.pollShardResult("shard-1"); + assertNotNull(polled, "Enqueued result should be available for polling from its shard queue"); + assertEquals("shard-1", polled.shardId()); + } + + /** + * Verifies that per-shard queues preserve FIFO order. When multiple results for the + * same shard are enqueued (potentially interleaved with other shards), polling that + * shard's queue must return them in enqueue order. This prevents the out-of-order + * delivery that occurred with the old flat-queue requeue-to-back design. + */ + @Test + void testPerShardOrderingPreservedAcrossEnqueues() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + client.enqueueResult(shardFetchResult("shard-5", "100")); + client.enqueueResult(shardFetchResult("shard-3", "500")); + client.enqueueResult(shardFetchResult("shard-5", "200")); + client.enqueueResult(shardFetchResult("shard-3", "600")); + client.enqueueResult(shardFetchResult("shard-5", "300")); + + assertEquals(new BigInteger("100"), client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals(new BigInteger("200"), client.pollShardResult("shard-5").firstSequenceNumber()); + assertEquals(new BigInteger("300"), client.pollShardResult("shard-5").firstSequenceNumber()); + assertNull(client.pollShardResult("shard-5"), "Queue should be empty after draining"); + + assertEquals(new BigInteger("500"), client.pollShardResult("shard-3").firstSequenceNumber()); + assertEquals(new BigInteger("600"), client.pollShardResult("shard-3").firstSequenceNumber()); + assertNull(client.pollShardResult("shard-3"), "Queue should be empty after draining"); + } + + /** + * Verifies that getShardIdsWithResults returns only shards with non-empty queues, + * and that totalQueuedResults reflects the actual count across all shard queues. + */ + @Test + void testQueueIntrospectionMethods() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + assertFalse(client.hasQueuedResults()); + assertEquals(0, client.totalQueuedResults()); + assertTrue(client.getShardIdsWithResults().isEmpty()); + + client.enqueueResult(shardFetchResult("shard-1", "10")); + client.enqueueResult(shardFetchResult("shard-2", "20")); + client.enqueueResult(shardFetchResult("shard-1", "30")); + + assertTrue(client.hasQueuedResults()); + assertEquals(3, client.totalQueuedResults()); + assertEquals(Set.of("shard-1", "shard-2"), new HashSet<>(client.getShardIdsWithResults())); + + client.pollShardResult("shard-1"); + client.pollShardResult("shard-1"); + + assertEquals(1, client.totalQueuedResults()); + assertEquals(List.of("shard-2"), client.getShardIdsWithResults()); + } + + /** + * Reproduces the scenario from the original out-of-order delivery bug. With a single + * flat queue and requeue-to-back, result B would end up behind C after being requeued, + * causing C to be delivered first. Per-shard queues eliminate this: when Task-2 cannot + * claim shard-1, it simply skips the shard queue. B and C remain in FIFO order for the + * next consumer. + */ + @Test + void testPerShardQueuesPreventOutOfOrderDeliveryAcrossInvocations() { + final PollingKinesisClient client = new PollingKinesisClient(mock(KinesisClient.class), mock(ComponentLog.class)); + + final ShardFetchResult resultA = shardFetchResult("shard-1", "100"); + final ShardFetchResult resultB = shardFetchResult("shard-1", "200"); + final ShardFetchResult resultC = shardFetchResult("shard-1", "300"); + final ShardFetchResult otherShardResult = shardFetchResult("shard-2", "999"); + + client.enqueueResult(resultA); + client.enqueueResult(resultB); + client.enqueueResult(otherShardResult); + client.enqueueResult(resultC); + + // Task-1: claims shard-1, polls A(100) from its per-shard queue + assertTrue(client.claimShard("shard-1"), "Task-1 should claim shard-1"); + final ShardFetchResult polledA = client.pollShardResult("shard-1"); + assertEquals(new BigInteger("100"), polledA.firstSequenceNumber()); + + // Task-2: cannot claim shard-1 (held by Task-1), so it skips shard-1 entirely. + // B and C remain at the head of shard-1's queue, undisturbed. + assertFalse(client.claimShard("shard-1"), "Task-2 cannot claim shard-1 (held by Task-1)"); + + // Task-2: claims shard-2 and polls its result + assertTrue(client.claimShard("shard-2")); + final ShardFetchResult polledOther = client.pollShardResult("shard-2"); + assertEquals(new BigInteger("999"), polledOther.firstSequenceNumber()); + + // Task-1 commits A and releases shard-1 + client.releaseShards(List.of("shard-1")); + + // Next invocation: shard-1's queue still has B(200) then C(300) in correct order + final ShardFetchResult firstPoll = client.pollShardResult("shard-1"); + final ShardFetchResult secondPoll = client.pollShardResult("shard-1"); + + assertNotNull(firstPoll, "Expected shard-1 queue to have B"); + assertNotNull(secondPoll, "Expected shard-1 queue to have C"); + assertEquals(new BigInteger("200"), firstPoll.firstSequenceNumber(), "First poll must be B(200), not C(300)"); + assertEquals(new BigInteger("300"), secondPoll.firstSequenceNumber(), "Second poll must be C(300)"); + assertNull(client.pollShardResult("shard-1"), "shard-1 queue should be empty"); + } + + private static void runStartFetches(final KinesisConsumerClient client, final List shards, + final KinesisShardManager shardManager, final CountDownLatch startLatch) { + try { + startLatch.await(); + client.startFetches(shards, "test-stream", 100, "TRIM_HORIZON", shardManager); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java deleted file mode 100644 index 8a03d71d244f..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisRecordPayload.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import static java.nio.charset.StandardCharsets.UTF_8; - -final class KinesisRecordPayload { - - static String extract(final KinesisClientRecord record) { - record.data().rewind(); - - final byte[] buffer = new byte[record.data().remaining()]; - record.data().get(buffer); - - return new String(buffer, UTF_8); - } - - private KinesisRecordPayload() { - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java new file mode 100644 index 000000000000..31c900d701f7 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/KinesisShardManagerTest.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest; +import software.amazon.awssdk.services.dynamodb.model.CreateTableResponse; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteTableResponse; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableRequest; +import software.amazon.awssdk.services.dynamodb.model.DescribeTableResponse; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemResponse; +import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement; +import software.amazon.awssdk.services.dynamodb.model.KeyType; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemResponse; +import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanResponse; +import software.amazon.awssdk.services.dynamodb.model.TableDescription; +import software.amazon.awssdk.services.dynamodb.model.TableStatus; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; +import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse; +import software.amazon.awssdk.services.kinesis.KinesisClient; + +import java.math.BigInteger; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class KinesisShardManagerTest { + + private DynamoDbClient dynamoDb; + private KinesisShardManager manager; + + @BeforeEach + void setUp() { + dynamoDb = mock(DynamoDbClient.class); + manager = new KinesisShardManager(mock(KinesisClient.class), dynamoDb, mock(ComponentLog.class), "test-table", "test-stream"); + } + + /** + * Verifies that writeCheckpoints enforces monotonic checkpoint advancement. + * A later call with a lower sequence number must not overwrite a higher one. + */ + @Test + void testCheckpointMonotonicity() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("30000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("70000"))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(2)).updateItem(captor.capture()); + + final List requests = captor.getAllValues(); + assertEquals("50000", requests.get(0).expressionAttributeValues().get(":seq").s()); + assertEquals("70000", requests.get(1).expressionAttributeValues().get(":seq").s(), + "Only increasing checkpoints should be written to DynamoDB"); + } + + /** + * Verifies that checkpoints for different shards are tracked independently. + * A lower checkpoint on shard-1 must not affect shard-2. + */ + @Test + void testCheckpointMonotonicityPerShard() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of( + "shard-1", new BigInteger("50000"), + "shard-2", new BigInteger("20000"))); + manager.writeCheckpoints(Map.of( + "shard-1", new BigInteger("30000"), + "shard-2", new BigInteger("40000"))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(3)).updateItem(captor.capture()); + + long shard1Writes = 0; + long shard2Writes = 0; + for (final UpdateItemRequest request : captor.getAllValues()) { + if ("shard-1".equals(request.key().get("shardId").s())) { + shard1Writes++; + } else if ("shard-2".equals(request.key().get("shardId").s())) { + shard2Writes++; + } + } + + assertEquals(1, shard1Writes, "shard-1 regression (30000 < 50000) should be skipped"); + assertEquals(2, shard2Writes, "shard-2 advance (40000 > 20000) should be written"); + } + + /** + * Verifies that close() resets the monotonic checkpoint guard, allowing a fresh start. + */ + @Test + void testCloseResetsCheckpointGuard() { + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); + manager.close(); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("30000"))); + + verify(dynamoDb, times(2)).updateItem(any(UpdateItemRequest.class)); + } + + /** + * Verifies that readCheckpoint returns null when no checkpoint item exists in DynamoDB. + * A wrong return value here would cause the processor to either skip data (if it returned + * a stale sequence) or re-process from TRIM_HORIZON unnecessarily. + */ + @Test + void testReadCheckpointReturnsNullForMissingCheckpoint() { + final GetItemResponse emptyResponse = GetItemResponse.builder().item(Map.of()).build(); + when(dynamoDb.getItem(any(GetItemRequest.class))).thenReturn(emptyResponse); + + assertNull(manager.readCheckpoint("shard-1"), + "Missing checkpoint must return null so the processor starts from the initial stream position"); + } + + /** + * Verifies that readCheckpoint ignores a non-numeric sequence number stored in DynamoDB. + * Passing a corrupt value to GetShardIterator would cause an API error and potentially + * force a full re-read from TRIM_HORIZON, creating massive duplication. + */ + @Test + void testReadCheckpointIgnoresInvalidSequenceNumber() { + final GetItemResponse response = GetItemResponse.builder() + .item(Map.of( + "streamName", AttributeValue.builder().s("test-stream").build(), + "shardId", AttributeValue.builder().s("shard-1").build(), + "sequenceNumber", AttributeValue.builder().s("NOT_A_NUMBER").build())) + .build(); + when(dynamoDb.getItem(any(GetItemRequest.class))).thenReturn(response); + + assertNull(manager.readCheckpoint("shard-1"), + "Non-numeric sequence number must be treated as missing to avoid API errors"); + } + + /** + * Verifies that a ConditionalCheckFailedException during checkpoint write (indicating + * lease loss) does not crash the processor and does not corrupt the in-memory monotonic + * guard. A subsequent higher-valued checkpoint must still be attempted. + */ + @Test + void testWriteCheckpointHandlesLostLeaseGracefully() { + final ConditionalCheckFailedException lostLease = ConditionalCheckFailedException.builder().message("lost lease").build(); + final UpdateItemResponse emptyResponse = UpdateItemResponse.builder().build(); + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenThrow(lostLease).thenReturn(emptyResponse); + + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("50000"))); + manager.writeCheckpoints(Map.of("shard-1", new BigInteger("70000"))); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(UpdateItemRequest.class); + verify(dynamoDb, times(2)).updateItem(captor.capture()); + + assertEquals("50000", captor.getAllValues().get(0).expressionAttributeValues().get(":seq").s()); + assertEquals("70000", captor.getAllValues().get(1).expressionAttributeValues().get(":seq").s(), + "After a lost-lease failure, the next higher checkpoint must still be attempted"); + } + + /** + * Verifies that when no checkpoint table exists and no orphaned migration table exists, + * a fresh table is created with the configured name (no suffix). + */ + @Test + void testEnsureCheckpointTableCreatesNewTableWhenNotFound() { + final AtomicInteger mainTableDescribeCount = new AtomicInteger(); + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + if (mainTableDescribeCount.incrementAndGet() <= 2) { + throw ResourceNotFoundException.builder().build(); + } + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + when(dynamoDb.createTable(any(CreateTableRequest.class))).thenReturn(CreateTableResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(CreateTableRequest.class); + verify(dynamoDb).createTable(captor.capture()); + assertEquals("test-table", captor.getValue().tableName(), "Fresh table should use the configured name, not a migration suffix"); + } + + /** + * Verifies that when the configured table already exists with the new composite-key + * schema and no migration table is lingering, no creation or migration occurs. + */ + @Test + void testEnsureCheckpointTableUsesExistingNewTable() { + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + + manager.ensureCheckpointTableExists(); + + verify(dynamoDb, never()).createTable(any(CreateTableRequest.class)); + verify(dynamoDb, never()).deleteTable(any(DeleteTableRequest.class)); + } + + /** + * Verifies that when the configured table is NOT_FOUND but an orphaned migration table + * exists (crash during a previous migration), the migration table is renamed to the + * configured name by creating a new table, copying items, and deleting the migration table. + */ + @Test + void testEnsureCheckpointTableRenamesOrphanedMigrationTable() { + final AtomicInteger mainTableDescribeCount = new AtomicInteger(); + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + if (mainTableDescribeCount.incrementAndGet() <= 3) { + throw ResourceNotFoundException.builder().build(); + } + + return newSchemaActiveResponse(); + } + + if ("test-table_migration".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + + throw ResourceNotFoundException.builder().build(); + }); + + when(dynamoDb.updateItem(any(UpdateItemRequest.class))).thenReturn(UpdateItemResponse.builder().build()); + when(dynamoDb.deleteTable(any(DeleteTableRequest.class))).thenAnswer(invocation -> { + final DeleteTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName())) { + throw ResourceNotFoundException.builder().build(); + } + return DeleteTableResponse.builder().build(); + }); + + when(dynamoDb.createTable(any(CreateTableRequest.class))).thenReturn(CreateTableResponse.builder().build()); + + final Map checkpointItem = Map.of( + "streamName", AttributeValue.builder().s("test-stream").build(), + "shardId", AttributeValue.builder().s("shard-1").build(), + "sequenceNumber", AttributeValue.builder().s("12345").build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().items(checkpointItem).build()); + when(dynamoDb.putItem(any(PutItemRequest.class))).thenReturn(PutItemResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + final ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateTableRequest.class); + verify(dynamoDb).createTable(createCaptor.capture()); + assertEquals("test-table", createCaptor.getValue().tableName(), "Renamed table should use the configured name"); + + final ArgumentCaptor putCaptor = ArgumentCaptor.forClass(PutItemRequest.class); + verify(dynamoDb).putItem(putCaptor.capture()); + assertEquals("test-table", putCaptor.getValue().tableName(), "Items should be copied to the configured table name"); + assertEquals("12345", putCaptor.getValue().item().get("sequenceNumber").s(), "Checkpoint data should be preserved during rename"); + + final ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteTableRequest.class); + verify(dynamoDb, times(2)).deleteTable(deleteCaptor.capture()); + final String deletedMigrationTable = deleteCaptor.getAllValues().stream() + .map(DeleteTableRequest::tableName) + .filter("test-table_migration"::equals) + .findFirst() + .orElseThrow(); + + assertEquals("test-table_migration", deletedMigrationTable, "Migration table should be deleted after rename"); + } + + /** + * Verifies that when the configured table has the new schema but a lingering migration + * table exists (crash after copy but before migration table deletion), the items are + * copied and the migration table is cleaned up. + */ + @Test + void testEnsureCheckpointTableCleansUpLingeringMigrationTable() { + when(dynamoDb.describeTable(any(DescribeTableRequest.class))).thenAnswer(invocation -> { + final DescribeTableRequest req = invocation.getArgument(0); + if ("test-table".equals(req.tableName()) || "test-table_migration".equals(req.tableName())) { + return newSchemaActiveResponse(); + } + throw ResourceNotFoundException.builder().build(); + }); + + when(dynamoDb.deleteTable(any(DeleteTableRequest.class))).thenReturn(DeleteTableResponse.builder().build()); + when(dynamoDb.scan(any(ScanRequest.class))).thenReturn(ScanResponse.builder().build()); + + manager.ensureCheckpointTableExists(); + + verify(dynamoDb, never()).createTable(any(CreateTableRequest.class)); + final ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteTableRequest.class); + verify(dynamoDb).deleteTable(deleteCaptor.capture()); + assertEquals("test-table_migration", deleteCaptor.getValue().tableName(), "Lingering migration table should be deleted"); + } + + private static DescribeTableResponse newSchemaActiveResponse() { + final KeySchemaElement hashKey = KeySchemaElement.builder() + .attributeName("streamName") + .keyType(KeyType.HASH) + .build(); + final KeySchemaElement rangeKey = KeySchemaElement.builder() + .attributeName("shardId") + .keyType(KeyType.RANGE) + .build(); + final TableDescription table = TableDescription.builder() + .keySchema(hashKey, rangeKey) + .tableStatus(TableStatus.ACTIVE) + .build(); + return DescribeTableResponse.builder().table(table).build(); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java deleted file mode 100644 index b6005bf8b99d..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/MemoryBoundRecordBufferTest.java +++ /dev/null @@ -1,792 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.documentation.init.NopComponentLog; -import org.apache.nifi.processors.aws.kinesis.MemoryBoundRecordBuffer.Lease; -import org.apache.nifi.processors.aws.kinesis.RecordBuffer.ShardBufferId; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import software.amazon.awssdk.services.kinesis.model.Record; -import software.amazon.kinesis.exceptions.InvalidStateException; -import software.amazon.kinesis.exceptions.KinesisClientLibDependencyException; -import software.amazon.kinesis.exceptions.ShutdownException; -import software.amazon.kinesis.exceptions.ThrottlingException; -import software.amazon.kinesis.processor.Checkpointer; -import software.amazon.kinesis.processor.PreparedCheckpointer; -import software.amazon.kinesis.processor.RecordProcessorCheckpointer; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.IntStream; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class MemoryBoundRecordBufferTest { - - private static final long MAX_MEMORY_BYTES = 1024L; - private static final String SHARD_ID_1 = "shard-1"; - private static final String SHARD_ID_2 = "shard-2"; - private static final Duration CHECKPOINT_INTERVAL = Duration.ZERO; - - private MemoryBoundRecordBuffer recordBuffer; - private TestCheckpointer checkpointer1; - private TestCheckpointer checkpointer2; - - @BeforeEach - void setUp() { - recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), MAX_MEMORY_BYTES, CHECKPOINT_INTERVAL); - checkpointer1 = new TestCheckpointer(); - checkpointer2 = new TestCheckpointer(); - } - - @Test - void testCreateBuffer() { - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - assertEquals(SHARD_ID_1, bufferId1.shardId()); - - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - assertEquals(SHARD_ID_2, bufferId2.shardId()); - - final ShardBufferId newBufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - assertEquals(SHARD_ID_1, newBufferId1.shardId()); - - assertNotEquals(bufferId1, bufferId2); - assertNotEquals(bufferId1, newBufferId1); - } - - @Test - void testAddRecordsToBuffer() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - // Buffer without records is not available for leasing. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - // Should be able to get buffer ID from pool since buffer has records. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease.shardId()); - } - - @Test - void testAddEmptyRecordsList() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List emptyRecords = Collections.emptyList(); - - recordBuffer.addRecords(bufferId, emptyRecords, checkpointer1); - - // Should not be able to get buffer ID from pool since no records were added. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testConsumeRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(3); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertEquals(records, consumedRecords); - // Just consuming record should not checkpoint them. - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testCommitConsumedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - assertEquals(records.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testCommitConsumedRecords_withRecordsAddedBeforeCommit() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List originalRecords = createTestRecords(2); - - recordBuffer.addRecords(bufferId, originalRecords, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - - // Simulating new records added in parallel, before a commit. - final List newRecords = createTestRecords(5); - recordBuffer.addRecords(bufferId, newRecords, checkpointer1); - - recordBuffer.commitConsumedRecords(lease); - - // Only originalRecords, which were consumed, are checkpointed. - assertEquals(originalRecords.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - } - - @Test - void testRollbackConsumedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(3); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - final List messages = consumedRecords.stream() - .map(this::readContent) - .toList(); - - recordBuffer.rollbackConsumedRecords(lease); - - // Checkpointer should not be called during rollback. - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - - final List rolledBackMessages = recordBuffer.consumeRecords(lease).stream() - .map(this::readContent) - .toList(); - assertEquals(messages, rolledBackMessages); - } - - private String readContent(final KinesisClientRecord record) { - final ByteBuffer data = record.data(); - final byte[] buffer = new byte[data.remaining()]; - data.get(buffer); - return new String(buffer, StandardCharsets.UTF_8); - } - - @Test - void testReturnBufferIdToPool() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - // Consume some records but don't return buffer id to the pool. - recordBuffer.consumeRecords(lease1); - - recordBuffer.addRecords(bufferId, createTestRecords(1), checkpointer2); - - // The buffer is still unavailable. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - - // After returning id to the pool it's possible to get the buffer from pool again. - recordBuffer.returnBufferLease(lease1); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - } - - @Test - void testReturnBufferIdToPool_multipleReturns() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.returnBufferLease(lease1); - recordBuffer.returnBufferLease(lease1); - - // Can retrieve the id only once. - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testReturnBufferIdToPool_withUncommittedRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - - // Consume some records, but don't commit them. - final List lease1Records = recordBuffer.consumeRecords(lease1); - recordBuffer.returnBufferLease(lease1); - - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease2.shardId()); - final List lease2Records = recordBuffer.consumeRecords(lease2); - - // Until committed, the records stay in the buffer. - assertEquals(lease1Records, lease2Records); - } - - @Test - void testReturnBufferIdToPoolWithNoPendingRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - // Get buffer from pool and consume all records. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - // Return buffer ID to pool - should not make buffer available since no pending records. - recordBuffer.returnBufferLease(lease); - - // Should not be able to get buffer from pool. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testConsumerLeaseLost() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(2); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - // Before lease lost buffer should be available in the pool. - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_1, lease.shardId()); - - // Simulate lease lost. - recordBuffer.consumerLeaseLost(bufferId); - - // Should not be able to consume records from invalidated buffer. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - - // Should not be able to commit records for invalidated buffer. - recordBuffer.commitConsumedRecords(lease); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - - // Buffer should not be available in pool. - assertTrue(recordBuffer.acquireBufferLease().isEmpty()); - } - - @Test - void testCheckpointEndedShard() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - recordBuffer.checkpointEndedShard(bufferId, checkpointer2); - assertEquals(TestCheckpointer.LATEST_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - // Buffer should be removed and not available for operations. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - } - - @Test - void testShutdownShardConsumption_forEmptyBuffer() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - recordBuffer.consumeRecords(lease); - recordBuffer.commitConsumedRecords(lease); - - recordBuffer.shutdownShardConsumption(bufferId, checkpointer2); - assertEquals(TestCheckpointer.LATEST_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - // Buffer should be removed and not available for operations. - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertTrue(consumedRecords.isEmpty()); - } - - @Test - void testShutdownShardConsumption_forBufferWithRecords() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - - recordBuffer.shutdownShardConsumption(bufferId, checkpointer2); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer1.latestCheckpointedSequenceNumber()); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, checkpointer2.latestCheckpointedSequenceNumber()); - - assertTrue(recordBuffer.acquireBufferLease().isEmpty(), "Buffer should not be available after shutdown"); - } - - @Test - void testShutdownShardConsumption_whileOtherShardIsValid() { - final int bufferSize = 100; - - // Create buffer with small memory limit. - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), bufferSize, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - - final List records1 = List.of(createRecordWithSize(bufferSize)); - recordBuffer.addRecords(bufferId1, records1, checkpointer1); - - // Shutting down a buffer with a record. - recordBuffer.shutdownShardConsumption(bufferId1, checkpointer1); - - // Adding records to another buffer. - final List records2 = List.of(createRecordWithSize(bufferSize)); - assertTimeoutPreemptively( - Duration.ofSeconds(1), - () -> recordBuffer.addRecords(bufferId2, records2, checkpointer2), - "Records should be added to a buffer without memory backpressure"); - - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(SHARD_ID_2, lease.shardId(), "Expected to acquire a lease for " + SHARD_ID_2); - assertEquals(records2, recordBuffer.consumeRecords(lease)); - } - - @Test - @Timeout(value = 5, unit = SECONDS) - void testMemoryBackpressure() throws InterruptedException { - // Create buffer with small memory limit. - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), 100L, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // Still fits into the buffer. - final List initialRecords = List.of(createRecordWithSize(80), createRecordWithSize(20)); - recordBuffer.addRecords(bufferId, initialRecords, checkpointer1); - - final CountDownLatch startLatch = new CountDownLatch(1); - - final List notFittingRecords = List.of(createRecordWithSize(50)); - // Thread that tries to add records (should block due to memory limit). - final Thread addRecordsThread = new Thread(() -> { - startLatch.countDown(); - // Doesn't fit into the buffer. - recordBuffer.addRecords(bufferId, notFittingRecords, checkpointer1); - }); - - addRecordsThread.start(); - startLatch.await(); - - // Wait for thread to try to add records and get blocked. - Thread.sleep(200); - assertTrue(addRecordsThread.isAlive(), "Thread should be blocked waiting for memory"); - - // Commit records in the buffer to free memory. - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(initialRecords, recordBuffer.consumeRecords(lease1)); - recordBuffer.commitConsumedRecords(lease1); - recordBuffer.returnBufferLease(lease1); - - // Thread should get unblocked and add the message. - addRecordsThread.join(); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - assertEquals(notFittingRecords, recordBuffer.consumeRecords(lease2)); - } - - @Test - void testMemoryBackpressure_forRecordsLargerThanMaxBuffer() { - // Create buffer with small memory limit. - final int bufferSize = 10; - final MemoryBoundRecordBuffer recordBuffer = new MemoryBoundRecordBuffer(new NopComponentLog(), bufferSize, CHECKPOINT_INTERVAL); - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // A single batch with size double the buffer size. - final List reallyLargeBatch = List.of(createRecordWithSize(bufferSize), createRecordWithSize(bufferSize)); - - // It's possible to insert a batch that exceeds the buffer size. - assertDoesNotThrow(() -> recordBuffer.addRecords(bufferId, reallyLargeBatch, checkpointer1)); - } - - @Test - @Timeout(value = 5, unit = SECONDS) - void testConcurrentBufferAccess() throws InterruptedException { - final int numberOfShards = 10; - final int numberOfRecordConsumers = 10; - final int totalThreads = numberOfShards + numberOfRecordConsumers; - - final int recordsPerShard = 5; - - final ExecutorService executor = Executors.newFixedThreadPool(totalThreads); - // All threads start at the same moment. - final CountDownLatch startLatch = new CountDownLatch(totalThreads); - final CountDownLatch finishLatch = new CountDownLatch(totalThreads); - - final List shardIds = IntStream.range(0, numberOfShards).boxed().toList(); - final List bufferIds = shardIds.stream() - .map(shardId -> recordBuffer.createBuffer("shard-" + shardId)) - .toList(); - - // Every producer returns a checkpointer for the records it produced. - final List checkpointers = shardIds.stream() - .map(shardId -> { - final ShardBufferId bufferId = bufferIds.get(shardId); - final List records = createTestRecords(recordsPerShard); - final TestCheckpointer threadCheckpointer = new TestCheckpointer(); - - executor.submit(() -> { - try { - startLatch.countDown(); - startLatch.await(); - - recordBuffer.addRecords(bufferId, records, threadCheckpointer); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } finally { - finishLatch.countDown(); - } - }); - - return threadCheckpointer; - }) - .toList(); - - // Every consumer returns a list of shardId:sequenceNumber strings for the records it consumed. - final List>> processedRecordsFutures = IntStream.range(0, numberOfRecordConsumers) - .mapToObj(__ -> executor.submit(() -> { - try { - startLatch.countDown(); - startLatch.await(); - - Optional maybeLease = Optional.empty(); - while (maybeLease.isEmpty()) { - maybeLease = recordBuffer.acquireBufferLease(); - Thread.sleep(100); // Wait for records to be added by producers. - } - - final Lease lease = maybeLease.orElseThrow(); - final List consumedRecords = recordBuffer.consumeRecords(lease); - - recordBuffer.commitConsumedRecords(lease); - - return consumedRecords.stream() - .map(record -> lease.shardId() + ":" + record.sequenceNumber()) - .toList(); - } catch (final InterruptedException e) { - throw new RuntimeException(e); - } finally { - finishLatch.countDown(); - } - })) - .toList(); - - finishLatch.await(); - executor.shutdown(); - - assertAll( - checkpointers.stream().map(it -> () -> - assertNotEquals( - TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, - it.latestCheckpointedSequenceNumber(), - "Every checkpointer should have been called")) - ); - - final long uniqueRecordsProcessed = processedRecordsFutures.stream() - .map(future -> { - try { - return future.get(); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }) - .flatMap(Collection::stream) - .distinct() - .count(); - assertEquals(numberOfShards * recordsPerShard, uniqueRecordsProcessed); - } - - @Test - void testGetMultipleBuffersFromPool() { - final ShardBufferId bufferId1 = recordBuffer.createBuffer(SHARD_ID_1); - final ShardBufferId bufferId2 = recordBuffer.createBuffer(SHARD_ID_2); - - // Add records to both buffers. - recordBuffer.addRecords(bufferId1, createTestRecords(2), checkpointer1); - recordBuffer.addRecords(bufferId2, createTestRecords(3), checkpointer2); - - // Should be able to get both buffer IDs from pool. - final Lease lease1 = recordBuffer.acquireBufferLease().orElseThrow(); - final Lease lease2 = recordBuffer.acquireBufferLease().orElseThrow(); - - assertEquals(SHARD_ID_1, lease1.shardId()); - assertEquals(SHARD_ID_2, lease2.shardId()); - - // Should get different buffer IDs. - assertNotEquals(lease1.shardId(), lease2.shardId()); - } - - @Test - void testCommitRecordsWhileNewRecordsArrive() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - - // Add records to pending queue. - final List batch1 = createTestRecords(2); - recordBuffer.addRecords(bufferId, batch1, checkpointer1); - - // Consume batch1 records (moves from pending to in-progress). - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - - final List consumedRecords = recordBuffer.consumeRecords(lease); - assertEquals(batch1, consumedRecords); - - // Add more records while others are in-progress. - final List batch2 = createTestRecords(1); - recordBuffer.addRecords(bufferId, batch2, checkpointer2); - - // Commit in-progress records. - recordBuffer.commitConsumedRecords(lease); - - // Consume batch2 records. - final List remainingRecords = recordBuffer.consumeRecords(lease); - assertEquals(batch2, remainingRecords); - } - - @ParameterizedTest - @ValueSource(classes = { - KinesisClientLibDependencyException.class, - InvalidStateException.class, - ThrottlingException.class, - }) - void testRetriableCheckpointExceptions(final Class exceptionClass) throws Exception { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - final Exception exception = exceptionClass.getDeclaredConstructor(String.class).newInstance("Thrown from test"); - final TestCheckpointer failingCheckpointer = new TestCheckpointer(exception, 2); - - recordBuffer.addRecords(bufferId, records, failingCheckpointer); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - // Should handle all exception types gracefully. - recordBuffer.commitConsumedRecords(lease); - - assertEquals(records.getLast().sequenceNumber(), failingCheckpointer.latestCheckpointedSequenceNumber()); - } - - @Test - void testShutdownDuringCheckpoint() { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - final ShutdownException exception = new ShutdownException("Test shutdown exception"); - final TestCheckpointer failingCheckpointer = new TestCheckpointer(exception, 1); - - recordBuffer.addRecords(bufferId, records, failingCheckpointer); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - recordBuffer.commitConsumedRecords(lease); - assertEquals(TestCheckpointer.NO_CHECKPOINT_SEQUENCE_NUMBER, failingCheckpointer.latestCheckpointedSequenceNumber()); - } - - @Test - @Timeout(value = 3, unit = SECONDS) - void testCheckpointEndedShardWaitsForPendingRecords() throws InterruptedException { - final ShardBufferId bufferId = recordBuffer.createBuffer(SHARD_ID_1); - final List records = createTestRecords(1); - - recordBuffer.addRecords(bufferId, records, checkpointer1); - final Lease lease = recordBuffer.acquireBufferLease().orElseThrow(); - recordBuffer.consumeRecords(lease); - - final CountDownLatch finishStarted = new CountDownLatch(1); - - // Start finishConsumption in another thread. - final Thread finishThread = new Thread(() -> { - finishStarted.countDown(); - recordBuffer.checkpointEndedShard(bufferId, checkpointer2); - }); - - finishThread.start(); - finishStarted.await(); - - // Give finishThread time to get blocked. - Thread.sleep(200); - assertTrue(finishThread.isAlive(), "finishConsumption should block until records are committed"); - - // Commit records to unblock finishConsumption. - recordBuffer.commitConsumedRecords(lease); - assertEquals(records.getLast().sequenceNumber(), checkpointer1.latestCheckpointedSequenceNumber()); - - finishThread.join(); - assertEquals( - TestCheckpointer.LATEST_SEQUENCE_NUMBER, - checkpointer2.latestCheckpointedSequenceNumber(), - "Checkpointer should be called after finishConsumption unblocks"); - } - - private List createTestRecords(int count) { - return IntStream.range(0, count) - .mapToObj(i -> { - final String data = "test-record-" + i; - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data.getBytes(StandardCharsets.UTF_8)).asReadOnlyBuffer()) - .partitionKey("partition-" + i) - .sequenceNumber(String.valueOf(i)) - .build(); - }) - .toList(); - } - - private KinesisClientRecord createRecordWithSize(int sizeBytes) { - final byte[] data = new byte[sizeBytes]; - Arrays.fill(data, (byte) 'X'); - - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data).asReadOnlyBuffer()) - .partitionKey("partition") - .sequenceNumber("1") - .build(); - } - - /** - * Test implementation of RecordProcessorCheckpointer that tracks checkpoint calls - * and can simulate various exception scenarios. - */ - private static class TestCheckpointer implements RecordProcessorCheckpointer { - - static final String NO_CHECKPOINT_SEQUENCE_NUMBER = "NONE"; - static final String LATEST_SEQUENCE_NUMBER = "LATEST"; - - private final Exception exceptionToThrow; - private final AtomicInteger throwsLeft; - - private volatile String latestCheckpointedSequenceNumber = NO_CHECKPOINT_SEQUENCE_NUMBER; - - TestCheckpointer() { - this.exceptionToThrow = null; - this.throwsLeft = new AtomicInteger(0); - } - - TestCheckpointer(final Exception exceptionToThrow, final int maxThrows) { - this.exceptionToThrow = exceptionToThrow; - this.throwsLeft = new AtomicInteger(maxThrows); - } - - @Override - public void checkpoint() throws KinesisClientLibDependencyException, InvalidStateException, ThrottlingException, ShutdownException { - doCheckpoint(LATEST_SEQUENCE_NUMBER); - } - - @Override - public void checkpoint(Record record) { - throw notImplemented(); - } - - @Override - public void checkpoint(String sequenceNumber) { - throw notImplemented(); - } - - @Override - public void checkpoint(String sequenceNumber, long subSequenceNumber) throws ShutdownException, InvalidStateException { - doCheckpoint(sequenceNumber); - } - - private void doCheckpoint(final String sequenceNumber) throws InvalidStateException, ShutdownException { - if (exceptionToThrow != null && throwsLeft.decrementAndGet() == 0) { - switch (exceptionToThrow) { - case KinesisClientLibDependencyException e -> throw e; - case InvalidStateException e -> throw e; - case ThrottlingException e -> throw e; - case ShutdownException e -> throw e; - default -> throw new RuntimeException(exceptionToThrow); - } - } - - latestCheckpointedSequenceNumber = sequenceNumber; - } - - @Override - public PreparedCheckpointer prepareCheckpoint() { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(Record record) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(Record record, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, long subSequenceNumber) { - throw notImplemented(); - } - - @Override - public PreparedCheckpointer prepareCheckpoint(String sequenceNumber, long subSequenceNumber, byte[] applicationState) { - throw notImplemented(); - } - - @Override - public Checkpointer checkpointer() { - throw notImplemented(); - } - - String latestCheckpointedSequenceNumber() { - return latestCheckpointedSequenceNumber; - } - - private static RuntimeException notImplemented() { - return new UnsupportedOperationException("Not implemented for test"); - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java new file mode 100644 index 000000000000..5133adc26d48 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/PollingKinesisClientTest.java @@ -0,0 +1,588 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import org.apache.nifi.logging.ComponentLog; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.ExpiredIteratorException; +import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest; +import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest; +import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse; +import software.amazon.awssdk.services.kinesis.model.ProvisionedThroughputExceededException; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.awssdk.services.kinesis.model.Shard; +import software.amazon.awssdk.services.kinesis.model.ShardIteratorType; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class PollingKinesisClientTest { + + private static final GetShardIteratorResponse ITERATOR_RESPONSE = + GetShardIteratorResponse.builder().shardIterator("iter-1").build(); + + private KinesisClient mockKinesisClient; + private KinesisShardManager mockShardManager; + private PollingKinesisClient consumer; + + @BeforeEach + void setUp() { + mockKinesisClient = mock(KinesisClient.class); + mockShardManager = mock(KinesisShardManager.class); + consumer = new PollingKinesisClient(mockKinesisClient, mock(ComponentLog.class), 1L, 1L); + } + + @AfterEach + void tearDown() { + if (consumer != null) { + consumer.close(); + } + } + + /** + * When a shard reaches exhaustion (nextShardIterator is null), the records from the + * final GetRecords response must be queued before the loop exits. A regression would + * silently drop the last batch of every exhausted shard. + */ + @Test + void testExhaustedShardDeliversAllRecords() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B"), record("300", "C")) + .nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Last batch of an exhausted shard must not be dropped"); + assertEquals(3, result.records().size()); + assertEquals(new BigInteger("100"), result.firstSequenceNumber()); + assertEquals(new BigInteger("300"), result.lastSequenceNumber()); + + assertEventuallyNoPendingFetches(); + } + + /** + * When a fetch loop thread dies from an uncaught Throwable, the next startFetches call + * must detect the dead loop and restart it so records continue flowing. An Error thrown + * from readCheckpoint escapes the Exception catch in getShardIterator, propagates through + * runFetchLoop, and is caught by the Throwable guard in launchFetchLoop. + */ + @Test + void testDeadLoopRecoveryRestoresDataFlow() throws Exception { + when(mockShardManager.readCheckpoint(anyString())) + .thenThrow(new AssertionError("simulated loop death")) + .thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + final List shards = shards("shard-1"); + consumer.startFetches(shards, "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + Thread.sleep(100); + + consumer.startFetches(shards, "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Dead loop must be restarted and produce records"); + assertEquals(new BigInteger("100"), result.firstSequenceNumber()); + } + + /** + * Validates that after a rollback (session commit failure), the fetch loop re-acquires + * the shard iterator from the checkpoint position, not from where it left off. This + * prevents data loss when in-flight records need to be re-consumed. + */ + @Test + void testRollbackCausesIteratorReAcquisitionFromCheckpoint() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null).thenReturn("500"); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("600", "data"), record("700", "data2")) + .nextShardIterator("iter-next").millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult); + + consumer.rollbackResults(List.of(firstResult)); + + final ArgumentCaptor captor = ArgumentCaptor.forClass(GetShardIteratorRequest.class); + verify(mockKinesisClient, timeout(5000).atLeast(2)).getShardIterator(captor.capture()); + + final GetShardIteratorRequest reAcquisition = captor.getAllValues().get(captor.getAllValues().size() - 1); + assertEquals(ShardIteratorType.AFTER_SEQUENCE_NUMBER, reAcquisition.shardIteratorType(), + "After rollback, iterator must resume from checkpoint, not from where it left off"); + assertEquals("500", reAcquisition.startingSequenceNumber()); + } + + /** + * Validates that an expired iterator is transparently replaced by re-acquiring from + * the checkpoint. Records must still arrive after the transient error. + */ + @Test + void testExpiredIteratorRecovery() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final ExpiredIteratorException expired = ExpiredIteratorException.builder().message("expired").build(); + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenThrow(expired).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Records must arrive after expired iterator recovery"); + + verify(mockKinesisClient, timeout(5000).atLeast(2)).getShardIterator(any(GetShardIteratorRequest.class)); + } + + /** + * Verifies that iterator recovery does not reorder results within a shard when newer data + * is already queued ahead of replay from the persisted checkpoint. + */ + @Test + void testExpiredIteratorRecoveryDoesNotDeliverSameShardOutOfOrder() throws Exception { + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn("100"); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + return switch (request.shardIterator()) { + case "iter-1" -> GetRecordsResponse.builder() + .records(record("200", "A")) + .nextShardIterator("iter-1a") + .millisBehindLatest(0L) + .build(); + case "iter-1a" -> GetRecordsResponse.builder() + .records(record("300", "B")) + .nextShardIterator("iter-1b") + .millisBehindLatest(0L) + .build(); + case "iter-1b" -> throw ExpiredIteratorException.builder().message("expired").build(); + case "iter-2" -> GetRecordsResponse.builder() + .records(record("200", "A-replay")) + .nextShardIterator("iter-2a") + .millisBehindLatest(0L) + .build(); + default -> GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()) + .millisBehindLatest(0L) + .build(); + }; + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "Initial result must be available"); + assertEquals(new BigInteger("200"), firstResult.firstSequenceNumber()); + + final long newerQueuedDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < newerQueuedDeadline && getRecordsCallCount.get() < 2) { + Thread.sleep(20); + } + Thread.sleep(50); + + final long replayQueuedDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < replayQueuedDeadline && (getShardIteratorCallCount.get() < 2 || getRecordsCallCount.get() < 4)) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRecovery = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRecovery, "Queue must contain results after iterator recovery"); + assertEquals(new BigInteger("200"), firstAfterRecovery.firstSequenceNumber(), + "Replayed checkpoint data must be delivered before any stale newer batch from the same shard"); + } + + /** + * Validates that throttling (ProvisionedThroughputExceededException) is transient: + * the loop retries and records eventually arrive without data loss. + */ + @Test + void testThrottledFetchRetriesWithoutDataLoss() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final ProvisionedThroughputExceededException throttled = + ProvisionedThroughputExceededException.builder().message("throttled").build(); + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenThrow(throttled).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Throttled fetch must retry and eventually deliver records"); + } + + /** + * Validates that a RuntimeException from readCheckpoint (e.g., DynamoDB throttle on + * startup) does not permanently kill the fetch loop. The getShardIterator method + * catches the exception and returns null, the loop backs off and retries. Records + * must eventually arrive. + */ + @Test + void testCheckpointReadFailureRetriesWithoutKillingLoop() throws Exception { + when(mockShardManager.readCheckpoint(anyString())) + .thenThrow(new RuntimeException("DynamoDB throttle")) + .thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult result = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(result, "Checkpoint read failure must not permanently kill the fetch loop"); + } + + /** + * Validates that close() terminates all fetch loops and that the consumer reports + * no pending work, preventing the processor from spin-waiting after shutdown. + */ + @Test + void testCloseTerminatesAllFetchingAndDrainsQueue() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator("iter-next").millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1", "shard-2"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + verify(mockKinesisClient, timeout(5000).atLeast(1)).getRecords(any(GetRecordsRequest.class)); + + consumer.close(); + + assertFalse(consumer.hasPendingFetches(), "After close, hasPendingFetches must return false"); + + consumer = null; + } + + /** + * Validates that once all shards are exhausted and the queue is drained, + * hasPendingFetches returns false. This ensures the processor's poll loop + * exits promptly rather than spin-waiting indefinitely. + */ + @Test + void testHasPendingFetchesFalseWhenAllShardsExhausted() throws Exception { + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenReturn(ITERATOR_RESPONSE); + + final GetRecordsResponse response = GetRecordsResponse.builder() + .records(record("100", "data")).nextShardIterator(null).millisBehindLatest(0L).build(); + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenReturn(response); + + consumer.startFetches(shards("shard-1", "shard-2"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + drainAllResults(); + assertEventuallyNoPendingFetches(); + } + + /** + * Reproduces out-of-order delivery caused by stale results remaining in the per-shard queue after a rollback. + * The scenario for a single shard: + * + *
    + *
  1. Fetch loop enqueues result 1 (sequence 100-200) and result 2 (sequence 300-400).
  2. + *
  3. Consumer polls result 1 only; result 2 remains in the queue.
  4. + *
  5. Consumer calls rollbackResults on result 1, which drains the queue synchronously and sets the reset flag. + * The fetch loop detects the flag, drains any stragglers, resets the shard iterator, and re-fetches.
  6. + *
  7. After the reset, the first result polled must come from the re-fetched sequence (sequence 500), + * not the stale result 2 (sequence 300).
  8. + *
+ */ + @Test + void testRollbackDrainsStaleResultsFromQueue() throws Exception { + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + if (request.shardIterator().equals("iter-1")) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (request.shardIterator().equals("iter-1a")) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (request.shardIterator().startsWith("iter-1")) { + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(record("500", "E"), record("600", "F")) + .nextShardIterator("iter-post-reset-next").millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); + + final long enqueueDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < enqueueDeadline && getRecordsCallCount.get() < 2) { + Thread.sleep(20); + } + Thread.sleep(50); + + consumer.rollbackResults(List.of(firstResult)); + + final long resetDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (getShardIteratorCallCount.get() < 2 && System.nanoTime() < resetDeadline) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRollback = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRollback, "Queue must contain results after rollback and re-fetch"); + assertEquals(new BigInteger("500"), firstAfterRollback.firstSequenceNumber(), + "First result after rollback must be re-fetched data, not a stale pre-rollback result"); + } + + /** + * Verifies that rollbackResults drains the queue synchronously so that a concurrent consumer cannot poll stale + * pre-rollback results. The fetch loop is blocked inside a GetRecords call while rollback happens, proving the + * drain must occur in rollbackResults itself rather than being deferred to the fetch loop thread. + */ + @Test + void testRollbackDrainsSynchronouslyPreventingConcurrentStaleRead() throws Exception { + final CountDownLatch fetchLoopBlocked = new CountDownLatch(1); + final CountDownLatch unblockFetchLoop = new CountDownLatch(1); + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))) + .thenReturn(GetShardIteratorResponse.builder().shardIterator("iter-1").build()); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + final int callNumber = getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + if (callNumber == 1) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (callNumber == 2) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (callNumber == 3) { + fetchLoopBlocked.countDown(); + unblockFetchLoop.await(10, TimeUnit.SECONDS); + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator("iter-1c").millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); + + fetchLoopBlocked.await(5, TimeUnit.SECONDS); + + consumer.rollbackResults(List.of(firstResult)); + + final ShardFetchResult staleResult = consumer.pollShardResult("shard-1"); + + unblockFetchLoop.countDown(); + + assertNull(staleResult, "After rollback, stale results must not be pollable — rollbackResults must drain synchronously"); + } + + /** + * Reproduces the race condition where the fetch loop returns records from GetRecords AFTER rollbackResults has + * already drained the queue. Without the per-shard lock, the fetch loop would enqueue the stale result after the + * drain, making it visible to the next consumer poll. With the lock, the enqueue is rejected because the fetch + * loop sees isResetRequested inside the synchronized block. + * + *

Timeline: + *

    + *
  1. Fetch loop enqueues result 1 (sequence 100-200) and result 2 (sequence 300-400).
  2. + *
  3. Consumer polls result 1.
  4. + *
  5. GetRecords call 3 blocks, holding a response with records (sequence 500-600).
  6. + *
  7. Consumer rolls back result 1 — this drains result 2 and sets the reset flag.
  8. + *
  9. GetRecords call 3 unblocks — fetch loop receives the stale response.
  10. + *
  11. Fetch loop calls enqueueIfActive under the shard lock, sees isResetRequested, and discards the result.
  12. + *
  13. Fetch loop processes the reset and re-fetches from the checkpoint (sequence 800).
  14. + *
  15. The first result available after rollback must be the re-fetched data (sequence 800), not the stale data (sequence 500).
  16. + *
+ */ + @Test + void testLockPreventsStaleEnqueueDuringConcurrentRollback() throws Exception { + final CountDownLatch fetchLoopBlockedInsideGetRecords = new CountDownLatch(1); + final CountDownLatch unblockGetRecords = new CountDownLatch(1); + final AtomicInteger getRecordsCallCount = new AtomicInteger(); + final AtomicInteger getShardIteratorCallCount = new AtomicInteger(); + + when(mockShardManager.readCheckpoint(anyString())).thenReturn(null); + when(mockKinesisClient.getShardIterator(any(GetShardIteratorRequest.class))).thenAnswer(invocation -> { + final int callNumber = getShardIteratorCallCount.incrementAndGet(); + return GetShardIteratorResponse.builder().shardIterator("iter-" + callNumber).build(); + }); + + when(mockKinesisClient.getRecords(any(GetRecordsRequest.class))).thenAnswer(invocation -> { + final int callNumber = getRecordsCallCount.incrementAndGet(); + final GetRecordsRequest request = invocation.getArgument(0); + + if (callNumber == 1) { + return GetRecordsResponse.builder() + .records(record("100", "A"), record("200", "B")) + .nextShardIterator("iter-1a").millisBehindLatest(0L).build(); + } + if (callNumber == 2) { + return GetRecordsResponse.builder() + .records(record("300", "C"), record("400", "D")) + .nextShardIterator("iter-1b").millisBehindLatest(0L).build(); + } + if (callNumber == 3) { + fetchLoopBlockedInsideGetRecords.countDown(); + unblockGetRecords.await(10, TimeUnit.SECONDS); + return GetRecordsResponse.builder() + .records(record("500", "E"), record("600", "F")) + .nextShardIterator("iter-1c").millisBehindLatest(0L).build(); + } + if (request.shardIterator().startsWith("iter-2")) { + return GetRecordsResponse.builder() + .records(record("800", "G"), record("900", "H")) + .nextShardIterator("iter-2a").millisBehindLatest(0L).build(); + } + return GetRecordsResponse.builder() + .records(List.of()) + .nextShardIterator(request.shardIterator()).millisBehindLatest(0L).build(); + }); + + consumer.startFetches(shards("shard-1"), "test-stream", 1000, "TRIM_HORIZON", mockShardManager); + + final ShardFetchResult firstResult = consumer.pollAnyResult(5, TimeUnit.SECONDS); + assertNotNull(firstResult, "First result must be available"); + assertEquals(new BigInteger("100"), firstResult.firstSequenceNumber()); + + fetchLoopBlockedInsideGetRecords.await(5, TimeUnit.SECONDS); + + consumer.rollbackResults(List.of(firstResult)); + + unblockGetRecords.countDown(); + + final long resetDeadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (getShardIteratorCallCount.get() < 2 && System.nanoTime() < resetDeadline) { + Thread.sleep(20); + } + Thread.sleep(100); + + final ShardFetchResult firstAfterRollback = consumer.pollShardResult("shard-1"); + assertNotNull(firstAfterRollback, "Queue must contain results after rollback and re-fetch"); + assertEquals(new BigInteger("800"), firstAfterRollback.firstSequenceNumber(), + "First result after rollback must be re-fetched data (800), not the stale data (500) that was returned from GetRecords during the rollback"); + } + + private void drainAllResults() throws InterruptedException { + ShardFetchResult discarded; + do { + discarded = consumer.pollAnyResult(500, TimeUnit.MILLISECONDS); + } while (discarded != null); + } + + private void assertEventuallyNoPendingFetches() throws InterruptedException { + final long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(5); + while (System.nanoTime() < deadline) { + if (!consumer.hasPendingFetches()) { + return; + } + Thread.sleep(50); + } + assertFalse(consumer.hasPendingFetches(), "Expected hasPendingFetches to become false"); + } + + private static List shards(final String... shardIds) { + return Arrays.stream(shardIds).map(id -> Shard.builder().shardId(id).build()).toList(); + } + + private static Record record(final String sequenceNumber, final String data) { + return Record.builder() + .sequenceNumber(sequenceNumber).partitionKey("pk-" + sequenceNumber) + .approximateArrivalTimestamp(Instant.now()) + .data(SdkBytes.fromString(data, StandardCharsets.UTF_8)).build(); + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java new file mode 100644 index 000000000000..5e12d358a9b5 --- /dev/null +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProducerLibraryDeaggregatorTest.java @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.aws.kinesis; + +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.kinesis.model.Record; +import software.amazon.kinesis.retrieval.AggregatorUtil; +import software.amazon.kinesis.retrieval.KinesisClientRecord; +import software.amazon.kinesis.retrieval.kpl.Messages; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.time.Instant; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ProducerLibraryDeaggregatorTest { + + private static final Instant ARRIVAL = Instant.parse("2025-06-15T12:00:00Z"); + private static final String TEST_SHARD_ID = "shardId-000000000000"; + + @Test + void testNonAggregatedPassthrough() { + final byte[] payload = "hello".getBytes(StandardCharsets.UTF_8); + final Record record = buildKinesisRecord("seq-001", "pk-1", payload); + + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + final UserRecord dr = result.getFirst(); + assertEquals(TEST_SHARD_ID, dr.shardId()); + assertEquals("seq-001", dr.sequenceNumber()); + assertEquals(0, dr.subSequenceNumber()); + assertEquals("pk-1", dr.partitionKey()); + assertArrayEquals(payload, dr.data()); + assertEquals(ARRIVAL, dr.approximateArrivalTimestamp()); + } + + @Test + void testSingleSubRecord() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-A"), + List.of(new SubRecord(0, "data-A".getBytes(StandardCharsets.UTF_8)))); + final Record record = buildKinesisRecord("seq-100", "agg-pk", aggregated); + + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + final UserRecord dr = result.getFirst(); + assertEquals("seq-100", dr.sequenceNumber()); + assertEquals(0, dr.subSequenceNumber()); + assertEquals("pk-A", dr.partitionKey()); + assertArrayEquals("data-A".getBytes(StandardCharsets.UTF_8), dr.data()); + } + + @Test + void testMultipleSubRecords() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-X", "pk-Y"), + List.of( + new SubRecord(0, "first".getBytes(StandardCharsets.UTF_8)), + new SubRecord(1, "second".getBytes(StandardCharsets.UTF_8)), + new SubRecord(0, "third".getBytes(StandardCharsets.UTF_8)))); + final Record record = buildKinesisRecord("seq-200", "agg-pk", aggregated); + + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(3, result.size()); + + assertEquals("pk-X", result.get(0).partitionKey()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals("first".getBytes(StandardCharsets.UTF_8), result.get(0).data()); + + assertEquals("pk-Y", result.get(1).partitionKey()); + assertEquals(1, result.get(1).subSequenceNumber()); + assertArrayEquals("second".getBytes(StandardCharsets.UTF_8), result.get(1).data()); + + assertEquals("pk-X", result.get(2).partitionKey()); + assertEquals(2, result.get(2).subSequenceNumber()); + assertArrayEquals("third".getBytes(StandardCharsets.UTF_8), result.get(2).data()); + + for (final UserRecord dr : result) { + assertEquals("seq-200", dr.sequenceNumber()); + assertEquals(ARRIVAL, dr.approximateArrivalTimestamp()); + } + } + + @Test + void testMixedAggregatedAndNonAggregated() throws Exception { + final byte[] plainPayload = "plain-data".getBytes(StandardCharsets.UTF_8); + final Record plainRecord = buildKinesisRecord("seq-001", "pk-plain", plainPayload); + + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-agg"), + List.of(new SubRecord(0, "agg-data".getBytes(StandardCharsets.UTF_8)))); + final Record aggRecord = buildKinesisRecord("seq-002", "pk-outer", aggregated); + + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + + assertEquals(2, result.size()); + assertEquals("seq-001", result.get(0).sequenceNumber()); + assertArrayEquals(plainPayload, result.get(0).data()); + assertEquals("seq-002", result.get(1).sequenceNumber()); + assertArrayEquals("agg-data".getBytes(StandardCharsets.UTF_8), result.get(1).data()); + } + + @Test + void testCorruptedProtobufFallsBackToPassthrough() { + final byte[] corrupted = new byte[ProducerLibraryDeaggregator.KPL_MAGIC.length + 20 + 16]; + System.arraycopy(ProducerLibraryDeaggregator.KPL_MAGIC, 0, corrupted, 0, ProducerLibraryDeaggregator.KPL_MAGIC.length); + final byte[] protobufPart = new byte[20]; + protobufPart[0] = (byte) 0xFF; + System.arraycopy(protobufPart, 0, corrupted, ProducerLibraryDeaggregator.KPL_MAGIC.length, 20); + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufPart); + System.arraycopy(md5, 0, corrupted, ProducerLibraryDeaggregator.KPL_MAGIC.length + 20, 16); + } catch (final Exception e) { + throw new RuntimeException(e); + } + + final Record record = buildKinesisRecord("seq-bad", "pk-bad", corrupted); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + assertEquals("seq-bad", result.get(0).sequenceNumber()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals(corrupted, result.get(0).data()); + } + + @Test + void testMd5MismatchFallsBackToPassthrough() throws Exception { + final byte[] aggregated = buildAggregatedPayload( + List.of("pk-1"), + List.of(new SubRecord(0, "data".getBytes(StandardCharsets.UTF_8)))); + + aggregated[aggregated.length - 1] ^= 0xFF; + + final Record record = buildKinesisRecord("seq-md5", "pk-md5", aggregated); + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(record)); + + assertEquals(1, result.size()); + assertEquals(0, result.get(0).subSequenceNumber()); + assertArrayEquals(aggregated, result.get(0).data()); + } + + @Test + void testIsAggregatedDetection() { + assertFalse(ProducerLibraryDeaggregator.isAggregated(new byte[0])); + assertFalse(ProducerLibraryDeaggregator.isAggregated(new byte[]{0x01, 0x02})); + assertFalse(ProducerLibraryDeaggregator.isAggregated("regular data".getBytes(StandardCharsets.UTF_8))); + + final byte[] withMagic = new byte[ProducerLibraryDeaggregator.KPL_MAGIC.length + 16 + 1]; + System.arraycopy(ProducerLibraryDeaggregator.KPL_MAGIC, 0, withMagic, 0, ProducerLibraryDeaggregator.KPL_MAGIC.length); + assertTrue(ProducerLibraryDeaggregator.isAggregated(withMagic)); + } + + @Test + void testEmptyRecordList() { + final List result = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of()); + assertTrue(result.isEmpty()); + } + + // ---- KCL cross-validation tests ---- + // These tests use the KCL's own protobuf Messages class to create aggregated records, + // then verify that our ProducerLibraryDeaggregator produces the same results as the KCL's AggregatorUtil. + + @Test + void testKclAggregatedSingleRecord() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-kcl-1") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("hello from KPL"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-kcl-1", "outer-pk", payload); + + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(1, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + assertDeaggregatedMatchesKcl(ourResult.getFirst(), kclResult.getFirst()); + assertEquals("pk-kcl-1", ourResult.getFirst().partitionKey()); + assertArrayEquals("hello from KPL".getBytes(StandardCharsets.UTF_8), ourResult.getFirst().data()); + } + + @Test + void testKclAggregatedMultipleRecords() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-alpha") + .addPartitionKeyTable("pk-beta") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-0"))) + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(1) + .setData(ByteString.copyFromUtf8("record-1"))) + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-2"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-kcl-multi", "outer-pk", payload); + + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(3, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + for (int i = 0; i < ourResult.size(); i++) { + assertDeaggregatedMatchesKcl(ourResult.get(i), kclResult.get(i)); + } + + assertEquals("pk-alpha", ourResult.get(0).partitionKey()); + assertEquals(0, ourResult.get(0).subSequenceNumber()); + assertArrayEquals("record-0".getBytes(StandardCharsets.UTF_8), ourResult.get(0).data()); + + assertEquals("pk-beta", ourResult.get(1).partitionKey()); + assertEquals(1, ourResult.get(1).subSequenceNumber()); + assertArrayEquals("record-1".getBytes(StandardCharsets.UTF_8), ourResult.get(1).data()); + + assertEquals("pk-alpha", ourResult.get(2).partitionKey()); + assertEquals(2, ourResult.get(2).subSequenceNumber()); + assertArrayEquals("record-2".getBytes(StandardCharsets.UTF_8), ourResult.get(2).data()); + } + + @Test + void testKclAggregatedMixedWithPlainRecords() { + final Record plainRecord = buildKinesisRecord("seq-plain", "pk-plain", + "plain-data".getBytes(StandardCharsets.UTF_8)); + + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-inner") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("agg-data"))) + .build(); + final Record aggRecord = buildKinesisRecord("seq-agg", "outer-pk", + wrapAsKplPayload(aggProto.toByteArray())); + + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(plainRecord, aggRecord)); + + final List kclPlain = deaggregateViaKcl(plainRecord); + final List kclAgg = deaggregateViaKcl(aggRecord); + + assertEquals(2, ourResult.size()); + assertEquals(1, kclPlain.size()); + assertEquals(1, kclAgg.size()); + + assertDeaggregatedMatchesKcl(ourResult.get(0), kclPlain.getFirst()); + assertDeaggregatedMatchesKcl(ourResult.get(1), kclAgg.getFirst()); + } + + @Test + void testKclAggregatedWithExplicitHashKeys() { + final Messages.AggregatedRecord aggProto = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-0") + .addExplicitHashKeyTable("12345678901234567890") + .addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setExplicitHashKeyIndex(0) + .setData(ByteString.copyFromUtf8("with-ehk"))) + .build(); + final byte[] payload = wrapAsKplPayload(aggProto.toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-ehk", "outer-pk", payload); + + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(1, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + assertDeaggregatedMatchesKcl(ourResult.getFirst(), kclResult.getFirst()); + assertArrayEquals("with-ehk".getBytes(StandardCharsets.UTF_8), ourResult.getFirst().data()); + } + + @Test + void testKclAggregatedLargeBatch() { + final Messages.AggregatedRecord.Builder builder = Messages.AggregatedRecord.newBuilder() + .addPartitionKeyTable("pk-batch"); + for (int i = 0; i < 100; i++) { + builder.addRecords(Messages.Record.newBuilder() + .setPartitionKeyIndex(0) + .setData(ByteString.copyFromUtf8("record-" + i))); + } + final byte[] payload = wrapAsKplPayload(builder.build().toByteArray()); + final Record kinesisRecord = buildKinesisRecord("seq-batch", "outer-pk", payload); + + final List ourResult = ProducerLibraryDeaggregator.deaggregate(TEST_SHARD_ID, List.of(kinesisRecord)); + final List kclResult = deaggregateViaKcl(kinesisRecord); + + assertEquals(100, ourResult.size()); + assertEquals(kclResult.size(), ourResult.size()); + + for (int i = 0; i < ourResult.size(); i++) { + assertDeaggregatedMatchesKcl(ourResult.get(i), kclResult.get(i)); + assertEquals(i, ourResult.get(i).subSequenceNumber()); + assertArrayEquals(("record-" + i).getBytes(StandardCharsets.UTF_8), ourResult.get(i).data()); + } + } + + // ---- Test helpers ---- + + private record SubRecord(int partitionKeyIndex, byte[] data) { } + + private static Record buildKinesisRecord(final String sequenceNumber, final String partitionKey, final byte[] data) { + return Record.builder() + .sequenceNumber(sequenceNumber) + .partitionKey(partitionKey) + .data(SdkBytes.fromByteArray(data)) + .approximateArrivalTimestamp(ARRIVAL) + .build(); + } + + /** + * Wraps raw protobuf bytes in the KPL envelope format (magic + protobuf + MD5). + * + * @param protobufBytes serialized protobuf content + * @return complete KPL aggregated payload + */ + private static byte[] wrapAsKplPayload(final byte[] protobufBytes) { + try { + final byte[] md5 = MessageDigest.getInstance("MD5").digest(protobufBytes); + final ByteArrayOutputStream result = new ByteArrayOutputStream(); + result.write(ProducerLibraryDeaggregator.KPL_MAGIC); + result.write(protobufBytes); + result.write(md5); + return result.toByteArray(); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Deaggregates a single SDK v2 Record using the KCL's own {@link AggregatorUtil}. + * + * @param record the SDK v2 Kinesis record + * @return deaggregated KCL records + */ + private static List deaggregateViaKcl(final Record record) { + final KinesisClientRecord kcr = KinesisClientRecord.fromRecord(record); + return new AggregatorUtil().deaggregate(List.of(kcr)); + } + + private static void assertDeaggregatedMatchesKcl(final UserRecord ours, final KinesisClientRecord kcl) { + assertEquals(kcl.sequenceNumber(), ours.sequenceNumber(), "sequence number mismatch"); + assertEquals(kcl.subSequenceNumber(), ours.subSequenceNumber(), "sub-sequence number mismatch"); + assertEquals(kcl.partitionKey(), ours.partitionKey(), "partition key mismatch"); + assertArrayEquals(toBytes(kcl.data()), ours.data(), "data mismatch"); + } + + private static byte[] toBytes(final ByteBuffer buffer) { + final ByteBuffer dup = buffer.duplicate(); + final byte[] bytes = new byte[dup.remaining()]; + dup.get(bytes); + return bytes; + } + + private static byte[] buildAggregatedPayload(final List partitionKeys, final List subRecords) + throws Exception { + final ByteArrayOutputStream protobufBuffer = new ByteArrayOutputStream(); + final CodedOutputStream cos = CodedOutputStream.newInstance(protobufBuffer); + + for (final String pk : partitionKeys) { + cos.writeString(1, pk); + } + + for (final SubRecord sub : subRecords) { + final int innerSize = computeInnerRecordSize(sub); + cos.writeTag(3, 2); + cos.writeUInt32NoTag(innerSize); + cos.writeUInt64(1, sub.partitionKeyIndex); + cos.writeByteArray(3, sub.data); + } + + cos.flush(); + final byte[] protobufBytes = protobufBuffer.toByteArray(); + return wrapAsKplPayload(protobufBytes); + } + + private static int computeInnerRecordSize(final SubRecord sub) { + int size = 0; + size += CodedOutputStream.computeUInt64Size(1, sub.partitionKeyIndex); + size += CodedOutputStream.computeByteArraySize(3, sub.data); + return size; + } +} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java deleted file mode 100644 index 71c9ccd36fa9..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ProvenanceTransitUriFormatTest.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -class ProvenanceTransitUriFormatTest { - - @ParameterizedTest - @CsvSource(""" - streamA,shardId123,kinesis:stream/streamA/shardId123 - stream-name-b,shardId-000000013,kinesis:stream/stream-name-b/shardId-000000013 - kinesis,shardId-00000001,kinesis:stream/kinesis/shardId-00000001 - """) - void toTransitUri(final String streamName, final String shardId, final String expectedTransitUri) { - final String transitUri = ProvenanceTransitUriFormat.toTransitUri(streamName, shardId); - assertEquals(expectedTransitUri, transitUri); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java deleted file mode 100644 index c4023148aa50..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/ReaderRecordProcessorTest.java +++ /dev/null @@ -1,397 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis; - -import org.apache.nifi.flowfile.FlowFile; -import org.apache.nifi.json.JsonRecordSetWriter; -import org.apache.nifi.json.JsonTreeReader; -import org.apache.nifi.logging.ComponentLog; -import org.apache.nifi.processors.aws.kinesis.ReaderRecordProcessor.ProcessingResult; -import org.apache.nifi.processors.aws.kinesis.converter.ValueRecordConverter; -import org.apache.nifi.reporting.InitializationException; -import org.apache.nifi.schema.access.SchemaAccessUtils; -import org.apache.nifi.schema.inference.SchemaInferenceUtil; -import org.apache.nifi.serialization.MalformedRecordException; -import org.apache.nifi.serialization.RecordReader; -import org.apache.nifi.serialization.RecordReaderFactory; -import org.apache.nifi.serialization.record.MockRecordParser; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordSchema; -import org.apache.nifi.util.MockFlowFile; -import org.apache.nifi.util.MockProcessSession; -import org.apache.nifi.util.SharedSessionState; -import org.apache.nifi.util.TestRunner; -import org.apache.nifi.util.TestRunners; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.time.Instant; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; - -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.APPROXIMATE_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.FIRST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.LAST_SUB_SEQUENCE_NUMBER; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.MIME_TYPE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.PARTITION_KEY; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_COUNT; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.RECORD_ERROR_MESSAGE; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.ConsumeKinesisAttributes.STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.JsonRecordAssert.assertFlowFileRecords; -import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -class ReaderRecordProcessorTest { - - private static final String TEST_STREAM_NAME = "stream-test"; - private static final String TEST_SHARD_ID = "shardId-test"; - - private static final String USER_JSON_1 = "{\"name\":\"John\",\"age\":30}"; - private static final String USER_JSON_2 = "{\"name\":\"Jane\",\"age\":25}"; - private static final String USER_JSON_3 = "{\"name\":\"Bob\",\"age\":35}"; - - private static final String CITY_JSON_1 = "{\"name\":\"Seattle\",\"country\":\"US\"}"; - private static final String CITY_JSON_2 = "{\"name\":\"Warsaw\",\"country\":\"PL\"}"; - - private static final String INVALID_JSON = "{invalid json}"; - - private MockProcessSession session; - private ComponentLog logger; - - private JsonRecordSetWriter jsonWriter; - private ReaderRecordProcessor processor; - - @BeforeEach - void setUp() throws InitializationException { - final TestRunner runner = TestRunners.newTestRunner(ConsumeKinesis.class); - final SharedSessionState sharedState = new SharedSessionState(runner.getProcessor(), new AtomicLong(0)); - session = new MockProcessSession(sharedState, runner.getProcessor()); - logger = runner.getLogger(); - - final JsonTreeReader jsonReader = new JsonTreeReader(); - runner.addControllerService("json-reader", jsonReader); - runner.setProperty(jsonReader, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaInferenceUtil.INFER_SCHEMA.getValue()); - runner.enableControllerService(jsonReader); - - jsonWriter = new JsonRecordSetWriter(); - runner.addControllerService("json-writer", jsonWriter); - runner.setProperty(jsonWriter, SchemaAccessUtils.SCHEMA_ACCESS_STRATEGY, SchemaAccessUtils.INHERIT_RECORD_SCHEMA.getValue()); - runner.enableControllerService(jsonWriter); - - processor = new ReaderRecordProcessor(jsonReader, new ValueRecordConverter(), jsonWriter, logger); - } - - @Test - void testProcessSingleRecord() { - final KinesisClientRecord record = KinesisClientRecord.builder() - .data(ByteBuffer.wrap(USER_JSON_1.getBytes(UTF_8))) - .sequenceNumber("1") - .subSequenceNumber(2) - .approximateArrivalTimestamp(Instant.now()) - .partitionKey("key-123") - .build(); - final List records = List.of(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - - assertEquals(TEST_STREAM_NAME, successFlowFile.getAttribute(STREAM_NAME)); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - - assertEquals(record.sequenceNumber(), successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(record.subSequenceNumber()), successFlowFile.getAttribute(FIRST_SUB_SEQUENCE_NUMBER)); - assertEquals(record.sequenceNumber(), successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(record.subSequenceNumber()), successFlowFile.getAttribute(LAST_SUB_SEQUENCE_NUMBER)); - - assertEquals(record.partitionKey(), successFlowFile.getAttribute(PARTITION_KEY)); - assertEquals(String.valueOf(record.approximateArrivalTimestamp().toEpochMilli()), successFlowFile.getAttribute(APPROXIMATE_ARRIVAL_TIMESTAMP)); - - assertEquals("application/json", successFlowFile.getAttribute(MIME_TYPE)); - assertEquals("1", successFlowFile.getAttribute(RECORD_COUNT)); - - assertFlowFileRecords(successFlowFile, records); - } - - @Test - void testProcessMultipleRecordsWithSameSchema() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(USER_JSON_2, "2"), - createKinesisRecord(USER_JSON_3, "3") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - assertEquals(TEST_STREAM_NAME, successFlowFile.getAttribute(STREAM_NAME)); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - assertEquals("3", successFlowFile.getAttribute(RECORD_COUNT)); - - assertEquals(records.getFirst().sequenceNumber(), successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(records.getFirst().subSequenceNumber()), successFlowFile.getAttribute(FIRST_SUB_SEQUENCE_NUMBER)); - assertEquals(records.getLast().sequenceNumber(), successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertEquals(String.valueOf(records.getLast().subSequenceNumber()), successFlowFile.getAttribute(LAST_SUB_SEQUENCE_NUMBER)); - - assertFlowFileRecords(successFlowFile, records); - } - - @Test - void testEmptyRecordsList() { - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, Collections.emptyList()); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - } - - @Test - void testSchemaChangeCreatesNewFlowFile() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(CITY_JSON_1, "2") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(2, result.successFlowFiles().size()); // Two different schemas = two FlowFiles - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("1", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(firstFlowFile, records.getFirst()); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("1", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(secondFlowFile, records.get(1)); - } - - @Test - void testSchemaChangeWithMultipleRecordsInBetween() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(USER_JSON_2, "2"), - createKinesisRecord(CITY_JSON_1, "3"), - createKinesisRecord(CITY_JSON_2, "4") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(2, result.successFlowFiles().size()); - assertEquals(0, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(firstFlowFile, records.subList(0, 2)); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("2", secondFlowFile.getAttribute(RECORD_COUNT)); - assertFlowFileRecords(secondFlowFile, records.subList(2, 4)); - } - - @Test - void testSingleMalformedRecord() { - final List records = List.of( - createKinesisRecord(INVALID_JSON, "1") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertEquals(TEST_SHARD_ID, failureFlowFile.getAttribute(SHARD_ID)); - assertEquals("1", failureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("1", failureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - - failureFlowFile.assertContentEquals(INVALID_JSON, UTF_8); - } - - @Test - void testMalformedRecordBetweenValid() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), - createKinesisRecord(INVALID_JSON, "2"), - createKinesisRecord(USER_JSON_2, "3"), - createKinesisRecord(INVALID_JSON, "4"), - createKinesisRecord(USER_JSON_3, "5") - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(1, result.successFlowFiles().size()); - assertEquals(2, result.parseFailureFlowFiles().size()); - - final FlowFile successFlowFile = result.successFlowFiles().getFirst(); - assertEquals(TEST_SHARD_ID, successFlowFile.getAttribute(SHARD_ID)); - assertEquals("3", successFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("1", successFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("5", successFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(successFlowFile, records.get(0), records.get(2), records.get(4)); - - assertAll(result.parseFailureFlowFiles().stream().map( - failureFlowFile -> () -> { - assertNotNull(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, failureFlowFile.getAttribute(SHARD_ID)); - } - )); - } - - @Test - void testIOExceptionDuringReaderCreation() { - final RecordReaderFactory failingReaderFactory = new MockRecordParser() { - @Override - public RecordReader createRecordReader(Map variables, InputStream in, long inputLength, ComponentLog logger) throws IOException { - throw new IOException("Failed to create reader"); - } - }; - - final ReaderRecordProcessor processor = new ReaderRecordProcessor(failingReaderFactory, new ValueRecordConverter(), jsonWriter, logger); - - final KinesisClientRecord record = createKinesisRecord(USER_JSON_1, "1"); - final List records = List.of(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertTrue(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE).contains("Failed to create reader")); - failureFlowFile.assertContentEquals(KinesisRecordPayload.extract(record), UTF_8); - } - - @Test - void testMalformedRecordExceptionDuringReading() { - final ReaderRecordProcessor processor = new ReaderRecordProcessor(getMalformedRecordExceptionReader(), new ValueRecordConverter(), jsonWriter, logger); - - final KinesisClientRecord record = createKinesisRecord(USER_JSON_1, "1"); - final List records = Collections.singletonList(record); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(0, result.successFlowFiles().size()); - assertEquals(1, result.parseFailureFlowFiles().size()); - - final MockFlowFile failureFlowFile = (MockFlowFile) result.parseFailureFlowFiles().getFirst(); - assertTrue(failureFlowFile.getAttribute(RECORD_ERROR_MESSAGE).contains("Test exception")); - failureFlowFile.assertContentEquals(KinesisRecordPayload.extract(record), UTF_8); - } - - @Test - void testInvalidRecordsWithSchemaEvolution() { - final List records = List.of( - createKinesisRecord(USER_JSON_1, "1"), // Schema A - createKinesisRecord(USER_JSON_2, "2"), // Schema A - createKinesisRecord(INVALID_JSON, "3"), - createKinesisRecord(CITY_JSON_1, "4"), // Schema B - createKinesisRecord(INVALID_JSON, "5"), - createKinesisRecord(CITY_JSON_2, "6"), // Schema B - createKinesisRecord("{\"category\":\"electronics\",\"price\":99.99}", "7") // Schema C - ); - - final ProcessingResult result = processor.processRecords(session, TEST_STREAM_NAME, TEST_SHARD_ID, records); - - assertEquals(3, result.successFlowFiles().size()); - assertEquals(2, result.parseFailureFlowFiles().size()); - - final FlowFile firstFlowFile = result.successFlowFiles().getFirst(); - assertEquals("2", firstFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("1", firstFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("2", firstFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(firstFlowFile, records.subList(0, 2)); - - final FlowFile secondFlowFile = result.successFlowFiles().get(1); - assertEquals("2", secondFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("4", secondFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("6", secondFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(secondFlowFile, records.get(3), records.get(5)); - - final FlowFile thirdFlowFile = result.successFlowFiles().get(2); - assertEquals("1", thirdFlowFile.getAttribute(RECORD_COUNT)); - assertEquals("7", thirdFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("7", thirdFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertFlowFileRecords(thirdFlowFile, records.get(6)); - - final List failureFlowFiles = result.parseFailureFlowFiles(); - - final MockFlowFile firstFailureFlowFile = (MockFlowFile) failureFlowFiles.getFirst(); - assertEquals("3", firstFailureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("3", firstFailureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(firstFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, firstFailureFlowFile.getAttribute(SHARD_ID)); - firstFailureFlowFile.assertContentEquals(KinesisRecordPayload.extract(records.get(2)), UTF_8); - - final MockFlowFile secondFailureFlowFile = (MockFlowFile) failureFlowFiles.get(1); - assertEquals("5", secondFailureFlowFile.getAttribute(FIRST_SEQUENCE_NUMBER)); - assertEquals("5", secondFailureFlowFile.getAttribute(LAST_SEQUENCE_NUMBER)); - assertNotNull(secondFailureFlowFile.getAttribute(RECORD_ERROR_MESSAGE)); - assertEquals(TEST_SHARD_ID, secondFailureFlowFile.getAttribute(SHARD_ID)); - secondFailureFlowFile.assertContentEquals(KinesisRecordPayload.extract(records.get(4)), UTF_8); - } - - private static KinesisClientRecord createKinesisRecord(final String data, final String sequenceNumber) { - return KinesisClientRecord.builder() - .data(ByteBuffer.wrap(data.getBytes(UTF_8))) - .sequenceNumber(sequenceNumber) - .partitionKey("key") - .approximateArrivalTimestamp(Instant.now()) - .build(); - } - - private static RecordReaderFactory getMalformedRecordExceptionReader() { - return new MockRecordParser() { - @Override - public RecordReader createRecordReader(Map variables, InputStream in, long inputLength, ComponentLog logger) { - return new RecordReader() { - @Override - public void close() { - } - - @Override - public Record nextRecord(boolean coerceTypes, boolean dropUnknownFields) throws MalformedRecordException { - throw new MalformedRecordException("Test exception"); - } - - @Override - public RecordSchema getSchema() { - return null; - } - }; - } - }; - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java deleted file mode 100644 index c4261d3087c9..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/InjectMetadataRecordConverterTest.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_RECORD; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.KINESIS_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.SCHEMA_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.createTestKinesisRecord; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.verifyMetadata; -import static org.junit.jupiter.api.Assertions.assertEquals; - -class InjectMetadataRecordConverterTest { - - private static final RecordSchema EXPECTED_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField("name", RecordFieldType.STRING.getDataType()), - new RecordField("age", RecordFieldType.INT.getDataType()), - new RecordField(KINESIS_METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)) - )); - - private static final InjectMetadataRecordConverter CONVERTER = new InjectMetadataRecordConverter(); - - @Test - void testConvertWithApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(TEST_ARRIVAL_TIMESTAMP); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.remove(KINESIS_METADATA); - assertEquals(INPUT_RECORD.toMap(), recordValues); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = true; - verifyMetadata(metadata, expectTimestamp); - } - - @Test - void testConvertWithoutApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(null); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - - final Map recordValues = new HashMap<>(record.toMap()); - recordValues.remove(KINESIS_METADATA); - assertEquals(INPUT_RECORD.toMap(), recordValues); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = false; - verifyMetadata(metadata, expectTimestamp); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java deleted file mode 100644 index 12f736656af2..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/KinesisRecordConverterTestUtil.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import jakarta.annotation.Nullable; -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.MapRecord; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.nio.ByteBuffer; -import java.time.Instant; -import java.util.List; -import java.util.Map; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordMetadata.APPROX_ARRIVAL_TIMESTAMP; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; - -public class KinesisRecordConverterTestUtil { - - static final String KINESIS_METADATA = "kinesisMetadata"; - - static final String TEST_STREAM_NAME = "test-stream"; - static final String TEST_SHARD_ID = "shardId-000000000001"; - static final String TEST_SEQUENCE_NUMBER = "49590338271490256608559692538361571095921575989136588801"; - static final long TEST_SUB_SEQUENCE_NUMBER = 2; - static final String TEST_PARTITION_KEY = "test-partition-key"; - static final Instant TEST_ARRIVAL_TIMESTAMP = Instant.ofEpochMilli(1640995200000L); - - static final String EXPECTED_SHARDED_SEQUENCE_NUMBER = "4959033827149025660855969253836157109592157598913658880100000000000000000002"; - - static final RecordSchema INPUT_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField("name", RecordFieldType.STRING.getDataType()), - new RecordField("age", RecordFieldType.INT.getDataType()) - )); - - static final Record INPUT_RECORD = new MapRecord(INPUT_SCHEMA, Map.of( - "name", "John Doe", - "age", 30 - )); - - static final RecordSchema SCHEMA_METADATA = new SimpleRecordSchema(List.of( - new RecordField("stream", RecordFieldType.STRING.getDataType()), - new RecordField("shardId", RecordFieldType.STRING.getDataType()), - new RecordField("sequenceNumber", RecordFieldType.STRING.getDataType()), - new RecordField("subSequenceNumber", RecordFieldType.LONG.getDataType()), - new RecordField("shardedSequenceNumber", RecordFieldType.STRING.getDataType()), - new RecordField("partitionKey", RecordFieldType.STRING.getDataType()), - new RecordField(APPROX_ARRIVAL_TIMESTAMP, RecordFieldType.TIMESTAMP.getDataType()) - )); - - private KinesisRecordConverterTestUtil() { - // Utility class - } - - static KinesisClientRecord createTestKinesisRecord(final @Nullable Instant arrivalTimestamp) { - return KinesisClientRecord.builder() - .data(ByteBuffer.allocate(0)) - .sequenceNumber(TEST_SEQUENCE_NUMBER) - .subSequenceNumber(TEST_SUB_SEQUENCE_NUMBER) - .partitionKey(TEST_PARTITION_KEY) - .approximateArrivalTimestamp(arrivalTimestamp) - .build(); - } - - static void verifyMetadata(final Record metadata, final boolean expectTimestamp) { - assertEquals(TEST_STREAM_NAME, metadata.getValue("stream")); - assertEquals(TEST_SHARD_ID, metadata.getValue("shardId")); - assertEquals(TEST_SEQUENCE_NUMBER, metadata.getValue("sequenceNumber")); - assertEquals(TEST_SUB_SEQUENCE_NUMBER, metadata.getValue("subSequenceNumber")); - assertEquals(EXPECTED_SHARDED_SEQUENCE_NUMBER, metadata.getValue("shardedSequenceNumber")); - assertEquals(TEST_PARTITION_KEY, metadata.getValue("partitionKey")); - - if (expectTimestamp) { - assertEquals(TEST_ARRIVAL_TIMESTAMP.toEpochMilli(), metadata.getValue(APPROX_ARRIVAL_TIMESTAMP)); - } else { - assertNull(metadata.getValue(APPROX_ARRIVAL_TIMESTAMP)); - } - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java deleted file mode 100644 index b15e1efc5c45..000000000000 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-kinesis/src/test/java/org/apache/nifi/processors/aws/kinesis/converter/WrapperRecordConverterTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.nifi.processors.aws.kinesis.converter; - -import org.apache.nifi.serialization.SimpleRecordSchema; -import org.apache.nifi.serialization.record.Record; -import org.apache.nifi.serialization.record.RecordField; -import org.apache.nifi.serialization.record.RecordFieldType; -import org.apache.nifi.serialization.record.RecordSchema; -import org.junit.jupiter.api.Test; -import software.amazon.kinesis.retrieval.KinesisClientRecord; - -import java.util.List; - -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_RECORD; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.INPUT_SCHEMA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.KINESIS_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.SCHEMA_METADATA; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_ARRIVAL_TIMESTAMP; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_SHARD_ID; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.TEST_STREAM_NAME; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.createTestKinesisRecord; -import static org.apache.nifi.processors.aws.kinesis.converter.KinesisRecordConverterTestUtil.verifyMetadata; -import static org.junit.jupiter.api.Assertions.assertEquals; - -class WrapperRecordConverterTest { - - private static final RecordSchema EXPECTED_SCHEMA = new SimpleRecordSchema(List.of( - new RecordField(KINESIS_METADATA, RecordFieldType.RECORD.getRecordDataType(SCHEMA_METADATA)), - new RecordField("value", RecordFieldType.RECORD.getRecordDataType(INPUT_SCHEMA)) - )); - - private static final WrapperRecordConverter CONVERTER = new WrapperRecordConverter(); - - @Test - void testConvertWithApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(TEST_ARRIVAL_TIMESTAMP); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - assertEquals(INPUT_RECORD, record.getValue("value")); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = true; - verifyMetadata(metadata, expectTimestamp); - } - - @Test - void testConvertWithoutApproximateArrivalTimestamp() { - final KinesisClientRecord kinesisRecord = createTestKinesisRecord(null); - - final Record record = CONVERTER.convert(INPUT_RECORD, kinesisRecord, TEST_STREAM_NAME, TEST_SHARD_ID); - - assertEquals(EXPECTED_SCHEMA, record.getSchema()); - assertEquals(INPUT_RECORD, record.getValue("value")); - - final Record metadata = record.getAsRecord(KINESIS_METADATA, SCHEMA_METADATA); - final boolean expectTimestamp = false; - verifyMetadata(metadata, expectTimestamp); - } -} diff --git a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-service-api-nar/pom.xml b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-service-api-nar/pom.xml index 8c000576d0de..125aea297a7a 100644 --- a/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-service-api-nar/pom.xml +++ b/nifi-extension-bundles/nifi-aws-bundle/nifi-aws-service-api-nar/pom.xml @@ -37,6 +37,13 @@ nifi-aws-service-api compile
+ + + software.amazon.awssdk + apache5-client + ${software.amazon.awssdk.version} + compile + software.amazon.awssdk