diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index 78ae0d7759..0f9493788f 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -30,7 +30,9 @@ import ( "github.com/apache/pulsar-client-go/pulsar/log" ) -const defaultNackRedeliveryDelay = 1 * time.Minute +const ( + defaultNackRedeliveryDelay = 1 * time.Minute +) type acker interface { AckID(id trackingMessageID) diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index cf9294978d..8c24b7e452 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -175,7 +175,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon "subscription": options.subscription, "consumerID": pc.consumerID, }) - pc.nackTracker = newNegativeAcksTracker(pc, options.nackRedeliveryDelay, pc.log) + pc.nackTracker = newNegativeAcksTracker(pc, pc.log) err := pc.grabConn() if err != nil { @@ -301,7 +301,7 @@ func (pc *partitionConsumer) AckID(msgID trackingMessageID) { } func (pc *partitionConsumer) NackID(msgID trackingMessageID) { - pc.nackTracker.Add(msgID.messageID) + pc.nackTracker.Add(msgID.messageID, pc.options.nackRedeliveryDelay) pc.metrics.NacksCounter.Inc() } diff --git a/pulsar/internal/time_wheel.go b/pulsar/internal/time_wheel.go new file mode 100644 index 0000000000..4785b9882d --- /dev/null +++ b/pulsar/internal/time_wheel.go @@ -0,0 +1,153 @@ +package internal + +import ( + "errors" + "time" +) + +// Task means handle unit in time wheel +type Task struct { + delay time.Duration + key interface{} + round int // optimize time wheel to handle delay bigger than bucketsNum * tick + callback func() +} + +// TimeWheel means time wheel +type TimeWheel struct { + tick time.Duration + ticker *time.Ticker + + bucketsNum int + buckets []map[interface{}]*Task // key: added item, value: *Task + bucketIndexes map[interface{}]int // key: added item, value: bucket position + + currentIndex int + + addC chan *Task + removeC chan interface{} + stopC chan struct{} +} + +// NewTimeWheel create new time wheel +func NewTimeWheel(tick time.Duration, bucketsNum int) (*TimeWheel, error) { + if bucketsNum <= 0 { + return nil, errors.New("bucket number must be greater than 0") + } + if int(tick.Seconds()) < 1 { + return nil, errors.New("tick cannot be less than 1s") + } + + tw := &TimeWheel{ + tick: tick, + bucketsNum: bucketsNum, + bucketIndexes: make(map[interface{}]int, 1024), + buckets: make([]map[interface{}]*Task, bucketsNum), + currentIndex: 0, + addC: make(chan *Task, 1024), + removeC: make(chan interface{}, 1024), + stopC: make(chan struct{}), + } + + for i := 0; i < bucketsNum; i++ { + tw.buckets[i] = make(map[interface{}]*Task, 16) + } + + return tw, nil +} + +// Start start the time wheel +func (tw *TimeWheel) Start() { + tw.ticker = time.NewTicker(tw.tick) + go tw.start() +} + +func (tw *TimeWheel) start() { + for { + select { + case <-tw.ticker.C: + tw.handleTick() + case task := <-tw.addC: + tw.add(task) + case key := <-tw.removeC: + tw.remove(key) + case <-tw.stopC: + tw.ticker.Stop() + return + } + } +} + +// Stop stop the time wheel +func (tw *TimeWheel) Stop() { + tw.stopC <- struct{}{} +} + +func (tw *TimeWheel) handleTick() { + bucket := tw.buckets[tw.currentIndex] + for k := range bucket { + if bucket[k].round > 0 { + bucket[k].round-- + continue + } + go bucket[k].callback() + delete(bucket, k) + delete(tw.bucketIndexes, k) + } + if tw.currentIndex == tw.bucketsNum-1 { + tw.currentIndex = 0 + return + } + tw.currentIndex++ +} + +// Add add an item into time wheel +func (tw *TimeWheel) Add(delay time.Duration, key interface{}, callback func()) error { + if delay <= 0 || key == nil { + return errors.New("invalid params") + } + tw.addC <- &Task{delay: delay, key: key, callback: callback} + return nil +} + +func (tw *TimeWheel) add(task *Task) { + round := tw.calculateRound(task.delay) + index := tw.calculateIndex(task.delay) + task.round = round + if originIndex, ok := tw.bucketIndexes[task.key]; ok { + delete(tw.buckets[originIndex], task.key) + } + tw.bucketIndexes[task.key] = index + tw.buckets[index][task.key] = task +} + +func (tw *TimeWheel) calculateRound(delay time.Duration) (round int) { + delaySeconds := int(delay.Seconds()) + tickSeconds := int(tw.tick.Seconds()) + round = delaySeconds / tickSeconds / tw.bucketsNum + return +} + +func (tw *TimeWheel) calculateIndex(delay time.Duration) (index int) { + delaySeconds := int(delay.Seconds()) + tickSeconds := int(tw.tick.Seconds()) + index = (tw.currentIndex + delaySeconds/tickSeconds) % tw.bucketsNum + return +} + +// Remove remove an item from time wheel +func (tw *TimeWheel) Remove(key interface{}) error { + if key == nil { + return errors.New("invalid params") + } + tw.removeC <- key + return nil +} + +// don't need to call callback +func (tw *TimeWheel) remove(key interface{}) { + if index, ok := tw.bucketIndexes[key]; ok { + delete(tw.bucketIndexes, key) + delete(tw.buckets[index], key) + } +} diff --git a/pulsar/negative_acks_tracker.go b/pulsar/negative_acks_tracker.go index e10ab49c2d..1d50828929 100644 --- a/pulsar/negative_acks_tracker.go +++ b/pulsar/negative_acks_tracker.go @@ -21,6 +21,7 @@ import ( "sync" "time" + "github.com/apache/pulsar-client-go/pulsar/internal" log "github.com/apache/pulsar-client-go/pulsar/log" ) @@ -31,30 +32,36 @@ type redeliveryConsumer interface { type negativeAcksTracker struct { sync.Mutex - doneCh chan interface{} - doneOnce sync.Once - negativeAcks map[messageID]time.Time - rc redeliveryConsumer - tick *time.Ticker - delay time.Duration - log log.Logger + doneOnce sync.Once + rc redeliveryConsumer + log log.Logger + msgIds []messageID + tw *internal.TimeWheel } -func newNegativeAcksTracker(rc redeliveryConsumer, delay time.Duration, logger log.Logger) *negativeAcksTracker { +const ( + defaultCheckNegativeAcksBatchKey = "negative_acks_check_batch_key" + + negativeAcksbatchSize = 1024 + checkNegativeAcksBatchinterval = time.Second * 5 +) + +func newNegativeAcksTracker(rc redeliveryConsumer, logger log.Logger) *negativeAcksTracker { + tw, _ := internal.NewTimeWheel(time.Second*1, 1024) t := &negativeAcksTracker{ - doneCh: make(chan interface{}), - negativeAcks: make(map[messageID]time.Time), - rc: rc, - tick: time.NewTicker(delay / 3), - delay: delay, - log: logger, + rc: rc, + log: logger, + msgIds: make([]messageID, 0), + tw: tw, } - go t.track() + t.tw.Start() + t.tw.Add(checkNegativeAcksBatchinterval, defaultCheckNegativeAcksBatchKey, t.checkBatch) + return t } -func (t *negativeAcksTracker) Add(msgID messageID) { +func (t *negativeAcksTracker) Add(msgID messageID, negativeAckDelay time.Duration) { // Always clear up the batch index since we want to track the nack // for the entire batch batchMsgID := messageID{ @@ -63,57 +70,41 @@ func (t *negativeAcksTracker) Add(msgID messageID) { batchIdx: 0, } - t.Lock() - defer t.Unlock() + t.tw.Add(negativeAckDelay, batchMsgID, func() { + t.Lock() + t.msgIds = append(t.msgIds, batchMsgID) + if len(t.msgIds) >= negativeAcksbatchSize { + t.rc.Redeliver(t.msgIds) + t.msgIds = make([]messageID, 0) + } + t.Unlock() + }) +} - _, present := t.negativeAcks[batchMsgID] - if present { - // The batch is already being tracked - return +func (t *negativeAcksTracker) Remove(msgID messageID) { + batchMsgID := messageID{ + ledgerID: msgID.ledgerID, + entryID: msgID.entryID, + batchIdx: 0, } - targetTime := time.Now().Add(t.delay) - t.negativeAcks[batchMsgID] = targetTime + t.tw.Remove(batchMsgID) } -func (t *negativeAcksTracker) track() { - for { - select { - case <-t.doneCh: - t.log.Debug("Closing nack tracker") - return - - case <-t.tick.C: - { - now := time.Now() - msgIds := make([]messageID, 0) - - t.Lock() - - for msgID, targetTime := range t.negativeAcks { - t.log.Debugf("MsgId: %v -- targetTime: %v -- now: %v", msgID, targetTime, now) - if targetTime.Before(now) { - t.log.Debugf("Adding MsgId: %v", msgID) - msgIds = append(msgIds, msgID) - delete(t.negativeAcks, msgID) - } - } - - t.Unlock() - - if len(msgIds) > 0 { - t.rc.Redeliver(msgIds) - } - } - - } +func (t *negativeAcksTracker) checkBatch() { + t.Lock() + if len(t.msgIds) > 0 { + t.rc.Redeliver(t.msgIds) + t.msgIds = make([]messageID, 0) } + t.Unlock() + + t.tw.Add(checkNegativeAcksBatchinterval, defaultCheckNegativeAcksBatchKey, t.checkBatch) } func (t *negativeAcksTracker) Close() { // allow Close() to be invoked multiple times by consumer_partition to avoid panic t.doneOnce.Do(func() { - t.tick.Stop() - t.doneCh <- nil + t.tw.Stop() }) } diff --git a/pulsar/negative_acks_tracker_test.go b/pulsar/negative_acks_tracker_test.go index e587f3f19e..bec58a260b 100644 --- a/pulsar/negative_acks_tracker_test.go +++ b/pulsar/negative_acks_tracker_test.go @@ -75,19 +75,19 @@ func (nmc *nackMockedConsumer) Wait() <-chan messageID { func TestNacksTracker(t *testing.T) { nmc := newNackMockedConsumer() - nacks := newNegativeAcksTracker(nmc, testNackDelay, log.DefaultNopLogger()) + nacks := newNegativeAcksTracker(nmc, log.DefaultNopLogger()) nacks.Add(messageID{ ledgerID: 1, entryID: 1, batchIdx: 1, - }) + }, testNackDelay) nacks.Add(messageID{ ledgerID: 2, entryID: 2, batchIdx: 1, - }) + }, testNackDelay) msgIds := make([]messageID, 0) for id := range nmc.Wait() { @@ -108,31 +108,31 @@ func TestNacksTracker(t *testing.T) { func TestNacksWithBatchesTracker(t *testing.T) { nmc := newNackMockedConsumer() - nacks := newNegativeAcksTracker(nmc, testNackDelay, log.DefaultNopLogger()) + nacks := newNegativeAcksTracker(nmc, log.DefaultNopLogger()) nacks.Add(messageID{ ledgerID: 1, entryID: 1, batchIdx: 1, - }) + }, testNackDelay) nacks.Add(messageID{ ledgerID: 1, entryID: 1, batchIdx: 2, - }) + }, testNackDelay) nacks.Add(messageID{ ledgerID: 1, entryID: 1, batchIdx: 3, - }) + }, testNackDelay) nacks.Add(messageID{ ledgerID: 2, entryID: 2, batchIdx: 1, - }) + }, testNackDelay) msgIds := make([]messageID, 0) for id := range nmc.Wait() {