diff --git a/pulsar/consumer.go b/pulsar/consumer.go index 574d6af213..ffef7860df 100644 --- a/pulsar/consumer.go +++ b/pulsar/consumer.go @@ -198,6 +198,16 @@ type ConsumerOptions struct { // error information of the Ack method only contains errors that may occur in the Go SDK's own processing. // Default: false AckWithResponse bool + + // MaxPendingChunkedMessage sets the maximum pending chunked messages. (default: 100) + MaxPendingChunkedMessage int + + // ExpireTimeOfIncompleteChunk sets the expiry time of discarding incomplete chunked message. (default: 60 seconds) + ExpireTimeOfIncompleteChunk time.Duration + + // AutoAckIncompleteChunk sets whether consumer auto acknowledges incomplete chunked message when it should + // be removed (e.g.the chunked message pending queue is full). (default: false) + AutoAckIncompleteChunk bool } // Consumer is an interface that abstracts behavior of Pulsar's consumer diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index 517824d504..a402a4d9f9 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -19,7 +19,6 @@ package pulsar import ( "context" - "errors" "fmt" "math/rand" "strconv" @@ -37,9 +36,9 @@ const defaultNackRedeliveryDelay = 1 * time.Minute type acker interface { // AckID does not handle errors returned by the Broker side, so no need to wait for doneCh to finish. - AckID(id trackingMessageID) error - AckIDWithResponse(id trackingMessageID) error - NackID(id trackingMessageID) + AckID(id MessageID) error + AckIDWithResponse(id MessageID) error + NackID(id MessageID) NackMsg(msg Message) } @@ -93,6 +92,14 @@ func newConsumer(client *client, options ConsumerOptions) (Consumer, error) { } } + if options.MaxPendingChunkedMessage == 0 { + options.MaxPendingChunkedMessage = 100 + } + + if options.ExpireTimeOfIncompleteChunk == 0 { + options.ExpireTimeOfIncompleteChunk = time.Minute + } + if options.NackBackoffPolicy == nil && options.EnableDefaultNackBackoffPolicy { options.NackBackoffPolicy = new(defaultNackBackoffPolicy) } @@ -344,28 +351,31 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { nackRedeliveryDelay = c.options.NackRedeliveryDelay } opts := &partitionConsumerOpts{ - topic: pt, - consumerName: c.consumerName, - subscription: c.options.SubscriptionName, - subscriptionType: c.options.Type, - subscriptionInitPos: c.options.SubscriptionInitialPosition, - partitionIdx: idx, - receiverQueueSize: receiverQueueSize, - nackRedeliveryDelay: nackRedeliveryDelay, - nackBackoffPolicy: c.options.NackBackoffPolicy, - metadata: metadata, - subProperties: subProperties, - replicateSubscriptionState: c.options.ReplicateSubscriptionState, - startMessageID: trackingMessageID{}, - subscriptionMode: durable, - readCompacted: c.options.ReadCompacted, - interceptors: c.options.Interceptors, - maxReconnectToBroker: c.options.MaxReconnectToBroker, - backoffPolicy: c.options.BackoffPolicy, - keySharedPolicy: c.options.KeySharedPolicy, - schema: c.options.Schema, - decryption: c.options.Decryption, - ackWithResponse: c.options.AckWithResponse, + topic: pt, + consumerName: c.consumerName, + subscription: c.options.SubscriptionName, + subscriptionType: c.options.Type, + subscriptionInitPos: c.options.SubscriptionInitialPosition, + partitionIdx: idx, + receiverQueueSize: receiverQueueSize, + nackRedeliveryDelay: nackRedeliveryDelay, + nackBackoffPolicy: c.options.NackBackoffPolicy, + metadata: metadata, + subProperties: subProperties, + replicateSubscriptionState: c.options.ReplicateSubscriptionState, + startMessageID: trackingMessageID{}, + subscriptionMode: durable, + readCompacted: c.options.ReadCompacted, + interceptors: c.options.Interceptors, + maxReconnectToBroker: c.options.MaxReconnectToBroker, + backoffPolicy: c.options.BackoffPolicy, + keySharedPolicy: c.options.KeySharedPolicy, + schema: c.options.Schema, + decryption: c.options.Decryption, + ackWithResponse: c.options.AckWithResponse, + maxPendingChunkedMessage: c.options.MaxPendingChunkedMessage, + expireTimeOfIncompleteChunk: c.options.ExpireTimeOfIncompleteChunk, + autoAckIncompleteChunk: c.options.AutoAckIncompleteChunk, } cons, err := newPartitionConsumer(c, c.client, opts, c.messageCh, c.dlq, c.metrics) ch <- ConsumerError{ @@ -456,20 +466,15 @@ func (c *consumer) Ack(msg Message) error { // AckID the consumption of a single message, identified by its MessageID func (c *consumer) AckID(msgID MessageID) error { - mid, ok := c.messageID(msgID) - if !ok { - return errors.New("failed to convert trackingMessageID") - } - - if mid.consumer != nil { - return mid.Ack() + if err := c.checkMsgIDPartition(msgID); err != nil { + return err } if c.options.AckWithResponse { - return c.consumers[mid.partitionIdx].AckIDWithResponse(mid) + return c.consumers[msgID.PartitionIdx()].AckIDWithResponse(msgID) } - return c.consumers[mid.partitionIdx].AckID(mid) + return c.consumers[msgID.PartitionIdx()].AckID(msgID) } // ReconsumeLater mark a message for redelivery after custom delay @@ -529,7 +534,7 @@ func (c *consumer) Nack(msg Message) { } if mid.consumer != nil { - mid.Nack() + mid.consumer.NackID(msg.ID()) return } c.consumers[mid.partitionIdx].NackMsg(msg) @@ -540,17 +545,11 @@ func (c *consumer) Nack(msg Message) { } func (c *consumer) NackID(msgID MessageID) { - mid, ok := c.messageID(msgID) - if !ok { - return - } - - if mid.consumer != nil { - mid.Nack() + if err := c.checkMsgIDPartition(msgID); err != nil { return } - c.consumers[mid.partitionIdx].NackID(mid) + c.consumers[msgID.PartitionIdx()].NackID(msgID) } func (c *consumer) Close() { @@ -586,12 +585,11 @@ func (c *consumer) Seek(msgID MessageID) error { return newError(SeekFailed, "for partition topic, seek command should perform on the individual partitions") } - mid, ok := c.messageID(msgID) - if !ok { - return nil + if err := c.checkMsgIDPartition(msgID); err != nil { + return err } - return c.consumers[mid.partitionIdx].Seek(mid) + return c.consumers[msgID.PartitionIdx()].Seek(msgID) } func (c *consumer) SeekByTime(time time.Time) error { @@ -608,6 +606,17 @@ func (c *consumer) SeekByTime(time time.Time) error { return errs } +func (c *consumer) checkMsgIDPartition(msgID MessageID) error { + partition := msgID.PartitionIdx() + if partition < 0 || int(partition) >= len(c.consumers) { + c.log.Errorf("invalid partition index %d expected a partition between [0-%d]", + partition, len(c.consumers)) + return fmt.Errorf("invalid partition index %d expected a partition between [0-%d]", + partition, len(c.consumers)) + } + return nil +} + var r = &random{ R: rand.New(rand.NewSource(time.Now().UnixNano())), } diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go index 380dd75379..d32ce31c50 100644 --- a/pulsar/consumer_multitopic.go +++ b/pulsar/consumer_multitopic.go @@ -137,10 +137,10 @@ func (c *multiTopicConsumer) AckID(msgID MessageID) error { } if c.options.AckWithResponse { - return mid.AckWithResponse() + return mid.consumer.AckIDWithResponse(msgID) } - return mid.Ack() + return mid.consumer.AckID(msgID) } func (c *multiTopicConsumer) ReconsumeLater(msg Message, delay time.Duration) { @@ -200,7 +200,7 @@ func (c *multiTopicConsumer) NackID(msgID MessageID) { return } - mid.Nack() + mid.consumer.NackID(msgID) } func (c *multiTopicConsumer) Close() { diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index 5b61e7dcb4..ebaa48b989 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -18,6 +18,7 @@ package pulsar import ( + "container/list" "encoding/hex" "errors" "fmt" @@ -92,30 +93,33 @@ const ( ) type partitionConsumerOpts struct { - topic string - consumerName string - subscription string - subscriptionType SubscriptionType - subscriptionInitPos SubscriptionInitialPosition - partitionIdx int - receiverQueueSize int - nackRedeliveryDelay time.Duration - nackBackoffPolicy NackBackoffPolicy - metadata map[string]string - subProperties map[string]string - replicateSubscriptionState bool - startMessageID trackingMessageID - startMessageIDInclusive bool - subscriptionMode subscriptionMode - readCompacted bool - disableForceTopicCreation bool - interceptors ConsumerInterceptors - maxReconnectToBroker *uint - backoffPolicy internal.BackoffPolicy - keySharedPolicy *KeySharedPolicy - schema Schema - decryption *MessageDecryptionInfo - ackWithResponse bool + topic string + consumerName string + subscription string + subscriptionType SubscriptionType + subscriptionInitPos SubscriptionInitialPosition + partitionIdx int + receiverQueueSize int + nackRedeliveryDelay time.Duration + nackBackoffPolicy NackBackoffPolicy + metadata map[string]string + subProperties map[string]string + replicateSubscriptionState bool + startMessageID trackingMessageID + startMessageIDInclusive bool + subscriptionMode subscriptionMode + readCompacted bool + disableForceTopicCreation bool + interceptors ConsumerInterceptors + maxReconnectToBroker *uint + backoffPolicy internal.BackoffPolicy + keySharedPolicy *KeySharedPolicy + schema Schema + decryption *MessageDecryptionInfo + ackWithResponse bool + maxPendingChunkedMessage int + expireTimeOfIncompleteChunk time.Duration + autoAckIncompleteChunk bool } type partitionConsumer struct { @@ -161,6 +165,9 @@ type partitionConsumer struct { metrics *internal.LeveledMetrics decryptor cryptointernal.Decryptor schemaInfoCache *schemaInfoCache + + chunkedMsgCtxMap *chunkedMsgCtxMap + unAckChunksTracker *unAckChunksTracker } type schemaInfoCache struct { @@ -236,6 +243,8 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon schemaInfoCache: newSchemaInfoCache(client, options.topic), availablePermitsCh: make(chan permitsReq, 10), } + pc.chunkedMsgCtxMap = newChunkedMsgCtxMap(options.maxPendingChunkedMessage, pc) + pc.unAckChunksTracker = newUnAckChunksTracker(pc) pc.setConsumerState(consumerInit) pc.log = client.log.SubLogger(log.Fields{ "name": pc.name, @@ -378,18 +387,27 @@ func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error return convertToMessageID(id), nil } -func (pc *partitionConsumer) AckIDWithResponse(msgID trackingMessageID) error { +func (pc *partitionConsumer) AckIDWithResponse(msgID MessageID) error { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to ack by closing or closed consumer") return errors.New("consumer state is closed") } + if cmid, ok := toChunkedMessageID(msgID); ok { + return pc.unAckChunksTracker.ack(cmid) + } + + trackingID, ok := toTrackingMessageID(msgID) + if !ok { + return errors.New("failed to convert trackingMessageID") + } + ackReq := new(ackRequest) ackReq.doneCh = make(chan struct{}) - if !msgID.Undefined() && msgID.ack() { + if !trackingID.Undefined() && trackingID.ack() { pc.metrics.AcksCounter.Inc() - pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-msgID.receivedTime.UnixNano()) / 1.0e9) - ackReq.msgID = msgID + pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-trackingID.receivedTime.UnixNano()) / 1.0e9) + ackReq.msgID = trackingID // send ack request to eventsCh pc.eventsCh <- ackReq // wait for the request to complete @@ -401,18 +419,27 @@ func (pc *partitionConsumer) AckIDWithResponse(msgID trackingMessageID) error { return ackReq.err } -func (pc *partitionConsumer) AckID(msgID trackingMessageID) error { +func (pc *partitionConsumer) AckID(msgID MessageID) error { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to ack by closing or closed consumer") return errors.New("consumer state is closed") } + if cmid, ok := toChunkedMessageID(msgID); ok { + return pc.unAckChunksTracker.ack(cmid) + } + + trackingID, ok := toTrackingMessageID(msgID) + if !ok { + return errors.New("failed to convert trackingMessageID") + } + ackReq := new(ackRequest) ackReq.doneCh = make(chan struct{}) - if !msgID.Undefined() && msgID.ack() { + if !trackingID.Undefined() && trackingID.ack() { pc.metrics.AcksCounter.Inc() - pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-msgID.receivedTime.UnixNano()) / 1.0e9) - ackReq.msgID = msgID + pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-trackingID.receivedTime.UnixNano()) / 1.0e9) + ackReq.msgID = trackingID // send ack request to eventsCh pc.eventsCh <- ackReq // No need to wait for ackReq.doneCh to finish @@ -423,8 +450,18 @@ func (pc *partitionConsumer) AckID(msgID trackingMessageID) error { return ackReq.err } -func (pc *partitionConsumer) NackID(msgID trackingMessageID) { - pc.nackTracker.Add(msgID.messageID) +func (pc *partitionConsumer) NackID(msgID MessageID) { + if cmid, ok := toChunkedMessageID(msgID); ok { + pc.unAckChunksTracker.nack(cmid) + return + } + + trackingID, ok := toTrackingMessageID(msgID) + if !ok { + return + } + + pc.nackTracker.Add(trackingID.messageID) pc.metrics.NacksCounter.Inc() } @@ -487,6 +524,9 @@ func (pc *partitionConsumer) Close() { return } + // close chunkedMsgCtxMap + pc.chunkedMsgCtxMap.Close() + req := &closeRequest{doneCh: make(chan struct{})} pc.eventsCh <- req @@ -494,15 +534,23 @@ func (pc *partitionConsumer) Close() { <-req.doneCh } -func (pc *partitionConsumer) Seek(msgID trackingMessageID) error { +func (pc *partitionConsumer) Seek(msgID MessageID) error { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to seek by closing or closed consumer") return errors.New("failed to seek by closing or closed consumer") } req := &seekRequest{ doneCh: make(chan struct{}), - msgID: msgID, } + if cmid, ok := toChunkedMessageID(msgID); ok { + req.msgID = cmid.firstChunkID + } else if tmid, ok := toTrackingMessageID(msgID); ok { + req.msgID = tmid.messageID + } else { + // will never reach + return errors.New("unhandled messageID type") + } + pc.eventsCh <- req // wait for the request to complete @@ -512,7 +560,7 @@ func (pc *partitionConsumer) Seek(msgID trackingMessageID) error { func (pc *partitionConsumer) internalSeek(seek *seekRequest) { defer close(seek.doneCh) - seek.err = pc.requestSeek(seek.msgID.messageID) + seek.err = pc.requestSeek(seek.msgID) } func (pc *partitionConsumer) requestSeek(msgID messageID) error { if err := pc.requestSeekWithoutClear(msgID); err != nil { @@ -698,8 +746,21 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header } } + isChunkedMsg := false + if msgMeta.GetNumChunksFromMsg() > 1 { + isChunkedMsg = true + } + + processedPayloadBuffer := internal.NewBufferWrapper(decryptedPayload) + if isChunkedMsg { + processedPayloadBuffer = pc.processMessageChunk(processedPayloadBuffer, msgMeta, pbMsgID) + if processedPayloadBuffer == nil { + return nil + } + } + // decryption is success, decompress the payload - uncompressedHeadersAndPayload, err := pc.Decompress(msgMeta, internal.NewBufferWrapper(decryptedPayload)) + uncompressedHeadersAndPayload, err := pc.Decompress(msgMeta, processedPayloadBuffer) if err != nil { pc.discardCorruptedMessage(pbMsgID, pb.CommandAck_DecompressionError) return err @@ -733,17 +794,40 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header pc.metrics.BytesReceived.Add(float64(len(payload))) pc.metrics.PrefetchedBytes.Add(float64(len(payload))) - msgID := newTrackingMessageID( + trackingMsgID := newTrackingMessageID( int64(pbMsgID.GetLedgerId()), int64(pbMsgID.GetEntryId()), int32(i), pc.partitionIdx, ackTracker) + // set the consumer so we know how to ack the message id + trackingMsgID.consumer = pc - if pc.messageShouldBeDiscarded(msgID) { - pc.AckID(msgID) + if pc.messageShouldBeDiscarded(trackingMsgID) { + pc.AckID(trackingMsgID) continue } + + var msgID MessageID + if isChunkedMsg { + ctx := pc.chunkedMsgCtxMap.get(msgMeta.GetUuid()) + if ctx == nil { + // chunkedMsgCtxMap has closed because of consumer closed + pc.log.Warnf("get chunkedMsgCtx for chunk with uuid %s failed because consumer has closed", + msgMeta.Uuid) + return nil + } + cmid := newChunkMessageID(ctx.firstChunkID(), ctx.lastChunkID()) + // set the consumer so we know how to ack the message id + cmid.consumer = pc + // clean chunkedMsgCtxMap + pc.chunkedMsgCtxMap.remove(msgMeta.GetUuid()) + pc.unAckChunksTracker.add(cmid, ctx.chunkedMsgIDs) + msgID = cmid + } else { + msgID = trackingMsgID + } + var messageIndex *uint64 var brokerPublishTime *time.Time if brokerMetadata != nil { @@ -756,8 +840,7 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header brokerPublishTime = &aux } } - // set the consumer so we know how to ack the message id - msgID.consumer = pc + var msg *message if smm != nil { msg = &message{ @@ -813,6 +896,55 @@ func (pc *partitionConsumer) MessageReceived(response *pb.CommandMessage, header return nil } +func (pc *partitionConsumer) processMessageChunk(compressedPayload internal.Buffer, + msgMeta *pb.MessageMetadata, + pbMsgID *pb.MessageIdData) internal.Buffer { + uuid := msgMeta.GetUuid() + numChunks := msgMeta.GetNumChunksFromMsg() + totalChunksSize := int(msgMeta.GetTotalChunkMsgSize()) + chunkID := msgMeta.GetChunkId() + msgID := messageID{ + ledgerID: int64(pbMsgID.GetLedgerId()), + entryID: int64(pbMsgID.GetEntryId()), + batchIdx: -1, + partitionIdx: pc.partitionIdx, + } + + if msgMeta.GetChunkId() == 0 { + pc.chunkedMsgCtxMap.addIfAbsent(uuid, + numChunks, + totalChunksSize, + ) + } + + ctx := pc.chunkedMsgCtxMap.get(uuid) + + if ctx == nil || ctx.chunkedMsgBuffer == nil || chunkID != ctx.lastChunkedMsgID+1 { + lastChunkedMsgID := -1 + totalChunks := -1 + if ctx != nil { + lastChunkedMsgID = int(ctx.lastChunkedMsgID) + totalChunks = int(ctx.totalChunks) + ctx.chunkedMsgBuffer.Clear() + } + pc.log.Warnf(fmt.Sprintf( + "Received unexpected chunk messageId %s, last-chunk-id %d, chunkId = %d, total-chunks %d", + msgID.String(), lastChunkedMsgID, chunkID, totalChunks)) + pc.chunkedMsgCtxMap.remove(uuid) + pc.availablePermitsCh <- permitsInc + return nil + } + + ctx.append(chunkID, msgID, compressedPayload) + + if msgMeta.GetChunkId() != msgMeta.GetNumChunksFromMsg()-1 { + pc.availablePermitsCh <- permitsInc + return nil + } + + return ctx.chunkedMsgBuffer +} + func (pc *partitionConsumer) messageShouldBeDiscarded(msgID trackingMessageID) bool { if pc.startMessageID.Undefined() { return false @@ -1045,7 +1177,7 @@ type getLastMsgIDRequest struct { type seekRequest struct { doneCh chan struct{} - msgID trackingMessageID + msgID messageID err error } @@ -1508,3 +1640,227 @@ func convertToMessageID(id *pb.MessageIdData) trackingMessageID { return msgID } + +type chunkedMsgCtx struct { + totalChunks int32 + chunkedMsgBuffer internal.Buffer + lastChunkedMsgID int32 + chunkedMsgIDs []messageID + receivedTime int64 + + mu sync.Mutex +} + +func newChunkedMsgCtx(numChunksFromMsg int32, totalChunkMsgSize int) *chunkedMsgCtx { + return &chunkedMsgCtx{ + totalChunks: numChunksFromMsg, + chunkedMsgBuffer: internal.NewBuffer(totalChunkMsgSize), + lastChunkedMsgID: -1, + chunkedMsgIDs: make([]messageID, numChunksFromMsg), + receivedTime: time.Now().Unix(), + } +} + +func (c *chunkedMsgCtx) append(chunkID int32, msgID messageID, partPayload internal.Buffer) { + c.mu.Lock() + defer c.mu.Unlock() + c.chunkedMsgIDs[chunkID] = msgID + c.chunkedMsgBuffer.Write(partPayload.ReadableSlice()) + c.lastChunkedMsgID = chunkID +} + +func (c *chunkedMsgCtx) firstChunkID() messageID { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.chunkedMsgIDs) == 0 { + return messageID{} + } + return c.chunkedMsgIDs[0] +} + +func (c *chunkedMsgCtx) lastChunkID() messageID { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.chunkedMsgIDs) == 0 { + return messageID{} + } + return c.chunkedMsgIDs[len(c.chunkedMsgIDs)-1] +} + +func (c *chunkedMsgCtx) discard(pc *partitionConsumer) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, mid := range c.chunkedMsgIDs { + pc.log.Info("Removing chunk message-id", mid.String()) + tmid, _ := toTrackingMessageID(mid) + pc.AckID(tmid) + } +} + +type chunkedMsgCtxMap struct { + chunkedMsgCtxs map[string]*chunkedMsgCtx + pendingQueue *list.List + maxPending int + pc *partitionConsumer + mu sync.Mutex + closed bool +} + +func newChunkedMsgCtxMap(maxPending int, pc *partitionConsumer) *chunkedMsgCtxMap { + return &chunkedMsgCtxMap{ + chunkedMsgCtxs: make(map[string]*chunkedMsgCtx, maxPending), + pendingQueue: list.New(), + maxPending: maxPending, + pc: pc, + mu: sync.Mutex{}, + } +} + +func (c *chunkedMsgCtxMap) addIfAbsent(uuid string, totalChunks int32, totalChunkMsgSize int) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + if _, ok := c.chunkedMsgCtxs[uuid]; !ok { + c.chunkedMsgCtxs[uuid] = newChunkedMsgCtx(totalChunks, totalChunkMsgSize) + c.pendingQueue.PushBack(uuid) + go c.discardChunkIfExpire(uuid, true, c.pc.options.expireTimeOfIncompleteChunk) + } + if c.maxPending > 0 && c.pendingQueue.Len() > c.maxPending { + go c.discardOldestChunkMessage(c.pc.options.autoAckIncompleteChunk) + } +} + +func (c *chunkedMsgCtxMap) get(uuid string) *chunkedMsgCtx { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil + } + return c.chunkedMsgCtxs[uuid] +} + +func (c *chunkedMsgCtxMap) remove(uuid string) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + delete(c.chunkedMsgCtxs, uuid) + e := c.pendingQueue.Front() + for ; e != nil; e = e.Next() { + if e.Value.(string) == uuid { + c.pendingQueue.Remove(e) + break + } + } +} + +func (c *chunkedMsgCtxMap) discardOldestChunkMessage(autoAck bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed || (c.maxPending > 0 && c.pendingQueue.Len() <= c.maxPending) { + return + } + oldest := c.pendingQueue.Front().Value.(string) + ctx, ok := c.chunkedMsgCtxs[oldest] + if !ok { + return + } + if autoAck { + ctx.discard(c.pc) + } + delete(c.chunkedMsgCtxs, oldest) + c.pc.log.Infof("Chunked message [%s] has been removed from chunkedMsgCtxMap", oldest) +} + +func (c *chunkedMsgCtxMap) discardChunkMessage(uuid string, autoAck bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return + } + ctx, ok := c.chunkedMsgCtxs[uuid] + if !ok { + return + } + if autoAck { + ctx.discard(c.pc) + } + delete(c.chunkedMsgCtxs, uuid) + e := c.pendingQueue.Front() + for ; e != nil; e = e.Next() { + if e.Value.(string) == uuid { + c.pendingQueue.Remove(e) + break + } + } + c.pc.log.Infof("Chunked message [%s] has been removed from chunkedMsgCtxMap", uuid) +} + +func (c *chunkedMsgCtxMap) discardChunkIfExpire(uuid string, autoAck bool, expire time.Duration) { + timer := time.NewTimer(expire) + <-timer.C + c.discardChunkMessage(uuid, autoAck) +} + +func (c *chunkedMsgCtxMap) Close() { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true +} + +type unAckChunksTracker struct { + chunkIDs map[chunkMessageID][]messageID + pc *partitionConsumer + mu sync.Mutex +} + +func newUnAckChunksTracker(pc *partitionConsumer) *unAckChunksTracker { + return &unAckChunksTracker{ + chunkIDs: make(map[chunkMessageID][]messageID), + pc: pc, + } +} + +func (u *unAckChunksTracker) add(cmid chunkMessageID, ids []messageID) { + u.mu.Lock() + defer u.mu.Unlock() + + u.chunkIDs[cmid] = ids +} + +func (u *unAckChunksTracker) get(cmid chunkMessageID) []messageID { + u.mu.Lock() + defer u.mu.Unlock() + + return u.chunkIDs[cmid] +} + +func (u *unAckChunksTracker) remove(cmid chunkMessageID) { + u.mu.Lock() + defer u.mu.Unlock() + + delete(u.chunkIDs, cmid) +} + +func (u *unAckChunksTracker) ack(cmid chunkMessageID) error { + ids := u.get(cmid) + for _, id := range ids { + if err := u.pc.AckID(id); err != nil { + return err + } + } + u.remove(cmid) + return nil +} + +func (u *unAckChunksTracker) nack(cmid chunkMessageID) { + ids := u.get(cmid) + for _, id := range ids { + u.pc.NackID(id) + } + u.remove(cmid) +} diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go index c55a1c1143..55f3d7a37b 100644 --- a/pulsar/consumer_regex.go +++ b/pulsar/consumer_regex.go @@ -181,10 +181,10 @@ func (c *regexConsumer) AckID(msgID MessageID) error { } if c.options.AckWithResponse { - return mid.AckWithResponse() + return mid.consumer.AckIDWithResponse(msgID) } - return mid.Ack() + return mid.consumer.AckID(msgID) } func (c *regexConsumer) Nack(msg Message) { @@ -219,7 +219,7 @@ func (c *regexConsumer) NackID(msgID MessageID) { return } - mid.Nack() + mid.consumer.NackID(msgID) } func (c *regexConsumer) Close() { diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go index d155ae776c..11ecb40f21 100644 --- a/pulsar/impl_message.go +++ b/pulsar/impl_message.go @@ -210,11 +210,25 @@ func toTrackingMessageID(msgID MessageID) (trackingMessageID, bool) { }, true } else if mid, ok := msgID.(trackingMessageID); ok { return mid, true + } else if cmid, ok := msgID.(chunkMessageID); ok { + return trackingMessageID{ + messageID: cmid.messageID, + receivedTime: cmid.receivedTime, + consumer: cmid.consumer, + }, true } else { return trackingMessageID{}, false } } +func toChunkedMessageID(msgID MessageID) (chunkMessageID, bool) { + cid, ok := msgID.(chunkMessageID) + if ok { + return cid, true + } + return chunkMessageID{}, false +} + func timeFromUnixTimestampMillis(timestamp uint64) time.Time { ts := int64(timestamp) * int64(time.Millisecond) seconds := ts / int64(time.Second) @@ -372,3 +386,41 @@ func (t *ackTracker) completed() bool { defer t.Unlock() return len(t.batchIDs.Bits()) == 0 } + +type chunkMessageID struct { + messageID + + firstChunkID messageID + receivedTime time.Time + + consumer acker +} + +func newChunkMessageID(firstChunkID messageID, lastChunkID messageID) chunkMessageID { + return chunkMessageID{ + messageID: lastChunkID, + firstChunkID: firstChunkID, + receivedTime: time.Now(), + } +} + +func (id chunkMessageID) String() string { + return fmt.Sprintf("%s;%s", id.firstChunkID.String(), id.messageID.String()) +} + +func (id chunkMessageID) Serialize() []byte { + msgID := &pb.MessageIdData{ + LedgerId: proto.Uint64(uint64(id.ledgerID)), + EntryId: proto.Uint64(uint64(id.entryID)), + BatchIndex: proto.Int32(id.batchIdx), + Partition: proto.Int32(id.partitionIdx), + FirstChunkMessageId: &pb.MessageIdData{ + LedgerId: proto.Uint64(uint64(id.firstChunkID.ledgerID)), + EntryId: proto.Uint64(uint64(id.firstChunkID.entryID)), + BatchIndex: proto.Int32(id.firstChunkID.batchIdx), + Partition: proto.Int32(id.firstChunkID.partitionIdx), + }, + } + data, _ := proto.Marshal(msgID) + return data +} diff --git a/pulsar/internal/batch_builder.go b/pulsar/internal/batch_builder.go index fe8f62808a..ca19a3ed71 100644 --- a/pulsar/internal/batch_builder.go +++ b/pulsar/internal/batch_builder.go @@ -35,7 +35,7 @@ type BuffersPool interface { // BatcherBuilderProvider defines func which returns the BatchBuilder. type BatcherBuilderProvider func( - maxMessages uint, maxBatchSize uint, producerName string, producerID uint64, + maxMessages uint, maxBatchSize uint, maxMessageSize uint32, producerName string, producerID uint64, compressionType pb.CompressionType, level compression.Level, bufferPool BuffersPool, logger log.Logger, encryptor crypto.Encryptor, ) (BatchBuilder, error) @@ -85,6 +85,8 @@ type batchContainer struct { // without needing costly re-allocations. maxBatchSize uint + maxMessageSize uint32 + producerName string producerID uint64 @@ -102,18 +104,19 @@ type batchContainer struct { // newBatchContainer init a batchContainer func newBatchContainer( - maxMessages uint, maxBatchSize uint, producerName string, producerID uint64, + maxMessages uint, maxBatchSize uint, maxMessageSize uint32, producerName string, producerID uint64, compressionType pb.CompressionType, level compression.Level, bufferPool BuffersPool, logger log.Logger, encryptor crypto.Encryptor, ) batchContainer { bc := batchContainer{ - buffer: NewBuffer(4096), - numMessages: 0, - maxMessages: maxMessages, - maxBatchSize: maxBatchSize, - producerName: producerName, - producerID: producerID, + buffer: NewBuffer(4096), + numMessages: 0, + maxMessages: maxMessages, + maxBatchSize: maxBatchSize, + maxMessageSize: maxMessageSize, + producerName: producerName, + producerID: producerID, cmdSend: baseCommand( pb.BaseCommand_SEND, &pb.CommandSend{ @@ -124,7 +127,7 @@ func newBatchContainer( ProducerName: &producerName, }, callbacks: []interface{}{}, - compressionProvider: getCompressionProvider(compressionType, level), + compressionProvider: GetCompressionProvider(compressionType, level), buffersPool: bufferPool, log: logger, encryptor: encryptor, @@ -139,13 +142,13 @@ func newBatchContainer( // NewBatchBuilder init batch builder and return BatchBuilder pointer. Build a new batch message container. func NewBatchBuilder( - maxMessages uint, maxBatchSize uint, producerName string, producerID uint64, + maxMessages uint, maxBatchSize uint, maxMessageSize uint32, producerName string, producerID uint64, compressionType pb.CompressionType, level compression.Level, bufferPool BuffersPool, logger log.Logger, encryptor crypto.Encryptor, ) (BatchBuilder, error) { bc := newBatchContainer( - maxMessages, maxBatchSize, producerName, producerID, compressionType, + maxMessages, maxBatchSize, maxMessageSize, producerName, producerID, compressionType, level, bufferPool, logger, encryptor, ) @@ -164,7 +167,9 @@ func (bc *batchContainer) hasSpace(payload []byte) bool { return true } msgSize := uint32(len(payload)) - return bc.numMessages+1 <= bc.maxMessages && bc.buffer.ReadableBytes()+msgSize <= uint32(bc.maxBatchSize) + expectedSize := bc.buffer.ReadableBytes() + msgSize + return bc.numMessages+1 <= bc.maxMessages && + expectedSize <= uint32(bc.maxBatchSize) && expectedSize <= bc.maxMessageSize } func (bc *batchContainer) hasSameSchema(schemaVersion []byte) bool { @@ -258,8 +263,9 @@ func (bc *batchContainer) Flush() ( buffer = NewBuffer(int(uncompressedSize * 3 / 2)) } - if err = serializeBatch( - buffer, bc.cmdSend, bc.msgMetadata, bc.buffer, bc.compressionProvider, bc.encryptor, + if err = serializeMessage( + buffer, bc.cmdSend, bc.msgMetadata, bc.buffer, bc.compressionProvider, + bc.encryptor, bc.maxMessageSize, true, ); err == nil { // no error in serializing Batch sequenceID = bc.cmdSend.Send.GetSequenceId() } @@ -285,7 +291,7 @@ func (bc *batchContainer) Close() error { return bc.compressionProvider.Close() } -func getCompressionProvider( +func GetCompressionProvider( compressionType pb.CompressionType, level compression.Level, ) compression.Provider { diff --git a/pulsar/internal/commands.go b/pulsar/internal/commands.go index 94217a50b4..c28593c6b8 100644 --- a/pulsar/internal/commands.go +++ b/pulsar/internal/commands.go @@ -50,6 +50,8 @@ var ErrEOM = errors.New("EOF") var ErrConnectionClosed = errors.New("connection closed") +var ErrExceedMaxMessageSize = errors.New("encryptedPayload exceeds MaxMessageSize") + func NewMessageReader(headersAndPayload Buffer) *MessageReader { return &MessageReader{ buffer: headersAndPayload, @@ -237,17 +239,24 @@ func addSingleMessageToBatch(wb Buffer, smm *pb.SingleMessageMetadata, payload [ wb.Write(payload) } -func serializeBatch(wb Buffer, +func serializeMessage(wb Buffer, cmdSend *pb.BaseCommand, msgMetadata *pb.MessageMetadata, - uncompressedPayload Buffer, + payload Buffer, compressionProvider compression.Provider, - encryptor crypto.Encryptor) error { + encryptor crypto.Encryptor, + maxMessageSize uint32, + doCompress bool) error { // Wire format // [TOTAL_SIZE] [CMD_SIZE][CMD] [MAGIC_NUMBER][CHECKSUM] [METADATA_SIZE][METADATA] [PAYLOAD] // compress the payload - compressedPayload := compressionProvider.Compress(nil, uncompressedPayload.ReadableSlice()) + var compressedPayload []byte + if doCompress { + compressedPayload = compressionProvider.Compress(nil, payload.ReadableSlice()) + } else { + compressedPayload = payload.ReadableSlice() + } // encrypt the compressed payload encryptedPayload, err := encryptor.Encrypt(compressedPayload, msgMetadata) @@ -258,6 +267,13 @@ func serializeBatch(wb Buffer, cmdSize := uint32(proto.Size(cmdSend)) msgMetadataSize := uint32(proto.Size(msgMetadata)) + msgSize := len(encryptedPayload) + int(msgMetadataSize) + + // the maxMessageSize check of batching message is in here + if !(msgMetadata.GetTotalChunkMsgSize() != 0) && msgSize > int(maxMessageSize) { + return fmt.Errorf("%w, size: %d, MaxMessageSize: %d", + ErrExceedMaxMessageSize, msgSize, maxMessageSize) + } frameSizeIdx := wb.WriterIndex() wb.WriteUint32(0) // Skip frame size until we now the size @@ -300,6 +316,28 @@ func serializeBatch(wb Buffer, return nil } +func SingleSend(wb Buffer, + producerID, sequenceID uint64, + msgMetadata *pb.MessageMetadata, + compressedPayload Buffer, + encryptor crypto.Encryptor, + maxMassageSize uint32) error { + cmdSend := baseCommand( + pb.BaseCommand_SEND, + &pb.CommandSend{ + ProducerId: &producerID, + }, + ) + cmdSend.Send.SequenceId = &sequenceID + if msgMetadata.GetTotalChunkMsgSize() > 1 { + isChunk := true + cmdSend.Send.IsChunk = &isChunk + } + // payload has been compressed so compressionProvider can be nil + return serializeMessage(wb, cmdSend, msgMetadata, compressedPayload, + nil, encryptor, maxMassageSize, false) +} + // ConvertFromStringMap convert a string map to a KeyValue []byte func ConvertFromStringMap(m map[string]string) []*pb.KeyValue { list := make([]*pb.KeyValue, len(m)) diff --git a/pulsar/internal/key_based_batch_builder.go b/pulsar/internal/key_based_batch_builder.go index 77fbb8c77a..334e6746ab 100644 --- a/pulsar/internal/key_based_batch_builder.go +++ b/pulsar/internal/key_based_batch_builder.go @@ -84,7 +84,7 @@ func (h *keyBasedBatches) Val(key string) *batchContainer { // NewKeyBasedBatchBuilder init batch builder and return BatchBuilder // pointer. Build a new key based batch message container. func NewKeyBasedBatchBuilder( - maxMessages uint, maxBatchSize uint, producerName string, producerID uint64, + maxMessages uint, maxBatchSize uint, maxMessageSize uint32, producerName string, producerID uint64, compressionType pb.CompressionType, level compression.Level, bufferPool BuffersPool, logger log.Logger, encryptor crypto.Encryptor, ) (BatchBuilder, error) { @@ -92,7 +92,7 @@ func NewKeyBasedBatchBuilder( bb := &keyBasedBatchContainer{ batches: newKeyBasedBatches(), batchContainer: newBatchContainer( - maxMessages, maxBatchSize, producerName, producerID, + maxMessages, maxBatchSize, maxMessageSize, producerName, producerID, compressionType, level, bufferPool, logger, encryptor, ), compressionType: compressionType, @@ -151,7 +151,7 @@ func (bc *keyBasedBatchContainer) Add( if batchPart == nil { // create batchContainer for new key t := newBatchContainer( - bc.maxMessages, bc.maxBatchSize, bc.producerName, bc.producerID, + bc.maxMessages, bc.maxBatchSize, bc.maxMessageSize, bc.producerName, bc.producerID, bc.compressionType, bc.level, bc.buffersPool, bc.log, bc.encryptor, ) batchPart = &t diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go new file mode 100644 index 0000000000..b3d64afaec --- /dev/null +++ b/pulsar/message_chunking_test.go @@ -0,0 +1,570 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pulsar + +import ( + "context" + "errors" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var _brokerMaxMessageSize = 1024 * 1024 + +func TestInvalidChunkingConfig(t *testing.T) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + // create producer + producer, err := client.CreateProducer(ProducerOptions{ + Topic: newTopicName(), + DisableBatching: false, + EnableChunking: true, + }) + + assert.Error(t, err, "producer creation should have fail") + assert.Nil(t, producer) +} + +func TestLargeMessage(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + + // create producer without ChunkMaxMessageSize + producer1, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + EnableChunking: true, + }) + assert.NoError(t, err) + assert.NotNil(t, producer1) + defer producer1.Close() + + // create producer with ChunkMaxMessageSize + producer2, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + EnableChunking: true, + ChunkMaxMessageSize: 5, + }) + assert.NoError(t, err) + assert.NotNil(t, producer2) + defer producer2.Close() + + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + Type: Exclusive, + SubscriptionName: "chunk-subscriber", + }) + assert.NoError(t, err) + assert.NotNil(t, consumer) + defer consumer.Close() + + expectMsgs := make([][]byte, 0, 10) + + // test send chunk with serverMaxMessageSize limit + for i := 0; i < 5; i++ { + msg := createTestMessagePayload(_brokerMaxMessageSize + 1) + expectMsgs = append(expectMsgs, msg) + ID, err := producer1.Send(context.Background(), &ProducerMessage{ + Payload: msg, + }) + assert.NoError(t, err) + assert.NotNil(t, ID) + } + + // test receive chunk with serverMaxMessageSize limit + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + msg, err := consumer.Receive(ctx) + cancel() + assert.NoError(t, err) + + expectMsg := expectMsgs[i] + + assert.Equal(t, expectMsg, msg.Payload()) + // ack message + err = consumer.Ack(msg) + assert.NoError(t, err) + } + + // test send chunk with ChunkMaxMessageSize limit + for i := 0; i < 5; i++ { + msg := createTestMessagePayload(50) + expectMsgs = append(expectMsgs, msg) + ID, err := producer2.Send(context.Background(), &ProducerMessage{ + Payload: msg, + }) + assert.NoError(t, err) + assert.NotNil(t, ID) + } + + // test receive chunk with ChunkMaxMessageSize limit + for i := 5; i < 10; i++ { + msg, err := consumer.Receive(context.Background()) + assert.NoError(t, err) + + expectMsg := expectMsgs[i] + + assert.Equal(t, expectMsg, msg.Payload()) + // ack message + err = consumer.Ack(msg) + assert.NoError(t, err) + } +} + +func TestMaxPendingChunkMessages(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + + totalProducers := 5 + producers := make([]Producer, 0, 20) + defer func() { + for _, p := range producers { + p.Close() + } + }() + + clients := make([]Client, 0, 20) + defer func() { + for _, c := range clients { + c.Close() + } + }() + + for j := 0; j < totalProducers; j++ { + pc, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + assert.Nil(t, err) + clients = append(clients, pc) + producer, err := pc.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + EnableChunking: true, + ChunkMaxMessageSize: 10, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + producers = append(producers, producer) + } + + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + Type: Exclusive, + SubscriptionName: "chunk-subscriber", + MaxPendingChunkedMessage: 1, + }) + assert.NoError(t, err) + assert.NotNil(t, consumer) + defer consumer.Close() + + totalMsgs := 40 + wg := sync.WaitGroup{} + wg.Add(totalMsgs * totalProducers) + for i := 0; i < totalMsgs; i++ { + for j := 0; j < totalProducers; j++ { + p := producers[j] + go func() { + ID, err := p.Send(context.Background(), &ProducerMessage{ + Payload: createTestMessagePayload(50), + }) + assert.NoError(t, err) + assert.NotNil(t, ID) + wg.Done() + }() + } + } + wg.Wait() + + received := 0 + for i := 0; i < totalMsgs*totalProducers; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + msg, err := consumer.Receive(ctx) + cancel() + if msg == nil || (err != nil && errors.Is(err, context.DeadlineExceeded)) { + break + } + + received++ + + err = consumer.Ack(msg) + assert.NoError(t, err) + } + + assert.NotEqual(t, totalMsgs*totalProducers, received) +} + +func TestExpireIncompleteChunks(t *testing.T) { + rand.Seed(time.Now().Unix()) + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + + c, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + Type: Exclusive, + SubscriptionName: "chunk-subscriber", + ExpireTimeOfIncompleteChunk: time.Millisecond * 300, + }) + assert.NoError(t, err) + defer c.Close() + + uuid := "test-uuid" + chunkCtxMap := c.(*consumer).consumers[0].chunkedMsgCtxMap + chunkCtxMap.addIfAbsent(uuid, 2, 100) + ctx := chunkCtxMap.get(uuid) + assert.NotNil(t, ctx) + + time.Sleep(400 * time.Millisecond) + + ctx = chunkCtxMap.get(uuid) + assert.Nil(t, ctx) +} + +func TestChunksEnqueueFailed(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + EnableChunking: true, + DisableBatching: true, + MaxPendingMessages: 10, + ChunkMaxMessageSize: 50, + DisableBlockIfQueueFull: true, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + defer producer.Close() + + ID, err := producer.Send(context.Background(), &ProducerMessage{ + Payload: createTestMessagePayload(1000), + }) + assert.Error(t, err) + assert.Nil(t, ID) +} + +func TestSeekChunkMessages(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + totalMessages := 5 + + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + EnableChunking: true, + DisableBatching: true, + ChunkMaxMessageSize: 50, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + defer producer.Close() + + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + Type: Exclusive, + SubscriptionName: "default-seek", + }) + assert.NoError(t, err) + assert.NotNil(t, consumer) + defer consumer.Close() + + msgIDs := make([]MessageID, 0) + for i := 0; i < totalMessages; i++ { + ID, err := producer.Send(context.Background(), &ProducerMessage{ + Payload: createTestMessagePayload(100), + }) + assert.NoError(t, err) + msgIDs = append(msgIDs, ID) + } + + for i := 0; i < totalMessages; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + msg, err := consumer.Receive(ctx) + cancel() + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, msgIDs[i].Serialize(), msg.ID().Serialize()) + } + + err = consumer.Seek(msgIDs[1]) + assert.NoError(t, err) + + for i := 1; i < totalMessages; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + msg, err := consumer.Receive(ctx) + cancel() + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, msgIDs[i].Serialize(), msg.ID().Serialize()) + } + + // todo: add reader seek test when support reader read chunk message +} + +func TestChunkAckAndNAck(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + + assert.Nil(t, err) + defer client.Close() + + topic := newTopicName() + + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + EnableChunking: true, + DisableBatching: true, + ChunkMaxMessageSize: 50, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + defer producer.Close() + + consumer, err := client.Subscribe(ConsumerOptions{ + Topic: topic, + Type: Exclusive, + SubscriptionName: "default-seek", + NackRedeliveryDelay: time.Second, + }) + assert.NoError(t, err) + assert.NotNil(t, consumer) + defer consumer.Close() + + content := createTestMessagePayload(100) + + _, err = producer.Send(context.Background(), &ProducerMessage{ + Payload: content, + }) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + msg, err := consumer.Receive(ctx) + cancel() + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, msg.Payload(), content) + + consumer.Nack(msg) + time.Sleep(time.Second * 2) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + msg, err = consumer.Receive(ctx) + cancel() + assert.NoError(t, err) + assert.NotNil(t, msg) + assert.Equal(t, msg.Payload(), content) +} + +func TestChunkSize(t *testing.T) { + rand.Seed(time.Now().Unix()) + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + assert.Nil(t, err) + defer client.Close() + + // the default message metadata size for string schema + // The proto messageMetaData size as following. (all with tag) (maxMessageSize = 1024 * 1024) + // | producerName | sequenceID | publishTime | uncompressedSize | + // | ------------ | ---------- | ----------- | ---------------- | + // | 6 | 2 | 7 | 4 | + payloadChunkSize := _brokerMaxMessageSize - 19 + + topic := newTopicName() + + producer, err := client.CreateProducer(ProducerOptions{ + Name: "test", + Topic: topic, + EnableChunking: true, + DisableBatching: true, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + defer producer.Close() + + for size := payloadChunkSize; size <= _brokerMaxMessageSize; size++ { + msgID, err := producer.Send(context.Background(), &ProducerMessage{ + Payload: createTestMessagePayload(size), + }) + assert.NoError(t, err) + if size <= payloadChunkSize { + _, ok := msgID.(messageID) + assert.Equal(t, true, ok) + } else { + _, ok := msgID.(chunkMessageID) + assert.Equal(t, true, ok) + } + } +} + +func TestChunkMultiTopicConsumerReceive(t *testing.T) { + topic1 := newTopicName() + topic2 := newTopicName() + + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + if err != nil { + t.Fatal(err) + } + topics := []string{topic1, topic2} + consumer, err := client.Subscribe(ConsumerOptions{ + Topics: topics, + SubscriptionName: "multi-topic-sub", + }) + if err != nil { + t.Fatal(err) + } + defer consumer.Close() + + maxSize := 50 + + // produce messages + for i, topic := range topics { + p, err := client.CreateProducer(ProducerOptions{ + Topic: topic, + DisableBatching: true, + EnableChunking: true, + ChunkMaxMessageSize: uint(maxSize), + }) + if err != nil { + t.Fatal(err) + } + err = genMessages(p, 10, func(idx int) string { + return fmt.Sprintf("topic-%d-hello-%d-%s", i+1, idx, string(createTestMessagePayload(100))) + }) + p.Close() + if err != nil { + t.Fatal(err) + } + } + + receivedTopic1 := 0 + receivedTopic2 := 0 + // nolint + for receivedTopic1+receivedTopic2 < 20 { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + select { + case cm, ok := <-consumer.Chan(): + if ok { + msg := string(cm.Payload()) + if strings.HasPrefix(msg, "topic-1") { + receivedTopic1++ + } else if strings.HasPrefix(msg, "topic-2") { + receivedTopic2++ + } + consumer.Ack(cm.Message) + } else { + t.Fail() + } + case <-ctx.Done(): + t.Error(ctx.Err()) + } + cancel() + } + assert.Equal(t, receivedTopic1, receivedTopic2) +} + +func TestChunkBlockIfQueueFull(t *testing.T) { + client, err := NewClient(ClientOptions{ + URL: lookupURL, + }) + if err != nil { + t.Fatal(err) + } + + topic := newTopicName() + + producer, err := client.CreateProducer(ProducerOptions{ + Name: "test", + Topic: topic, + EnableChunking: true, + DisableBatching: true, + MaxPendingMessages: 1, + ChunkMaxMessageSize: 10, + }) + assert.NoError(t, err) + assert.NotNil(t, producer) + defer producer.Close() + + // Large messages will be split into 11 chunks, exceeding the length of pending queue + ID, err := producer.Send(context.Background(), &ProducerMessage{ + Payload: createTestMessagePayload(100), + }) + assert.NoError(t, err) + assert.NotNil(t, ID) +} + +func createTestMessagePayload(size int) []byte { + payload := make([]byte, size) + for i := range payload { + payload[i] = byte(rand.Intn(100)) + } + return payload +} diff --git a/pulsar/producer.go b/pulsar/producer.go index b4e43bd624..d088fb2d60 100644 --- a/pulsar/producer.go +++ b/pulsar/producer.go @@ -178,6 +178,15 @@ type ProducerOptions struct { // Encryption specifies the fields required to encrypt a message Encryption *ProducerEncryptionInfo + + // EnableChunking controls whether automatic chunking of messages is enabled for the producer. By default, chunking + // is disabled. + // Chunking can not be enabled when batching is enabled. + EnableChunking bool + + // ChunkMaxMessageSize is the max size of single chunk payload. + // It will actually only take effect if it is smaller than the maxMessageSize from the broker. + ChunkMaxMessageSize uint } // Producer is used to publish messages on a topic diff --git a/pulsar/producer_impl.go b/pulsar/producer_impl.go index 9bbfccbd74..3c45b597d0 100644 --- a/pulsar/producer_impl.go +++ b/pulsar/producer_impl.go @@ -94,6 +94,10 @@ func newProducer(client *client, options *ProducerOptions) (*producer, error) { options.PartitionsAutoDiscoveryInterval = defaultPartitionsAutoDiscoveryInterval } + if !options.DisableBatching && options.EnableChunking { + return nil, fmt.Errorf("batching and chunking can not be enabled together") + } + p := &producer{ options: options, topic: options.Topic, diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go index 922d89dbb9..881451c654 100644 --- a/pulsar/producer_partition.go +++ b/pulsar/producer_partition.go @@ -20,6 +20,9 @@ package pulsar import ( "context" "errors" + "fmt" + "math" + "strconv" "strings" "sync" "sync/atomic" @@ -53,6 +56,7 @@ var ( errSendQueueIsFull = newError(ProducerQueueIsFull, "producer send queue is full") errContextExpired = newError(TimeoutError, "message send context expired") errMessageTooLarge = newError(MessageTooBig, "message size exceeds MaxMessageSize") + errMetaTooLarge = newError(InvalidMessage, "message metadata size exceeds MaxMessageSize") errProducerClosed = newError(ProducerClosed, "producer already been closed") buffersPool sync.Pool @@ -75,6 +79,8 @@ type partitionProducer struct { batchBuilder internal.BatchBuilder sequenceIDGenerator *uint64 batchFlushTicker *time.Ticker + encryptor internalcrypto.Encryptor + compressionProvider compression.Provider // Channel where app is posting messages to be published eventsChan chan interface{} @@ -146,6 +152,8 @@ func newPartitionProducer(client *client, topic string, options *ProducerOptions connectClosedCh: make(chan connectionClosed, 10), closeCh: make(chan struct{}), batchFlushTicker: time.NewTicker(batchingMaxPublishDelay), + compressionProvider: internal.GetCompressionProvider(pb.CompressionType(options.CompressionType), + compression.Level(options.CompressionLevel)), publishSemaphore: internal.NewSemaphore(int32(maxPendingMessages)), pendingQueue: internal.NewBlockingQueue(maxPendingMessages), lastSequenceID: -1, @@ -246,42 +254,13 @@ func (p *partitionProducer) grabCnx() error { p.producerName = res.Response.ProducerSuccess.GetProducerName() - var encryptor internalcrypto.Encryptor if p.options.Encryption != nil { - encryptor = internalcrypto.NewProducerEncryptor(p.options.Encryption.Keys, + p.encryptor = internalcrypto.NewProducerEncryptor(p.options.Encryption.Keys, p.options.Encryption.KeyReader, p.options.Encryption.MessageCrypto, p.options.Encryption.ProducerCryptoFailureAction, p.log) } else { - encryptor = internalcrypto.NewNoopEncryptor() - } - - if p.options.DisableBatching { - provider, _ := GetBatcherBuilderProvider(DefaultBatchBuilder) - p.batchBuilder, err = provider(p.options.BatchingMaxMessages, p.options.BatchingMaxSize, - p.producerName, p.producerID, pb.CompressionType(p.options.CompressionType), - compression.Level(p.options.CompressionLevel), - p, - p.log, - encryptor) - if err != nil { - return err - } - } else if p.batchBuilder == nil { - provider, err := GetBatcherBuilderProvider(p.options.BatcherBuilderType) - if err != nil { - provider, _ = GetBatcherBuilderProvider(DefaultBatchBuilder) - } - - p.batchBuilder, err = provider(p.options.BatchingMaxMessages, p.options.BatchingMaxSize, - p.producerName, p.producerID, pb.CompressionType(p.options.CompressionType), - compression.Level(p.options.CompressionLevel), - p, - p.log, - encryptor) - if err != nil { - return err - } + p.encryptor = internalcrypto.NewNoopEncryptor() } if p.sequenceIDGenerator == nil { @@ -299,6 +278,24 @@ func (p *partitionProducer) grabCnx() error { if err != nil { return err } + + if !p.options.DisableBatching && p.batchBuilder == nil { + provider, err := GetBatcherBuilderProvider(p.options.BatcherBuilderType) + if err != nil { + return err + } + maxMessageSize := uint32(p._getConn().GetMaxMessageSize()) + p.batchBuilder, err = provider(p.options.BatchingMaxMessages, p.options.BatchingMaxSize, + maxMessageSize, p.producerName, p.producerID, pb.CompressionType(p.options.CompressionType), + compression.Level(p.options.CompressionLevel), + p, + p.log, + p.encryptor) + if err != nil { + return err + } + } + p.log.WithFields(log.Fields{ "cnx": res.Cnx.ID(), "epoch": atomic.LoadUint64(&p.epoch), @@ -323,7 +320,7 @@ func (p *partitionProducer) grabCnx() error { pi.sentAt = time.Now() pi.Unlock() p.pendingQueue.Put(pi) - p._getConn().WriteData(pi.batchData) + p._getConn().WriteData(pi.buffer) if pi == lastViewItem { break @@ -472,12 +469,12 @@ func (p *partitionProducer) Name() string { } func (p *partitionProducer) internalSend(request *sendRequest) { - p.log.Debug("Received send request: ", *request) + p.log.Debug("Received send request: ", *request.msg) msg := request.msg // read payload from message - payload := msg.Payload + uncompressedPayload := msg.Payload var schemaPayload []byte var err error @@ -486,9 +483,16 @@ func (p *partitionProducer) internalSend(request *sendRequest) { return } + // The block chan must be closed when returned with exception + defer request.stopBlock() + if !p.canAddToQueue(request) { + return + } + if p.options.DisableMultiSchema { if msg.Schema != nil && p.options.Schema != nil && msg.Schema.GetSchemaInfo().hash() != p.options.Schema.GetSchemaInfo().hash() { + p.publishSemaphore.Release() p.log.WithError(err).Errorf("The producer %s of the topic %s is disabled the `MultiSchema`", p.producerName, p.topic) return } @@ -503,7 +507,7 @@ func (p *partitionProducer) internalSend(request *sendRequest) { if msg.Value != nil { // payload and schema are mutually exclusive // try to get payload from schema value only if payload is not set - if payload == nil && schema != nil { + if uncompressedPayload == nil && schema != nil { schemaPayload, err = schema.Encode(msg.Value) if err != nil { p.publishSemaphore.Release() @@ -513,14 +517,16 @@ func (p *partitionProducer) internalSend(request *sendRequest) { } } } - if payload == nil { - payload = schemaPayload + if uncompressedPayload == nil { + uncompressedPayload = schemaPayload } + if schema != nil { schemaVersion = p.schemaCache.Get(schema.GetSchemaInfo()) if schemaVersion == nil { schemaVersion, err = p.getOrCreateSchema(schema.GetSchemaInfo()) if err != nil { + p.publishSemaphore.Release() p.log.WithError(err).Error("get schema version fail") return } @@ -528,29 +534,196 @@ func (p *partitionProducer) internalSend(request *sendRequest) { } } - // if msg is too large - if len(payload) > int(p._getConn().GetMaxMessageSize()) { + uncompressedSize := len(uncompressedPayload) + + deliverAt := msg.DeliverAt + if msg.DeliverAfter.Nanoseconds() > 0 { + deliverAt = time.Now().Add(msg.DeliverAfter) + } + + mm := p.genMetadata(msg, uncompressedSize, deliverAt) + + // set default ReplicationClusters when DisableReplication + if msg.DisableReplication { + msg.ReplicationClusters = []string{"__local__"} + } + + sendAsBatch := !p.options.DisableBatching && + msg.ReplicationClusters == nil && + deliverAt.UnixNano() < 0 + + // Once the batching is enabled, it can close blockCh early to make block finish + if sendAsBatch { + request.stopBlock() + } else { + // update sequence id for metadata, make the size of msgMetadata more accurate + // batch sending will update sequence ID in the BatchBuilder + p.updateMetadataSeqID(mm, msg) + } + + maxMessageSize := int(p._getConn().GetMaxMessageSize()) + + // compress payload if not batching + var compressedPayload []byte + var compressedSize int + var checkSize int + if !sendAsBatch { + compressedPayload = p.compressionProvider.Compress(nil, uncompressedPayload) + compressedSize = len(compressedPayload) + checkSize = compressedSize + } else { + // final check for batching message is in serializeMessage + // this is a double check + checkSize = uncompressedSize + } + + // if msg is too large and chunking is disabled + if checkSize > maxMessageSize && !p.options.EnableChunking { p.publishSemaphore.Release() request.callback(nil, request.msg, errMessageTooLarge) p.log.WithError(errMessageTooLarge). - WithField("size", len(payload)). + WithField("size", checkSize). WithField("properties", msg.Properties). - Errorf("MaxMessageSize %d", int(p._getConn().GetMaxMessageSize())) + Errorf("MaxMessageSize %d", maxMessageSize) p.metrics.PublishErrorsMsgTooLarge.Inc() return } - deliverAt := msg.DeliverAt - if msg.DeliverAfter.Nanoseconds() > 0 { - deliverAt = time.Now().Add(msg.DeliverAfter) + var totalChunks int + // max chunk payload size + var payloadChunkSize int + if sendAsBatch || !p.options.EnableChunking { + totalChunks = 1 + payloadChunkSize = int(p._getConn().GetMaxMessageSize()) + } else { + payloadChunkSize = int(p._getConn().GetMaxMessageSize()) - mm.Size() + if payloadChunkSize <= 0 { + p.publishSemaphore.Release() + request.callback(nil, msg, errMetaTooLarge) + p.log.WithError(errMetaTooLarge). + WithField("metadata size", mm.Size()). + WithField("properties", msg.Properties). + Errorf("MaxMessageSize %d", int(p._getConn().GetMaxMessageSize())) + p.metrics.PublishErrorsMsgTooLarge.Inc() + return + } + // set ChunkMaxMessageSize + if p.options.ChunkMaxMessageSize != 0 { + payloadChunkSize = int(math.Min(float64(payloadChunkSize), float64(p.options.ChunkMaxMessageSize))) + } + totalChunks = int(math.Max(1, math.Ceil(float64(compressedSize)/float64(payloadChunkSize)))) } - sendAsBatch := !p.options.DisableBatching && - msg.ReplicationClusters == nil && - deliverAt.UnixNano() < 0 + // set total chunks to send request + request.totalChunks = totalChunks + + if !sendAsBatch { + if totalChunks > 1 { + var lhs, rhs int + uuid := fmt.Sprintf("%s-%s", p.producerName, strconv.FormatUint(*mm.SequenceId, 10)) + mm.Uuid = proto.String(uuid) + mm.NumChunksFromMsg = proto.Int(totalChunks) + mm.TotalChunkMsgSize = proto.Int(compressedSize) + cr := newChunkRecorder() + for chunkID := 0; chunkID < totalChunks; chunkID++ { + lhs = chunkID * payloadChunkSize + if rhs = lhs + payloadChunkSize; rhs > compressedSize { + rhs = compressedSize + } + // update chunk id + mm.ChunkId = proto.Int(chunkID) + nsr := &sendRequest{ + ctx: request.ctx, + msg: request.msg, + callback: request.callback, + callbackOnce: request.callbackOnce, + publishTime: request.publishTime, + blockCh: request.blockCh, + closeBlockChOnce: request.closeBlockChOnce, + totalChunks: totalChunks, + chunkID: chunkID, + uuid: uuid, + chunkRecorder: cr, + } + // the permit of first chunk has acquired + if chunkID != 0 && !p.canAddToQueue(nsr) { + return + } + p.internalSingleSend(mm, compressedPayload[lhs:rhs], nsr, uint32(maxMessageSize)) + } + // close the blockCh when all the chunks acquired permits + request.stopBlock() + } else { + // close the blockCh when totalChunks is 1 (it has acquired permits) + request.stopBlock() + p.internalSingleSend(mm, compressedPayload, request, uint32(maxMessageSize)) + } + } else { + smm := p.genSingleMessageMetadataInBatch(msg, uncompressedSize) + multiSchemaEnabled := !p.options.DisableMultiSchema + added := p.batchBuilder.Add(smm, p.sequenceIDGenerator, uncompressedPayload, request, + msg.ReplicationClusters, deliverAt, schemaVersion, multiSchemaEnabled) + if !added { + // The current batch is full.. flush it and retry + + p.internalFlushCurrentBatch() + + // after flushing try again to add the current payload + if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator, uncompressedPayload, request, + msg.ReplicationClusters, deliverAt, schemaVersion, multiSchemaEnabled); !ok { + p.publishSemaphore.Release() + request.callback(nil, request.msg, errFailAddToBatch) + p.log.WithField("size", uncompressedSize). + WithField("properties", msg.Properties). + Error("unable to add message to batch") + return + } + } + if request.flushImmediately { - smm := &pb.SingleMessageMetadata{ - PayloadSize: proto.Int(len(payload)), + p.internalFlushCurrentBatch() + + } + } +} + +func (p *partitionProducer) genMetadata(msg *ProducerMessage, + uncompressedSize int, + deliverAt time.Time) (mm *pb.MessageMetadata) { + mm = &pb.MessageMetadata{ + ProducerName: &p.producerName, + PublishTime: proto.Uint64(internal.TimestampMillis(time.Now())), + ReplicateTo: msg.ReplicationClusters, + UncompressedSize: proto.Uint32(uint32(uncompressedSize)), + } + + if msg.Key != "" { + mm.PartitionKey = proto.String(msg.Key) + } + + if msg.Properties != nil { + mm.Properties = internal.ConvertFromStringMap(msg.Properties) + } + + if deliverAt.UnixNano() > 0 { + mm.DeliverAtTime = proto.Int64(int64(internal.TimestampMillis(deliverAt))) + } + + return +} + +func (p *partitionProducer) updateMetadataSeqID(mm *pb.MessageMetadata, msg *ProducerMessage) { + if msg.SequenceID != nil { + mm.SequenceId = proto.Uint64(uint64(*msg.SequenceID)) + } else { + mm.SequenceId = proto.Uint64(internal.GetAndAdd(p.sequenceIDGenerator, 1)) + } +} + +func (p *partitionProducer) genSingleMessageMetadataInBatch(msg *ProducerMessage, + uncompressedSize int) (smm *pb.SingleMessageMetadata) { + smm = &pb.SingleMessageMetadata{ + PayloadSize: proto.Int(uncompressedSize), } if !msg.EventTime.IsZero() { @@ -569,49 +742,61 @@ func (p *partitionProducer) internalSend(request *sendRequest) { smm.Properties = internal.ConvertFromStringMap(msg.Properties) } + var sequenceID uint64 if msg.SequenceID != nil { - sequenceID := uint64(*msg.SequenceID) - smm.SequenceId = proto.Uint64(sequenceID) + sequenceID = uint64(*msg.SequenceID) + } else { + sequenceID = internal.GetAndAdd(p.sequenceIDGenerator, 1) } - if !sendAsBatch { - p.internalFlushCurrentBatch() - } + smm.SequenceId = proto.Uint64(sequenceID) - if msg.DisableReplication { - msg.ReplicationClusters = []string{"__local__"} - } - multiSchemaEnabled := !p.options.DisableMultiSchema - added := p.batchBuilder.Add(smm, p.sequenceIDGenerator, payload, request, - msg.ReplicationClusters, deliverAt, schemaVersion, multiSchemaEnabled) - if !added { - // The current batch is full.. flush it and retry + return +} - p.internalFlushCurrentBatch() +func (p *partitionProducer) internalSingleSend(mm *pb.MessageMetadata, + compressedPayload []byte, + request *sendRequest, + maxMessageSize uint32) { + msg := request.msg - // after flushing try again to add the current payload - if ok := p.batchBuilder.Add(smm, p.sequenceIDGenerator, payload, request, - msg.ReplicationClusters, deliverAt, schemaVersion, multiSchemaEnabled); !ok { - p.publishSemaphore.Release() - request.callback(nil, request.msg, errFailAddToBatch) - p.log.WithField("size", len(payload)). - WithField("properties", msg.Properties). - Error("unable to add message to batch") - return - } - } + payloadBuf := internal.NewBuffer(len(compressedPayload)) + payloadBuf.Write(compressedPayload) - if !sendAsBatch || request.flushImmediately { + buffer := p.GetBuffer() + if buffer == nil { + buffer = internal.NewBuffer(int(payloadBuf.ReadableBytes() * 3 / 2)) + } - p.internalFlushCurrentBatch() + sid := *mm.SequenceId + if err := internal.SingleSend( + buffer, + p.producerID, + sid, + mm, + payloadBuf, + p.encryptor, + maxMessageSize, + ); err != nil { + request.callback(nil, request.msg, err) + p.publishSemaphore.Release() + p.log.WithError(err).Errorf("Single message serialize failed %s", msg.Value) + return } + p.pendingQueue.Put(&pendingItem{ + sentAt: time.Now(), + buffer: buffer, + sequenceID: sid, + sendRequests: []interface{}{request}, + }) + p._getConn().WriteData(buffer) } type pendingItem struct { sync.Mutex - batchData internal.Buffer + buffer internal.Buffer sequenceID uint64 sentAt time.Time sendRequests []interface{} @@ -637,12 +822,18 @@ func (p *partitionProducer) internalFlushCurrentBatch() { sr.callback(nil, sr.msg, err) } } + if errors.Is(err, internal.ErrExceedMaxMessageSize) { + p.log.WithError(errMessageTooLarge). + Errorf("internal err: %s", err) + p.metrics.PublishErrorsMsgTooLarge.Inc() + return + } return } p.pendingQueue.Put(&pendingItem{ sentAt: time.Now(), - batchData: batchData, + buffer: batchData, sequenceID: sequenceID, sendRequests: callbacks, }) @@ -735,8 +926,11 @@ func (p *partitionProducer) failTimeoutMessages() { WithField("size", size). WithField("properties", sr.msg.Properties) } + if sr.callback != nil { - sr.callback(nil, sr.msg, errSendTimeout) + sr.callbackOnce.Do(func() { + sr.callback(nil, sr.msg, errSendTimeout) + }) } } @@ -754,7 +948,7 @@ func (p *partitionProducer) failTimeoutMessages() { } func (p *partitionProducer) internalFlushCurrentBatches() { - batchesData, sequenceIDs, callbacks, errors := p.batchBuilder.FlushBatches() + batchesData, sequenceIDs, callbacks, errs := p.batchBuilder.FlushBatches() if batchesData == nil { return } @@ -762,12 +956,18 @@ func (p *partitionProducer) internalFlushCurrentBatches() { for i := range batchesData { // error occurred in processing batch // report it using callback - if errors[i] != nil { + if errs[i] != nil { for _, cb := range callbacks[i] { if sr, ok := cb.(*sendRequest); ok { - sr.callback(nil, sr.msg, errors[i]) + sr.callback(nil, sr.msg, errs[i]) } } + if errors.Is(errs[i], internal.ErrExceedMaxMessageSize) { + p.log.WithError(errMessageTooLarge). + Errorf("internal err: %s", errs[i]) + p.metrics.PublishErrorsMsgTooLarge.Inc() + return + } continue } if batchesData[i] == nil { @@ -775,7 +975,7 @@ func (p *partitionProducer) internalFlushCurrentBatches() { } p.pendingQueue.Put(&pendingItem{ sentAt: time.Now(), - batchData: batchesData[i], + buffer: batchesData[i], sequenceID: sequenceIDs[i], sendRequests: callbacks[i], }) @@ -836,6 +1036,7 @@ func (p *partitionProducer) Send(ctx context.Context, msg *ProducerMessage) (Mes // wait for send request to finish <-doneCh + return msgID, err } @@ -852,33 +1053,30 @@ func (p *partitionProducer) internalSendAsync(ctx context.Context, msg *Producer return } + // bc only works when DisableBlockIfQueueFull is false + bc := make(chan struct{}) + + // callbackOnce make sure the callback is only invoked once in chunking + callbackOnce := &sync.Once{} + sr := &sendRequest{ ctx: ctx, msg: msg, callback: callback, + callbackOnce: callbackOnce, flushImmediately: flushImmediately, publishTime: time.Now(), + blockCh: bc, + closeBlockChOnce: &sync.Once{}, } p.options.Interceptors.BeforeSend(p, msg) - if p.options.DisableBlockIfQueueFull { - if !p.publishSemaphore.TryAcquire() { - if callback != nil { - callback(nil, msg, errSendQueueIsFull) - } - return - } - } else { - if !p.publishSemaphore.Acquire(ctx) { - callback(nil, msg, errContextExpired) - return - } - } - - p.metrics.MessagesPending.Inc() - p.metrics.BytesPending.Add(float64(len(sr.msg.Payload))) - p.eventsChan <- sr + + if !p.options.DisableBlockIfQueueFull { + // block if queue full + <-bc + } } func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) { @@ -933,11 +1131,34 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) p.partitionIdx, ) - if sr.callback != nil { - sr.callback(msgID, sr.msg, nil) + if sr.totalChunks > 1 { + if sr.chunkID == 0 { + sr.chunkRecorder.setFirstChunkID( + messageID{ + int64(response.MessageId.GetLedgerId()), + int64(response.MessageId.GetEntryId()), + -1, + p.partitionIdx, + }) + } else if sr.chunkID == sr.totalChunks-1 { + sr.chunkRecorder.setLastChunkID( + messageID{ + int64(response.MessageId.GetLedgerId()), + int64(response.MessageId.GetEntryId()), + -1, + p.partitionIdx, + }) + // use chunkMsgID to set msgID + msgID = sr.chunkRecorder.chunkedMsgID + } } - p.options.Interceptors.OnSendAcknowledgement(p, sr.msg, msgID) + if sr.totalChunks <= 1 || sr.chunkID == sr.totalChunks-1 { + if sr.callback != nil { + sr.callback(msgID, sr.msg, nil) + } + p.options.Interceptors.OnSendAcknowledgement(p, sr.msg, msgID) + } } } @@ -966,8 +1187,10 @@ func (p *partitionProducer) internalClose(req *closeProducer) { p.log.Info("Closed producer") } - if err = p.batchBuilder.Close(); err != nil { - p.log.WithError(err).Warn("Failed to close batch builder") + if p.batchBuilder != nil { + if err = p.batchBuilder.Close(); err != nil { + p.log.WithError(err).Warn("Failed to close batch builder") + } } p.setProducerState(producerClosed) @@ -1024,8 +1247,22 @@ type sendRequest struct { ctx context.Context msg *ProducerMessage callback func(MessageID, *ProducerMessage, error) + callbackOnce *sync.Once publishTime time.Time flushImmediately bool + blockCh chan struct{} + closeBlockChOnce *sync.Once + totalChunks int + chunkID int + uuid string + chunkRecorder *chunkRecorder +} + +// stopBlock can be invoked multiple times safety +func (sr *sendRequest) stopBlock() { + sr.closeBlockChOnce.Do(func() { + close(sr.blockCh) + }) } type closeProducer struct { @@ -1042,7 +1279,7 @@ func (i *pendingItem) Complete() { return } i.completed = true - buffersPool.Put(i.batchData) + buffersPool.Put(i.buffer) } // _setConn sets the internal connection field of this partition producer atomically. @@ -1059,3 +1296,38 @@ func (p *partitionProducer) _getConn() internal.Connection { // invariant is broken return p.conn.Load().(internal.Connection) } + +func (p *partitionProducer) canAddToQueue(sr *sendRequest) bool { + if p.options.DisableBlockIfQueueFull { + if !p.publishSemaphore.TryAcquire() { + if sr.callback != nil { + sr.callback(nil, sr.msg, errSendQueueIsFull) + } + return false + } + } else if !p.publishSemaphore.Acquire(sr.ctx) { + sr.callback(nil, sr.msg, errContextExpired) + return false + } + p.metrics.MessagesPending.Inc() + p.metrics.BytesPending.Add(float64(len(sr.msg.Payload))) + return true +} + +type chunkRecorder struct { + chunkedMsgID chunkMessageID +} + +func newChunkRecorder() *chunkRecorder { + return &chunkRecorder{ + chunkedMsgID: chunkMessageID{}, + } +} + +func (c *chunkRecorder) setFirstChunkID(msgID messageID) { + c.chunkedMsgID.firstChunkID = msgID +} + +func (c *chunkRecorder) setLastChunkID(msgID messageID) { + c.chunkedMsgID.messageID = msgID +} diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go index dc13f50dcb..f193ffd4ec 100644 --- a/pulsar/producer_test.go +++ b/pulsar/producer_test.go @@ -19,6 +19,7 @@ package pulsar import ( "context" + "errors" "fmt" "net/http" "strconv" @@ -934,14 +935,45 @@ func TestMaxMessageSize(t *testing.T) { assert.NotNil(t, producer) defer producer.Close() + // producer2 disable batching + producer2, err := client.CreateProducer(ProducerOptions{ + Topic: newTopicName(), + DisableBatching: true, + }) + assert.NoError(t, err) + assert.NotNil(t, producer2) + defer producer2.Close() + + // When serverMaxMessageSize=1024, the batch payload=1041 + // The totalSize includes: + // | singleMsgMetadataLength | singleMsgMetadata | payload | + // | ----------------------- | ----------------- | ------- | + // | 4 | 13 | 1024 | + // So when bias <= 0, the uncompressed payload will not exceed maxMessageSize, + // but encryptedPayloadSize exceeds maxMessageSize, Send() will return an internal error. + // When bias = 1, the first check of maxMessageSize (for uncompressed payload) is valid, + // Send() will return errMessageTooLarge for bias := -1; bias <= 1; bias++ { payload := make([]byte, serverMaxMessageSize+bias) ID, err := producer.Send(context.Background(), &ProducerMessage{ Payload: payload, }) if bias <= 0 { - assert.NoError(t, err) - assert.NotNil(t, ID) + assert.Equal(t, true, errors.Is(err, internal.ErrExceedMaxMessageSize)) + assert.Nil(t, ID) + } else { + assert.Equal(t, errMessageTooLarge, err) + } + } + + for bias := -1; bias <= 1; bias++ { + payload := make([]byte, serverMaxMessageSize+bias) + ID, err := producer2.Send(context.Background(), &ProducerMessage{ + Payload: payload, + }) + if bias <= 0 { + assert.Equal(t, true, errors.Is(err, internal.ErrExceedMaxMessageSize)) + assert.Nil(t, ID) } else { assert.Equal(t, errMessageTooLarge, err) }