diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index d76108a2e..5848b93c6 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -25,6 +25,7 @@ import ( "math/rand" "strconv" "sync" + "sync/atomic" "time" "github.com/apache/pulsar-client-go/pulsar/crypto" @@ -49,18 +50,11 @@ type acker interface { } type consumer struct { - sync.Mutex topic string client *client options ConsumerOptions - // When accessing `consumers`, the lock must be acquired in case partitions are being added - // in the background by `internalTopicSubscribeToPartitions`. Currently, when a new sub-consumer - // is created, the current consumer can immediately receive messages from the new partition. However, - // before the new sub-consumers are visible in `consumers`, the Ack related methods cannot find the - // sub-consumer for the message's message ID, so we cannot simply change `consumers` to `atomic.Value` - // and perform copy-on-write when partitions are added. - consumers []*partitionConsumer + consumers atomic.Value consumerName string disableForceTopicCreation bool @@ -356,14 +350,10 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { return err } - oldNumPartitions := 0 newNumPartitions := len(partitions) - c.Lock() - defer c.Unlock() - - oldConsumers := c.consumers - oldNumPartitions = len(oldConsumers) + oldConsumers := c.partitionConsumers() + oldNumPartitions := len(oldConsumers) if oldConsumers != nil { if oldNumPartitions == newNumPartitions { @@ -376,14 +366,14 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { Info("Changed number of partitions in topic") } - c.consumers = make([]*partitionConsumer, newNumPartitions) + newConsumers := make([]*partitionConsumer, newNumPartitions) // When for some reason (eg: forced deletion of sub partition) causes oldNumPartitions> newNumPartitions, // we need to rebuild the cache of new consumers, otherwise the array will be out of bounds. if oldConsumers != nil && oldNumPartitions < newNumPartitions { // Copy over the existing consumer instances for i := 0; i < oldNumPartitions; i++ { - c.consumers[i] = oldConsumers[i] + newConsumers[i] = oldConsumers[i] } } @@ -408,16 +398,16 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { for partitionIdx := startPartition; partitionIdx < newNumPartitions; partitionIdx++ { partitionTopic := partitions[partitionIdx] - go func() { + go func(partitionIdx int, partitionTopic string) { defer wg.Done() opts := newPartitionConsumerOpts(partitionTopic, c.consumerName, partitionIdx, c.options) - cons, err := newPartitionConsumer(c, c.client, opts, c.messageCh, c.dlq, c.metrics) + cons, err := newPartitionConsumer(c, c.client, opts, c.messageCh, c.dlq, c.metrics, false) ch <- ConsumerError{ err: err, partition: partitionIdx, consumer: cons, } - }() + }(partitionIdx, partitionTopic) } go func() { @@ -429,14 +419,14 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { if ce.err != nil { err = ce.err } else { - c.consumers[ce.partition] = ce.consumer + newConsumers[ce.partition] = ce.consumer } } if err != nil { // Since there were some failures, // cleanup all the partitions that succeeded in creating the consumer - for _, c := range c.consumers { + for _, c := range newConsumers { if c != nil { c.Close() } @@ -444,6 +434,10 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { return err } + c.consumers.Store(append([]*partitionConsumer(nil), newConsumers...)) + for partitionIdx := startPartition; partitionIdx < newNumPartitions; partitionIdx++ { + newConsumers[partitionIdx].startDispatcher() + } if newNumPartitions < oldNumPartitions { c.metrics.ConsumersPartitions.Set(float64(newNumPartitions)) } else { @@ -510,11 +504,9 @@ func (c *consumer) UnsubscribeForce() error { } func (c *consumer) unsubscribe(force bool) error { - c.Lock() - defer c.Unlock() - + consumers := c.partitionConsumers() var errMsg string - for _, consumer := range c.consumers { + for _, consumer := range consumers { if err := consumer.unsubscribe(force); err != nil { errMsg += fmt.Sprintf("topic %s, subscription %s: %s", consumer.topic, c.Subscription(), err) } @@ -526,8 +518,9 @@ func (c *consumer) unsubscribe(force bool) error { } func (c *consumer) GetLastMessageIDs() ([]TopicMessageID, error) { + consumers := c.partitionConsumers() ids := make([]TopicMessageID, 0) - for _, pc := range c.consumers { + for _, pc := range consumers { id, err := pc.getLastMessageID() tm := &topicMessageID{topic: pc.topic, track: id} if err != nil { @@ -556,7 +549,7 @@ func (c *consumer) Receive(ctx context.Context) (message Message, err error) { func (c *consumer) AckWithTxn(msg Message, txn Transaction) error { msgID := msg.ID() - consumer, err := c.findPartitionConsumer(msgID) + consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID) if err != nil { return err } @@ -575,7 +568,7 @@ func (c *consumer) Ack(msg Message) error { // AckID the consumption of a single message, identified by its MessageID func (c *consumer) AckID(msgID MessageID) error { - consumer, err := c.findPartitionConsumer(msgID) + consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID) if err != nil { return err } @@ -587,7 +580,7 @@ func (c *consumer) AckID(msgID MessageID) error { func (c *consumer) AckIDList(msgIDs []MessageID) error { return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) { - return c.findPartitionConsumer(msgID) + return findPartitionConsumer(c.partitionConsumers(), msgID) }) } @@ -600,7 +593,7 @@ func (c *consumer) AckCumulative(msg Message) error { // AckIDCumulative the reception of all the messages in the stream up to (and including) // the provided message, identified by its MessageID func (c *consumer) AckIDCumulative(msgID MessageID) error { - consumer, err := c.findPartitionConsumer(msgID) + consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID) if err != nil { return err } @@ -697,7 +690,7 @@ func (c *consumer) Nack(msg Message) { mid.NackByMsg(msg) return } - if consumer, err := c.findPartitionConsumer(mid); err == nil { + if consumer, err := findPartitionConsumer(c.partitionConsumers(), mid); err == nil { consumer.NackMsg(msg) } return @@ -707,7 +700,7 @@ func (c *consumer) Nack(msg Message) { } func (c *consumer) NackID(msgID MessageID) { - if consumer, err := c.findPartitionConsumer(msgID); err == nil { + if consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID); err == nil { consumer.NackID(msgID) } } @@ -724,16 +717,14 @@ func (c *consumer) closeWithCause(err error) { c.closeOnce.Do(func() { c.stopDiscovery() - c.Lock() - defer c.Unlock() - var wg sync.WaitGroup - for i := range c.consumers { + consumers := c.partitionConsumers() + for i := range consumers { wg.Add(1) go func(pc *partitionConsumer) { defer wg.Done() pc.Close() - }(c.consumers[i]) + }(consumers[i]) } wg.Wait() close(c.closeCh) @@ -741,20 +732,19 @@ func (c *consumer) closeWithCause(err error) { c.dlq.close() c.rlq.close() c.metrics.ConsumersClosed.Inc() - c.metrics.ConsumersPartitions.Sub(float64(len(c.consumers))) + c.metrics.ConsumersPartitions.Sub(float64(len(consumers))) c.options.Interceptors.OnConsumerClose(c, err) }) } func (c *consumer) Seek(msgID MessageID) error { - c.Lock() - defer c.Unlock() + consumers := c.partitionConsumers() - if len(c.consumers) > 1 { + if len(consumers) > 1 { return newError(SeekFailed, "for partition topic, seek command should perform on the individual partitions") } - consumer, err := c.unsafeFindPartitionConsumer(msgID) + consumer, err := findPartitionConsumer(consumers, msgID) if err != nil { return err } @@ -768,11 +758,10 @@ func (c *consumer) Seek(msgID MessageID) error { } func (c *consumer) SeekByTime(time time.Time) error { - c.Lock() - defer c.Unlock() var errs error + consumers := c.partitionConsumers() - for _, cons := range c.consumers { + for _, cons := range consumers { cons.pauseDispatchMessage() } // clear messageCh @@ -781,7 +770,7 @@ func (c *consumer) SeekByTime(time time.Time) error { } // run SeekByTime on every partition of topic - for _, cons := range c.consumers { + for _, cons := range consumers { if err := cons.SeekByTime(time); err != nil { msg := fmt.Sprintf("unable to SeekByTime for topic=%s subscription=%s", c.topic, c.Subscription()) errs = pkgerrors.Wrap(newError(SeekFailed, err.Error()), msg) @@ -791,35 +780,30 @@ func (c *consumer) SeekByTime(time time.Time) error { return errs } -func (c *consumer) findPartitionConsumer(msgID MessageID) (*partitionConsumer, error) { - c.Lock() - defer c.Unlock() - return c.unsafeFindPartitionConsumer(msgID) -} - -// NOTE: This method must be called when c.Lock is held -func (c *consumer) unsafeFindPartitionConsumer(msgID MessageID) (*partitionConsumer, error) { +func findPartitionConsumer(consumers []*partitionConsumer, msgID MessageID) (*partitionConsumer, error) { partition := int(msgID.PartitionIdx()) - if partition < 0 || partition >= len(c.consumers) { - c.log.Errorf("invalid partition index %d expected a partition between [0-%d]", - partition, len(c.consumers)) + if partition < 0 || partition >= len(consumers) { return nil, fmt.Errorf("invalid partition index %d expected a partition between [0-%d]", - partition, len(c.consumers)) + partition, len(consumers)-1) + } + return consumers[partition], nil +} + +func (c *consumer) partitionConsumers() []*partitionConsumer { + v := c.consumers.Load() + if v == nil { + return nil } - return c.consumers[partition], nil + // The slice stored in c.consumers is published via copy-on-write. + // Callers must treat the returned slice as immutable. + return v.([]*partitionConsumer) } func (c *consumer) hasNext() bool { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Make sure all paths cancel the context to avoid context leak - // We have to make a snapshot consumers, because we have to iterate over all consumers in - // other goroutines. But when this method returns, there might be still other consumers - // not completing the `hasNext` call, so we cannot just call defer `c.Unlock()` after acquiring the lock. - c.Lock() - consumers := make([]*partitionConsumer, len(c.consumers)) - copy(consumers, c.consumers) - c.Unlock() + consumers := c.partitionConsumers() var wg sync.WaitGroup wg.Add(len(consumers)) @@ -853,7 +837,7 @@ func (c *consumer) hasNext() bool { } func (c *consumer) setLastDequeuedMsg(msgID MessageID) error { - consumer, err := c.findPartitionConsumer(msgID) + consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID) if err != nil { return err } @@ -920,7 +904,7 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) pb.CommandSubscribe_I } func (c *consumer) messageID(msgID MessageID) *trackingMessageID { - if _, err := c.findPartitionConsumer(msgID); err != nil { + if _, err := findPartitionConsumer(c.partitionConsumers(), msgID); err != nil { return nil } return toTrackingMessageID(msgID) diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index d7aba5244..08e1b36f2 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -401,8 +401,8 @@ func (s *schemaInfoCache) add(schemaVersionHash string, schema Schema) { } func newPartitionConsumer(parent Consumer, client *client, options *partitionConsumerOpts, - messageCh chan ConsumerMessage, dlq *dlqRouter, - metrics *internal.LeveledMetrics) (*partitionConsumer, error) { + messageCh chan ConsumerMessage, dlq *dlqRouter, metrics *internal.LeveledMetrics, + startDispatcher bool) (*partitionConsumer, error) { var boFunc func() backoff.Policy if options.backOffPolicyFunc != nil { boFunc = options.backOffPolicyFunc @@ -425,7 +425,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon queueCh: make(chan []*message, options.receiverQueueSize), startMessageID: atomicMessageID{msgID: options.startMessageID}, seekMessageID: atomicMessageID{msgID: nil}, - connectedCh: make(chan struct{}), + connectedCh: make(chan struct{}, 1), messageCh: messageCh, connectClosedCh: make(chan *connectionClosed, 1), closeCh: make(chan struct{}), @@ -512,13 +512,18 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon } } - go pc.dispatcher() - go pc.runEventsLoop() + if startDispatcher { + pc.startDispatcher() + } return pc, nil } +func (pc *partitionConsumer) startDispatcher() { + go pc.dispatcher() +} + func (pc *partitionConsumer) unsubscribe(force bool) error { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to unsubscribe closing or closed consumer") diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go index e19642871..c3fc82d4d 100644 --- a/pulsar/consumer_test.go +++ b/pulsar/consumer_test.go @@ -22,7 +22,9 @@ import ( "errors" "fmt" "log" + "log/slog" "net/http" + "net/url" "os" "regexp" "strconv" @@ -4723,7 +4725,7 @@ func TestConsumerWithBackoffPolicy(t *testing.T) { assert.Nil(t, err) defer _consumer.Close() - partitionConsumerImp := _consumer.(*consumer).consumers[0] + partitionConsumerImp := _consumer.(*consumer).partitionConsumers()[0] // 1 s startTime := time.Now() partitionConsumerImp.reconnectToBroker(nil) @@ -4946,7 +4948,7 @@ func TestConsumerWithAutoScaledQueueReceive(t *testing.T) { EnableAutoScaledReceiverQueueSize: true, }) assert.Nil(t, err) - pc := c.(*consumer).consumers[0] + pc := c.(*consumer).partitionConsumers()[0] assert.Equal(t, int32(1), pc.currentQueueSize.Load()) defer c.Close() @@ -5161,7 +5163,7 @@ func TestConsumerMemoryLimit(t *testing.T) { }) assert.Nil(t, err) defer c1.Close() - pc1 := c1.(*consumer).consumers[0] + pc1 := c1.(*consumer).partitionConsumers()[0] // Fill up the messageCh of c1 for i := 0; i < 10; i++ { @@ -5201,7 +5203,7 @@ func TestConsumerMemoryLimit(t *testing.T) { }) assert.Nil(t, err) defer c2.Close() - pc2 := c2.(*consumer).consumers[0] + pc2 := c2.(*consumer).partitionConsumers()[0] // Try to induce c2 receiver queue size expansion for i := 0; i < 10; i++ { @@ -5273,7 +5275,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) { }) assert.Nil(t, err) defer c1.Close() - pc1 := c1.(*consumer).consumers[0] + pc1 := c1.(*consumer).partitionConsumers()[0] // Use mem-limited client 2 to create consumer c1 c2, err := cli2.Subscribe(ConsumerOptions{ @@ -5284,7 +5286,7 @@ func TestMultiConsumerMemoryLimit(t *testing.T) { }) assert.Nil(t, err) defer c2.Close() - pc2 := c2.(*consumer).consumers[0] + pc2 := c2.(*consumer).partitionConsumers()[0] // Fill up the messageCh of c1 nad c2 for i := 0; i < 10; i++ { @@ -5918,7 +5920,7 @@ func TestSelectConnectionForSameConsumer(t *testing.T) { assert.NoError(t, err) defer _consumer.Close() - partitionConsumerImpl := _consumer.(*consumer).consumers[0] + partitionConsumerImpl := _consumer.(*consumer).partitionConsumers()[0] conn := partitionConsumerImpl._getConn() for i := 0; i < 5; i++ { @@ -5928,6 +5930,382 @@ func TestSelectConnectionForSameConsumer(t *testing.T) { } } +func TestInternalTopicSubscribeToPartitionsDoesNotBlockExistingPartitionLookup(t *testing.T) { + lookupURL, err := url.Parse("pulsar://localhost:6650") + require.NoError(t, err) + + allowSubscribe := make(chan struct{}) + subscribeStarted := make(chan struct{}) + var releaseSubscribe sync.Once + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) + log := plog.NewLoggerWithSlog(logger) + + rpcClient := &blockingSubscribeRPCClient{ + lookupResult: &internal.LookupResult{LogicalAddr: lookupURL, PhysicalAddr: lookupURL}, + subscribeStarted: subscribeStarted, + allowSubscribe: allowSubscribe, + subscribeErr: errors.New("stop subscribe after lookup check"), + nextConsumerID: 1, + } + + c := newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{ + conn: dummyConnection{}, + rpcClient: rpcClient, + partitions: 2, + log: log, + consumerOptions: ConsumerOptions{SubscriptionName: "test-sub", NackPrecisionBit: ptr(defaultNackPrecisionBit)}, + initialConsumers: []*partitionConsumer{{topic: "persistent://public/default/test-topic-partition-0"}}, + }) + + go func() { + c.internalTopicSubscribeToPartitions() + }() + + select { + case <-subscribeStarted: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for partition discovery to start subscribing the new partition") + } + + lookupErrCh := make(chan error, 1) + go func() { + _, err := findPartitionConsumer(c.partitionConsumers(), &messageID{partitionIdx: 0}) + lookupErrCh <- err + }() + + select { + case err := <-lookupErrCh: + require.NoError(t, err) + case <-time.After(3 * time.Second): + releaseSubscribe.Do(func() { close(allowSubscribe) }) + select { + case <-lookupErrCh: + case <-time.After(time.Second): + t.Fatal("existing partition lookup stayed blocked even after partition discovery stopped") + } + t.Fatal("existing partition lookup blocked while a new partition was being added") + } + + releaseSubscribe.Do(func() { close(allowSubscribe) }) +} + +func TestInternalTopicSubscribeToPartitionsPublishesConsumersBeforeDispatchingMessages(t *testing.T) { + lookupURL, err := url.Parse("pulsar://localhost:6650") + require.NoError(t, err) + + partitionOneSubscribed := make(chan struct{}) + partitionOneFlowed := make(chan struct{}) + partitionTwoBlocked := make(chan struct{}) + allowPartitionTwo := make(chan struct{}) + cnx := newPartitionExpansionRaceConnection() + rpcClient := &partitionExpansionRaceRPCClient{ + lookupResult: &internal.LookupResult{LogicalAddr: lookupURL, PhysicalAddr: lookupURL}, + cnx: cnx, + partitionOneSubscribed: partitionOneSubscribed, + partitionOneFlowed: partitionOneFlowed, + partitionTwoBlocked: partitionTwoBlocked, + allowPartitionTwo: allowPartitionTwo, + } + + c := newInternalTopicPartitionTestConsumer(internalTopicPartitionTestConsumerOptions{ + conn: cnx, + rpcClient: rpcClient, + partitions: 3, + log: plog.DefaultNopLogger(), + consumerOptions: ConsumerOptions{ + SubscriptionName: "test-sub", + ReceiverQueueSize: 1, + NackPrecisionBit: ptr(defaultNackPrecisionBit), + AckWithResponse: true, + }, + initialConsumers: []*partitionConsumer{{topic: "persistent://public/default/test-topic-partition-0"}}, + dlq: &dlqRouter{}, + }) + + errCh := make(chan error, 1) + go func() { + errCh <- c.internalTopicSubscribeToPartitions() + }() + + select { + case <-partitionOneSubscribed: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for partition 1 to subscribe") + } + + select { + case <-partitionTwoBlocked: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for partition 2 subscribe to block") + } + + require.Len(t, c.partitionConsumers(), 1) + select { + case <-partitionOneFlowed: + t.Fatal("new partition dispatcher requested permits before c.consumers contained the new partition") + case <-time.After(200 * time.Millisecond): + } + + close(allowPartitionTwo) + + select { + case err := <-errCh: + require.NoError(t, err) + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for partition discovery to finish") + } + require.Len(t, c.partitionConsumers(), 3) + + select { + case <-partitionOneFlowed: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for partition 1 dispatcher to request permits") + } + + handler := cnx.handler(rpcClient.partitionOneConsumerID.Load()) + require.NotNil(t, handler) + err = handler.MessageReceived(&pb.CommandMessage{ + MessageId: &pb.MessageIdData{ + LedgerId: proto.Uint64(1), + EntryId: proto.Uint64(1), + }, + }, internal.NewBufferWrapper(rawCompatSingleMessage)) + require.NoError(t, err) + + var cm ConsumerMessage + select { + case cm = <-c.messageCh: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for the queued partition 1 message to dispatch") + } + require.Equal(t, int32(1), cm.Message.ID().PartitionIdx()) + require.NoError(t, c.AckID(cm.Message.ID())) + + for _, pc := range c.partitionConsumers()[1:] { + pc.Close() + } +} + +type internalTopicPartitionTestConsumerOptions struct { + conn internal.Connection + rpcClient internal.RPCClient + partitions int + log plog.Logger + consumerOptions ConsumerOptions + initialConsumers []*partitionConsumer + dlq *dlqRouter +} + +func newInternalTopicPartitionTestConsumer(opts internalTopicPartitionTestConsumerOptions) *consumer { + var consumers atomic.Value + consumers.Store(append([]*partitionConsumer(nil), opts.initialConsumers...)) + + return &consumer{ + topic: "persistent://public/default/test-topic", + client: &client{ + cnxPool: &blockingConnPool{cnx: opts.conn}, + rpcClient: opts.rpcClient, + lookupService: &partitionMetadataLookup{partitions: opts.partitions}, + log: opts.log, + }, + options: opts.consumerOptions, + consumers: consumers, + messageCh: make(chan ConsumerMessage, 1), + closeCh: make(chan struct{}), + errorCh: make(chan error, 1), + consumerName: "test-consumer", + dlq: opts.dlq, + log: opts.log, + metrics: newTestMetrics(), + } +} + +type partitionMetadataLookup struct { + internal.LookupService + partitions int +} + +func (l *partitionMetadataLookup) GetPartitionedTopicMetadata(_ string) (*internal.PartitionedTopicMetadata, error) { + return &internal.PartitionedTopicMetadata{Partitions: l.partitions}, nil +} + +type blockingConnPool struct { + internal.ConnectionPool + cnx internal.Connection +} + +func (p *blockingConnPool) GetConnection(_ *url.URL, _ *url.URL, _ int32) (internal.Connection, error) { + return p.cnx, nil +} + +func (p *blockingConnPool) GetConnections() map[string]internal.Connection { + return map[string]internal.Connection{} +} + +func (p *blockingConnPool) GenerateRoundRobinIndex() int32 { + return 0 +} + +func (p *blockingConnPool) Close() {} + +type partitionExpansionRaceConnection struct { + dummyConnection + mu sync.Mutex + handlers map[uint64]internal.ConsumerHandler +} + +func newPartitionExpansionRaceConnection() *partitionExpansionRaceConnection { + return &partitionExpansionRaceConnection{handlers: make(map[uint64]internal.ConsumerHandler)} +} + +func (c *partitionExpansionRaceConnection) AddConsumeHandler(id uint64, handler internal.ConsumerHandler) error { + c.mu.Lock() + defer c.mu.Unlock() + c.handlers[id] = handler + return nil +} + +func (c *partitionExpansionRaceConnection) DeleteConsumeHandler(id uint64) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.handlers, id) +} + +func (c *partitionExpansionRaceConnection) handler(id uint64) internal.ConsumerHandler { + c.mu.Lock() + defer c.mu.Unlock() + return c.handlers[id] +} + +type partitionExpansionRaceRPCClient struct { + internal.RPCClient + lookupResult *internal.LookupResult + cnx *partitionExpansionRaceConnection + partitionOneSubscribed chan struct{} + partitionOneFlowed chan struct{} + partitionTwoBlocked chan struct{} + allowPartitionTwo chan struct{} + requestID atomic.Uint64 + consumerID atomic.Uint64 + partitionOneConsumerID atomic.Uint64 + partitionOneOnce sync.Once + partitionOneFlowOnce sync.Once + partitionTwoOnce sync.Once +} + +func (r *partitionExpansionRaceRPCClient) NewRequestID() uint64 { + return r.requestID.Add(1) +} + +func (r *partitionExpansionRaceRPCClient) NewProducerID() uint64 { + return r.requestID.Add(1) +} + +func (r *partitionExpansionRaceRPCClient) NewConsumerID() uint64 { + return r.consumerID.Add(1) +} + +func (r *partitionExpansionRaceRPCClient) RequestOnCnxNoWait( + _ internal.Connection, cmdType pb.BaseCommand_Type, msg proto.Message, +) error { + if cmdType == pb.BaseCommand_FLOW { + flow := msg.(*pb.CommandFlow) + if flow.GetConsumerId() == r.partitionOneConsumerID.Load() { + r.partitionOneFlowOnce.Do(func() { close(r.partitionOneFlowed) }) + } + } + return nil +} + +func (r *partitionExpansionRaceRPCClient) RequestOnCnx( + _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, msg proto.Message, +) (*internal.RPCResult, error) { + switch cmdType { + case pb.BaseCommand_SUBSCRIBE: + return r.handleSubscribe(msg.(*pb.CommandSubscribe)) + case pb.BaseCommand_ACK, pb.BaseCommand_CLOSE_CONSUMER: + return r.success(), nil + default: + return nil, fmt.Errorf("unexpected command type %v", cmdType) + } +} + +func (r *partitionExpansionRaceRPCClient) handleSubscribe(cmd *pb.CommandSubscribe) (*internal.RPCResult, error) { + switch { + case strings.HasSuffix(cmd.GetTopic(), "-partition-1"): + r.partitionOneConsumerID.Store(cmd.GetConsumerId()) + r.partitionOneOnce.Do(func() { close(r.partitionOneSubscribed) }) + return r.success(), nil + case strings.HasSuffix(cmd.GetTopic(), "-partition-2"): + r.partitionTwoOnce.Do(func() { close(r.partitionTwoBlocked) }) + <-r.allowPartitionTwo + return r.success(), nil + default: + return nil, fmt.Errorf("unexpected subscribe topic %s", cmd.GetTopic()) + } +} + +func (r *partitionExpansionRaceRPCClient) success() *internal.RPCResult { + successType := pb.BaseCommand_SUCCESS + return &internal.RPCResult{ + Response: &pb.BaseCommand{Type: &successType}, + Cnx: r.cnx, + } +} + +func (r *partitionExpansionRaceRPCClient) LookupService(_ string) (internal.LookupService, error) { + return &grabConnMockLookup{result: r.lookupResult}, nil +} + +type blockingSubscribeRPCClient struct { + internal.RPCClient + lookupResult *internal.LookupResult + subscribeStarted chan struct{} + allowSubscribe chan struct{} + subscribeErr error + nextConsumerID uint64 + startOnce sync.Once +} + +func (r *blockingSubscribeRPCClient) NewRequestID() uint64 { + return 1 +} + +func (r *blockingSubscribeRPCClient) NewProducerID() uint64 { + return 1 +} + +func (r *blockingSubscribeRPCClient) NewConsumerID() uint64 { + id := r.nextConsumerID + r.nextConsumerID++ + return id +} + +func (r *blockingSubscribeRPCClient) RequestOnCnxNoWait( + _ internal.Connection, _ pb.BaseCommand_Type, _ proto.Message) error { + return nil +} + +func (r *blockingSubscribeRPCClient) RequestOnCnx( + _ internal.Connection, _ uint64, cmdType pb.BaseCommand_Type, _ proto.Message, +) (*internal.RPCResult, error) { + switch cmdType { + case pb.BaseCommand_SUBSCRIBE: + r.startOnce.Do(func() { close(r.subscribeStarted) }) + <-r.allowSubscribe + return nil, r.subscribeErr + case pb.BaseCommand_CLOSE_CONSUMER: + return nil, nil + default: + return nil, fmt.Errorf("unexpected command type %v", cmdType) + } +} + +func (r *blockingSubscribeRPCClient) LookupService(_ string) (internal.LookupService, error) { + return &grabConnMockLookup{result: r.lookupResult}, nil +} + // closeInterceptor captures the (consumer, err) pair delivered to // ConsumerCloseInterceptor.OnConsumerClose and signals via fired. type closeInterceptor struct { @@ -6006,7 +6384,7 @@ func TestConsumerOnCloseInterceptorOnMaxReconnect(t *testing.T) { assert.NotNil(t, interceptor.err, "interceptor should receive the cause of the close") assert.Equal(t, testConsumer, interceptor.consumer, "interceptor should receive the parent consumer") - pc := testConsumer.(*consumer).consumers[0] + pc := testConsumer.(*consumer).partitionConsumers()[0] require.Eventually(t, func() bool { return pc.getConsumerState() == consumerClosed }, 30*time.Second, 100*time.Millisecond, "consumer should be closed after exhausting max reconnect retries") diff --git a/pulsar/consumer_zero_queue.go b/pulsar/consumer_zero_queue.go index 4978fae25..20a0944e3 100644 --- a/pulsar/consumer_zero_queue.go +++ b/pulsar/consumer_zero_queue.go @@ -80,7 +80,7 @@ func newZeroConsumer(client *client, options ConsumerOptions, topic string, pc.availablePermits.inc() } } - pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh, zc.dlq, zc.metrics) + pc, err := newPartitionConsumer(zc, zc.client, opts, zc.messageCh, zc.dlq, zc.metrics, true) if err != nil { return nil, err } diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go index 12f0517c4..2d6462e78 100644 --- a/pulsar/message_chunking_test.go +++ b/pulsar/message_chunking_test.go @@ -178,7 +178,7 @@ func TestMaxPendingChunkMessages(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, c) defer c.Close() - pc := c.(*consumer).consumers[0] + pc := c.(*consumer).partitionConsumers()[0] sendSingleChunk(producer, "0", 0, 2) // MaxPendingChunkedMessage is 1, the chunked message with uuid 0 will be discarded @@ -228,7 +228,7 @@ func TestExpireIncompleteChunks(t *testing.T) { defer c.Close() uuid := "test-uuid" - chunkCtxMap := c.(*consumer).consumers[0].chunkedMsgCtxMap + chunkCtxMap := c.(*consumer).partitionConsumers()[0].chunkedMsgCtxMap chunkCtxMap.addIfAbsent(uuid, 2, 100) ctx := chunkCtxMap.get(uuid) assert.NotNil(t, ctx) diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go index 55b05037f..ad65c2855 100644 --- a/pulsar/reader_impl.go +++ b/pulsar/reader_impl.go @@ -224,8 +224,9 @@ func (r *reader) SeekByTime(time time.Time) error { } func (r *reader) GetLastMessageID() (MessageID, error) { - if len(r.c.consumers) > 1 { + consumers := r.c.partitionConsumers() + if len(consumers) > 1 { return nil, fmt.Errorf("GetLastMessageID is not supported for multi-topics reader") } - return r.c.consumers[0].getLastMessageID() + return consumers[0].getLastMessageID() } diff --git a/pulsar/reader_test.go b/pulsar/reader_test.go index 29d61b2bc..d8ebc240d 100644 --- a/pulsar/reader_test.go +++ b/pulsar/reader_test.go @@ -952,7 +952,7 @@ func TestReaderWithBackoffPolicy(t *testing.T) { assert.NotNil(t, _reader) assert.Nil(t, err) - partitionConsumerImp := _reader.(*reader).c.consumers[0] + partitionConsumerImp := _reader.(*reader).c.partitionConsumers()[0] // 1 s startTime := time.Now() partitionConsumerImp.reconnectToBroker(nil) @@ -1061,7 +1061,7 @@ func TestReaderHasNextFailed(t *testing.T) { StartMessageID: EarliestMessageID(), }) assert.Nil(t, err) - r.(*reader).c.consumers[0].state.Store(consumerClosing) + r.(*reader).c.partitionConsumers()[0].state.Store(consumerClosing) assert.False(t, r.HasNext()) } @@ -1082,7 +1082,7 @@ func TestReaderHasNextRetryFailed(t *testing.T) { defer close(c) // Close the consumer events loop and assign a mock eventsCh - pc := r.(*reader).c.consumers[0] + pc := r.(*reader).c.partitionConsumers()[0] pc.Close() pc.state.Store(consumerReady) pc.eventsCh = c