Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 52 additions & 68 deletions pulsar/consumer_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"math/rand"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/apache/pulsar-client-go/pulsar/crypto"
Expand All @@ -49,18 +50,11 @@ type acker interface {
}

type consumer struct {
sync.Mutex
topic string
client *client
options ConsumerOptions

// When accessing `consumers`, the lock must be acquired in case partitions are being added
// in the background by `internalTopicSubscribeToPartitions`. Currently, when a new sub-consumer
// is created, the current consumer can immediately receive messages from the new partition. However,
// before the new sub-consumers are visible in `consumers`, the Ack related methods cannot find the
// sub-consumer for the message's message ID, so we cannot simply change `consumers` to `atomic.Value`
// and perform copy-on-write when partitions are added.
consumers []*partitionConsumer
consumers atomic.Value
consumerName string
disableForceTopicCreation bool

Expand Down Expand Up @@ -356,14 +350,10 @@ func (c *consumer) internalTopicSubscribeToPartitions() error {
return err
}

oldNumPartitions := 0
newNumPartitions := len(partitions)

c.Lock()
defer c.Unlock()

oldConsumers := c.consumers
oldNumPartitions = len(oldConsumers)
oldConsumers := c.partitionConsumers()
oldNumPartitions := len(oldConsumers)

if oldConsumers != nil {
if oldNumPartitions == newNumPartitions {
Expand All @@ -376,14 +366,14 @@ func (c *consumer) internalTopicSubscribeToPartitions() error {
Info("Changed number of partitions in topic")
}

c.consumers = make([]*partitionConsumer, newNumPartitions)
newConsumers := make([]*partitionConsumer, newNumPartitions)

// When for some reason (eg: forced deletion of sub partition) causes oldNumPartitions> newNumPartitions,
// we need to rebuild the cache of new consumers, otherwise the array will be out of bounds.
if oldConsumers != nil && oldNumPartitions < newNumPartitions {
// Copy over the existing consumer instances
for i := 0; i < oldNumPartitions; i++ {
c.consumers[i] = oldConsumers[i]
newConsumers[i] = oldConsumers[i]
}
}

Expand All @@ -408,16 +398,16 @@ func (c *consumer) internalTopicSubscribeToPartitions() error {
for partitionIdx := startPartition; partitionIdx < newNumPartitions; partitionIdx++ {
partitionTopic := partitions[partitionIdx]

go func() {
go func(partitionIdx int, partitionTopic string) {
defer wg.Done()
opts := newPartitionConsumerOpts(partitionTopic, c.consumerName, partitionIdx, c.options)
cons, err := newPartitionConsumer(c, c.client, opts, c.messageCh, c.dlq, c.metrics)
cons, err := newPartitionConsumer(c, c.client, opts, c.messageCh, c.dlq, c.metrics, false)
ch <- ConsumerError{
err: err,
partition: partitionIdx,
consumer: cons,
}
}()
}(partitionIdx, partitionTopic)
}

go func() {
Expand All @@ -429,21 +419,25 @@ func (c *consumer) internalTopicSubscribeToPartitions() error {
if ce.err != nil {
err = ce.err
} else {
c.consumers[ce.partition] = ce.consumer
newConsumers[ce.partition] = ce.consumer
}
}

if err != nil {
// Since there were some failures,
// cleanup all the partitions that succeeded in creating the consumer
for _, c := range c.consumers {
for _, c := range newConsumers {
if c != nil {
c.Close()
}
}
return err
}

c.consumers.Store(append([]*partitionConsumer(nil), newConsumers...))
for partitionIdx := startPartition; partitionIdx < newNumPartitions; partitionIdx++ {
newConsumers[partitionIdx].startDispatcher()
}
if newNumPartitions < oldNumPartitions {
c.metrics.ConsumersPartitions.Set(float64(newNumPartitions))
} else {
Expand Down Expand Up @@ -510,11 +504,9 @@ func (c *consumer) UnsubscribeForce() error {
}

func (c *consumer) unsubscribe(force bool) error {
c.Lock()
defer c.Unlock()

consumers := c.partitionConsumers()
var errMsg string
for _, consumer := range c.consumers {
for _, consumer := range consumers {
if err := consumer.unsubscribe(force); err != nil {
errMsg += fmt.Sprintf("topic %s, subscription %s: %s", consumer.topic, c.Subscription(), err)
}
Expand All @@ -526,8 +518,9 @@ func (c *consumer) unsubscribe(force bool) error {
}

func (c *consumer) GetLastMessageIDs() ([]TopicMessageID, error) {
consumers := c.partitionConsumers()
ids := make([]TopicMessageID, 0)
for _, pc := range c.consumers {
for _, pc := range consumers {
id, err := pc.getLastMessageID()
tm := &topicMessageID{topic: pc.topic, track: id}
if err != nil {
Expand Down Expand Up @@ -556,7 +549,7 @@ func (c *consumer) Receive(ctx context.Context) (message Message, err error) {

func (c *consumer) AckWithTxn(msg Message, txn Transaction) error {
msgID := msg.ID()
consumer, err := c.findPartitionConsumer(msgID)
consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
Expand All @@ -575,7 +568,7 @@ func (c *consumer) Ack(msg Message) error {

// AckID the consumption of a single message, identified by its MessageID
func (c *consumer) AckID(msgID MessageID) error {
consumer, err := c.findPartitionConsumer(msgID)
consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
Expand All @@ -587,7 +580,7 @@ func (c *consumer) AckID(msgID MessageID) error {

func (c *consumer) AckIDList(msgIDs []MessageID) error {
return ackIDListFromMultiTopics(c.log, msgIDs, func(msgID MessageID) (acker, error) {
return c.findPartitionConsumer(msgID)
return findPartitionConsumer(c.partitionConsumers(), msgID)
})
}

Expand All @@ -600,7 +593,7 @@ func (c *consumer) AckCumulative(msg Message) error {
// AckIDCumulative the reception of all the messages in the stream up to (and including)
// the provided message, identified by its MessageID
func (c *consumer) AckIDCumulative(msgID MessageID) error {
consumer, err := c.findPartitionConsumer(msgID)
consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
Expand Down Expand Up @@ -697,7 +690,7 @@ func (c *consumer) Nack(msg Message) {
mid.NackByMsg(msg)
return
}
if consumer, err := c.findPartitionConsumer(mid); err == nil {
if consumer, err := findPartitionConsumer(c.partitionConsumers(), mid); err == nil {
consumer.NackMsg(msg)
}
return
Expand All @@ -707,7 +700,7 @@ func (c *consumer) Nack(msg Message) {
}

func (c *consumer) NackID(msgID MessageID) {
if consumer, err := c.findPartitionConsumer(msgID); err == nil {
if consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID); err == nil {
consumer.NackID(msgID)
}
}
Expand All @@ -724,37 +717,34 @@ func (c *consumer) closeWithCause(err error) {
c.closeOnce.Do(func() {
c.stopDiscovery()

c.Lock()
defer c.Unlock()

var wg sync.WaitGroup
for i := range c.consumers {
consumers := c.partitionConsumers()
for i := range consumers {
wg.Add(1)
go func(pc *partitionConsumer) {
defer wg.Done()
pc.Close()
}(c.consumers[i])
}(consumers[i])
}
wg.Wait()
close(c.closeCh)
c.client.handlers.Del(c)
c.dlq.close()
c.rlq.close()
c.metrics.ConsumersClosed.Inc()
c.metrics.ConsumersPartitions.Sub(float64(len(c.consumers)))
c.metrics.ConsumersPartitions.Sub(float64(len(consumers)))
c.options.Interceptors.OnConsumerClose(c, err)
})
}

func (c *consumer) Seek(msgID MessageID) error {
c.Lock()
defer c.Unlock()
consumers := c.partitionConsumers()

if len(c.consumers) > 1 {
if len(consumers) > 1 {
return newError(SeekFailed, "for partition topic, seek command should perform on the individual partitions")
}

consumer, err := c.unsafeFindPartitionConsumer(msgID)
consumer, err := findPartitionConsumer(consumers, msgID)
if err != nil {
return err
}
Expand All @@ -768,11 +758,10 @@ func (c *consumer) Seek(msgID MessageID) error {
}

func (c *consumer) SeekByTime(time time.Time) error {
c.Lock()
defer c.Unlock()
var errs error
consumers := c.partitionConsumers()

for _, cons := range c.consumers {
for _, cons := range consumers {
cons.pauseDispatchMessage()
}
// clear messageCh
Expand All @@ -781,7 +770,7 @@ func (c *consumer) SeekByTime(time time.Time) error {
}

// run SeekByTime on every partition of topic
for _, cons := range c.consumers {
for _, cons := range consumers {
if err := cons.SeekByTime(time); err != nil {
msg := fmt.Sprintf("unable to SeekByTime for topic=%s subscription=%s", c.topic, c.Subscription())
errs = pkgerrors.Wrap(newError(SeekFailed, err.Error()), msg)
Expand All @@ -791,35 +780,30 @@ func (c *consumer) SeekByTime(time time.Time) error {
return errs
}

func (c *consumer) findPartitionConsumer(msgID MessageID) (*partitionConsumer, error) {
c.Lock()
defer c.Unlock()
return c.unsafeFindPartitionConsumer(msgID)
}

// NOTE: This method must be called when c.Lock is held
func (c *consumer) unsafeFindPartitionConsumer(msgID MessageID) (*partitionConsumer, error) {
func findPartitionConsumer(consumers []*partitionConsumer, msgID MessageID) (*partitionConsumer, error) {
partition := int(msgID.PartitionIdx())
if partition < 0 || partition >= len(c.consumers) {
c.log.Errorf("invalid partition index %d expected a partition between [0-%d]",
partition, len(c.consumers))
if partition < 0 || partition >= len(consumers) {
return nil, fmt.Errorf("invalid partition index %d expected a partition between [0-%d]",
partition, len(c.consumers))
partition, len(consumers)-1)
}
return consumers[partition], nil
Comment thread
BewareMyPower marked this conversation as resolved.
}

func (c *consumer) partitionConsumers() []*partitionConsumer {
v := c.consumers.Load()
if v == nil {
return nil
}
return c.consumers[partition], nil
// The slice stored in c.consumers is published via copy-on-write.
// Callers must treat the returned slice as immutable.
return v.([]*partitionConsumer)
}

func (c *consumer) hasNext() bool {
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // Make sure all paths cancel the context to avoid context leak

// We have to make a snapshot consumers, because we have to iterate over all consumers in
// other goroutines. But when this method returns, there might be still other consumers
// not completing the `hasNext` call, so we cannot just call defer `c.Unlock()` after acquiring the lock.
c.Lock()
consumers := make([]*partitionConsumer, len(c.consumers))
copy(consumers, c.consumers)
c.Unlock()
consumers := c.partitionConsumers()

var wg sync.WaitGroup
wg.Add(len(consumers))
Expand Down Expand Up @@ -853,7 +837,7 @@ func (c *consumer) hasNext() bool {
}

func (c *consumer) setLastDequeuedMsg(msgID MessageID) error {
consumer, err := c.findPartitionConsumer(msgID)
consumer, err := findPartitionConsumer(c.partitionConsumers(), msgID)
if err != nil {
return err
}
Expand Down Expand Up @@ -920,7 +904,7 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) pb.CommandSubscribe_I
}

func (c *consumer) messageID(msgID MessageID) *trackingMessageID {
if _, err := c.findPartitionConsumer(msgID); err != nil {
if _, err := findPartitionConsumer(c.partitionConsumers(), msgID); err != nil {
return nil
}
return toTrackingMessageID(msgID)
Expand Down
15 changes: 10 additions & 5 deletions pulsar/consumer_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ func (s *schemaInfoCache) add(schemaVersionHash string, schema Schema) {
}

func newPartitionConsumer(parent Consumer, client *client, options *partitionConsumerOpts,
messageCh chan ConsumerMessage, dlq *dlqRouter,
metrics *internal.LeveledMetrics) (*partitionConsumer, error) {
messageCh chan ConsumerMessage, dlq *dlqRouter, metrics *internal.LeveledMetrics,
startDispatcher bool) (*partitionConsumer, error) {
var boFunc func() backoff.Policy
if options.backOffPolicyFunc != nil {
boFunc = options.backOffPolicyFunc
Expand All @@ -425,7 +425,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
queueCh: make(chan []*message, options.receiverQueueSize),
startMessageID: atomicMessageID{msgID: options.startMessageID},
seekMessageID: atomicMessageID{msgID: nil},
connectedCh: make(chan struct{}),
connectedCh: make(chan struct{}, 1),
messageCh: messageCh,
connectClosedCh: make(chan *connectionClosed, 1),
closeCh: make(chan struct{}),
Expand Down Expand Up @@ -512,13 +512,18 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon
}
}

go pc.dispatcher()

go pc.runEventsLoop()
if startDispatcher {
pc.startDispatcher()
}

return pc, nil
}

func (pc *partitionConsumer) startDispatcher() {
go pc.dispatcher()
}

func (pc *partitionConsumer) unsubscribe(force bool) error {
if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing {
pc.log.WithField("state", state).Error("Failed to unsubscribe closing or closed consumer")
Expand Down
Loading
Loading