diff --git a/cmd/broker/acl_test.go b/cmd/broker/acl_test.go index 025237c5..3150aac0 100644 --- a/cmd/broker/acl_test.go +++ b/cmd/broker/acl_test.go @@ -16,10 +16,7 @@ package main import ( - "bytes" "context" - "encoding/binary" - "io" "net" "testing" @@ -37,13 +34,13 @@ func TestACLProduceDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ {Partition: 0, Records: testBatchBytes(0, 0, 1)}, }, }, @@ -53,7 +50,7 @@ func TestACLProduceDenied(t *testing.T) { if err != nil { t.Fatalf("handleProduce: %v", err) } - resp := decodeProduceResponse(t, payload, 0) + resp := decodeKmsgResponse(t, 0, payload, kmsg.NewPtrProduceResponse) if len(resp.Topics) != 1 || len(resp.Topics[0].Partitions) != 1 { t.Fatalf("expected single topic/partition response") } @@ -77,12 +74,12 @@ func TestACLJoinGroupDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - joinReq := &protocol.JoinGroupRequest{GroupID: "group-a"} + joinReq := &kmsg.JoinGroupRequest{Group: "group-a"} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 5, APIVersion: 4, ClientID: &clientID}, joinReq) if err != nil { t.Fatalf("Handle JoinGroup: %v", err) } - resp := decodeJoinGroupResponse(t, payload, 4) + resp := decodeKmsgResponse(t, 4, payload, kmsg.NewPtrJoinGroupResponse) if resp.ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { t.Fatalf("expected group auth failed, got %d", resp.ErrorCode) } @@ -96,11 +93,11 @@ func TestACLListOffsetsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.ListOffsetsRequest{ - Topics: []protocol.ListOffsetsTopic{ + req := &kmsg.ListOffsetsRequest{ + Topics: []kmsg.ListOffsetsRequestTopic{ { - Name: "orders", - Partitions: []protocol.ListOffsetsPartition{ + Topic: "orders", + Partitions: []kmsg.ListOffsetsRequestTopicPartition{ {Partition: 0, Timestamp: -1, MaxNumOffsets: 1}, }, }, @@ -110,7 +107,7 @@ func TestACLListOffsetsDenied(t *testing.T) { if err != nil { t.Fatalf("Handle ListOffsets: %v", err) } - resp := decodeListOffsetsResponse(t, 4, payload) + resp := decodeKmsgResponse(t, 4, payload, kmsg.NewPtrListOffsetsResponse) if len(resp.Topics) != 1 || len(resp.Topics[0].Partitions) != 1 { t.Fatalf("expected single topic/partition response") } @@ -127,14 +124,12 @@ func TestACLOffsetFetchDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.OffsetFetchRequest{ - GroupID: "group-a", - Topics: []protocol.OffsetFetchTopic{ + req := &kmsg.OffsetFetchRequest{ + Group: "group-a", + Topics: []kmsg.OffsetFetchRequestTopic{ { - Name: "orders", - Partitions: []protocol.OffsetFetchPartition{ - {Partition: 0}, - }, + Topic: "orders", + Partitions: []int32{0}, }, }, } @@ -142,7 +137,7 @@ func TestACLOffsetFetchDenied(t *testing.T) { if err != nil { t.Fatalf("Handle OffsetFetch: %v", err) } - resp := decodeOffsetFetchResponse(t, payload, 5) + resp := decodeKmsgResponse(t, 5, payload, kmsg.NewPtrOffsetFetchResponse) if len(resp.Topics) != 1 || len(resp.Topics[0].Partitions) != 1 { t.Fatalf("expected single topic/partition response") } @@ -162,12 +157,12 @@ func TestACLDescribeGroupsMixed(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.DescribeGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} + req := &kmsg.DescribeGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 20, APIVersion: 5, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle DescribeGroups: %v", err) } - resp := decodeDescribeGroupsResponse(t, payload, 5) + resp := decodeKmsgResponse(t, 5, payload, kmsg.NewPtrDescribeGroupsResponse) if len(resp.Groups) != 2 { t.Fatalf("expected 2 groups, got %d", len(resp.Groups)) } @@ -190,12 +185,12 @@ func TestACLDeleteGroupsMixed(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.DeleteGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} + req := &kmsg.DeleteGroupsRequest{Groups: []string{"group-allowed", "group-denied"}} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 21, APIVersion: 2, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle DeleteGroups: %v", err) } - resp := decodeDeleteGroupsResponse(t, payload, 2) + resp := decodeKmsgResponse(t, 2, payload, kmsg.NewPtrDeleteGroupsResponse) if len(resp.Groups) != 2 { t.Fatalf("expected 2 groups, got %d", len(resp.Groups)) } @@ -234,13 +229,13 @@ func TestACLProxyAddrProduceAllowed(t *testing.T) { ctx := broker.ContextWithConnInfo(context.Background(), info) clientID := "spoofed-client" - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ {Partition: 0, Records: testBatchBytes(0, 0, 1)}, }, }, @@ -250,7 +245,7 @@ func TestACLProxyAddrProduceAllowed(t *testing.T) { if err != nil { t.Fatalf("handleProduce: %v", err) } - resp := decodeProduceResponse(t, payload, 0) + resp := decodeKmsgResponse(t, 0, payload, kmsg.NewPtrProduceResponse) if resp.Topics[0].Partitions[0].ErrorCode != protocol.NONE { t.Fatalf("expected produce allowed, got %d", resp.Topics[0].Partitions[0].ErrorCode) } @@ -264,12 +259,12 @@ func TestACLSyncGroupDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.SyncGroupRequest{GroupID: "group-a", GenerationID: 1, MemberID: "member-a"} + req := &kmsg.SyncGroupRequest{Group: "group-a", Generation: 1, MemberID: "member-a"} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 9, APIVersion: 4, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle SyncGroup: %v", err) } - resp := decodeSyncGroupResponse(t, payload, 4) + resp := decodeKmsgResponse(t, 4, payload, kmsg.NewPtrSyncGroupResponse) if resp.ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { t.Fatalf("expected group auth failed, got %d", resp.ErrorCode) } @@ -283,12 +278,12 @@ func TestACLHeartbeatDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.HeartbeatRequest{GroupID: "group-a", GenerationID: 1, MemberID: "member-a"} + req := &kmsg.HeartbeatRequest{Group: "group-a", Generation: 1, MemberID: "member-a"} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 10, APIVersion: 4, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle Heartbeat: %v", err) } - resp := decodeHeartbeatResponse(t, payload, 4) + resp := decodeKmsgResponse(t, 4, payload, kmsg.NewPtrHeartbeatResponse) if resp.ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { t.Fatalf("expected group auth failed, got %d", resp.ErrorCode) } @@ -302,12 +297,12 @@ func TestACLLeaveGroupDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.LeaveGroupRequest{GroupID: "group-a", MemberID: "member-a"} + req := &kmsg.LeaveGroupRequest{Group: "group-a", MemberID: "member-a"} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 11, APIVersion: 0, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle LeaveGroup: %v", err) } - resp := decodeLeaveGroupResponse(t, payload) + resp := decodeKmsgResponse(t, 0, payload, kmsg.NewPtrLeaveGroupResponse) if resp.ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { t.Fatalf("expected group auth failed, got %d", resp.ErrorCode) } @@ -321,13 +316,13 @@ func TestACLOffsetCommitDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.OffsetCommitRequest{ - GroupID: "group-a", - Topics: []protocol.OffsetCommitTopic{ + req := &kmsg.OffsetCommitRequest{ + Group: "group-a", + Topics: []kmsg.OffsetCommitRequestTopic{ { - Name: "orders", - Partitions: []protocol.OffsetCommitPartition{ - {Partition: 0, Offset: 1, Metadata: ""}, + Topic: "orders", + Partitions: []kmsg.OffsetCommitRequestTopicPartition{ + {Partition: 0, Offset: 1, Metadata: kmsg.StringPtr("")}, }, }, }, @@ -336,7 +331,7 @@ func TestACLOffsetCommitDenied(t *testing.T) { if err != nil { t.Fatalf("Handle OffsetCommit: %v", err) } - resp := decodeOffsetCommitResponse(t, payload, 3) + resp := decodeKmsgResponse(t, 3, payload, kmsg.NewPtrOffsetCommitResponse) if len(resp.Topics) == 0 || len(resp.Topics[0].Partitions) == 0 { t.Fatalf("expected offset commit response") } @@ -353,14 +348,14 @@ func TestACLCreateTopicsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.CreateTopicsRequest{ - Topics: []protocol.CreateTopicConfig{{Name: "orders", NumPartitions: 1, ReplicationFactor: 1}}, + req := &kmsg.CreateTopicsRequest{ + Topics: []kmsg.CreateTopicsRequestTopic{{Topic: "orders", NumPartitions: 1, ReplicationFactor: 1}}, } payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 13, APIVersion: 0, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle CreateTopics: %v", err) } - resp := decodeCreateTopicsResponse(t, payload, 0) + resp := decodeKmsgResponse(t, 0, payload, kmsg.NewPtrCreateTopicsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != protocol.TOPIC_AUTHORIZATION_FAILED { t.Fatalf("expected topic auth failed, got %+v", resp.Topics) } @@ -374,12 +369,12 @@ func TestACLDeleteTopicsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.DeleteTopicsRequest{TopicNames: []string{"orders"}} + req := &kmsg.DeleteTopicsRequest{TopicNames: []string{"orders"}} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 14, APIVersion: 0, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle DeleteTopics: %v", err) } - resp := decodeDeleteTopicsResponse(t, payload, 0) + resp := decodeKmsgResponse(t, 0, payload, kmsg.NewPtrDeleteTopicsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != protocol.TOPIC_AUTHORIZATION_FAILED { t.Fatalf("expected topic auth failed, got %+v", resp.Topics) } @@ -393,10 +388,10 @@ func TestACLAlterConfigsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.AlterConfigsRequest{ - Resources: []protocol.AlterConfigsResource{ + req := &kmsg.AlterConfigsRequest{ + Resources: []kmsg.AlterConfigsRequestResource{ { - ResourceType: protocol.ConfigResourceTopic, + ResourceType: kmsg.ConfigResourceTypeTopic, ResourceName: "orders", }, }, @@ -405,7 +400,7 @@ func TestACLAlterConfigsDenied(t *testing.T) { if err != nil { t.Fatalf("Handle AlterConfigs: %v", err) } - resp := decodeAlterConfigsResponse(t, payload, 1) + resp := decodeKmsgResponse(t, 1, payload, kmsg.NewPtrAlterConfigsResponse) if len(resp.Resources) != 1 || resp.Resources[0].ErrorCode != protocol.TOPIC_AUTHORIZATION_FAILED { t.Fatalf("expected topic auth failed, got %+v", resp.Resources) } @@ -419,14 +414,14 @@ func TestACLCreatePartitionsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{{Name: "orders", Count: 2}}, + req := &kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{{Topic: "orders", Count: 2}}, } payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 16, APIVersion: 3, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle CreatePartitions: %v", err) } - resp := decodeCreatePartitionsResponse(t, payload, 3) + resp := decodeKmsgResponse(t, 3, payload, kmsg.NewPtrCreatePartitionsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != protocol.TOPIC_AUTHORIZATION_FAILED { t.Fatalf("expected topic auth failed, got %+v", resp.Topics) } @@ -440,202 +435,13 @@ func TestACLDeleteGroupsDenied(t *testing.T) { handler := newTestHandler(store) clientID := "client-a" - req := &protocol.DeleteGroupsRequest{Groups: []string{"group-a"}} + req := &kmsg.DeleteGroupsRequest{Groups: []string{"group-a"}} payload, err := handler.Handle(context.Background(), &protocol.RequestHeader{CorrelationID: 17, APIVersion: 2, ClientID: &clientID}, req) if err != nil { t.Fatalf("Handle DeleteGroups: %v", err) } - resp := decodeDeleteGroupsResponse(t, payload, 2) + resp := decodeKmsgResponse(t, 2, payload, kmsg.NewPtrDeleteGroupsResponse) if len(resp.Groups) != 1 || resp.Groups[0].ErrorCode != protocol.GROUP_AUTHORIZATION_FAILED { t.Fatalf("expected group auth failed, got %+v", resp.Groups) } } - -func decodeOffsetFetchResponse(t *testing.T, payload []byte, version int16) *protocol.OffsetFetchResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.OffsetFetchResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 3 { - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle: %v", err) - } - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.OffsetFetchTopicResponse, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - name := readKafkaString(t, reader) - var partCount int32 - if err := binary.Read(reader, binary.BigEndian, &partCount); err != nil { - t.Fatalf("read partition count: %v", err) - } - topic := protocol.OffsetFetchTopicResponse{Name: name} - topic.Partitions = make([]protocol.OffsetFetchPartitionResponse, 0, partCount) - for j := 0; j < int(partCount); j++ { - var part protocol.OffsetFetchPartitionResponse - if err := binary.Read(reader, binary.BigEndian, &part.Partition); err != nil { - t.Fatalf("read partition id: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.Offset); err != nil { - t.Fatalf("read offset: %v", err) - } - if version >= 5 { - if err := binary.Read(reader, binary.BigEndian, &part.LeaderEpoch); err != nil { - t.Fatalf("read leader epoch: %v", err) - } - } - part.Metadata = readKafkaNullableString(t, reader) - if err := binary.Read(reader, binary.BigEndian, &part.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - topic.Partitions = append(topic.Partitions, part) - } - resp.Topics = append(resp.Topics, topic) - } - if version >= 2 { - if err := binary.Read(reader, binary.BigEndian, &resp.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - } - return resp -} - -func decodeDescribeGroupsResponse(t *testing.T, payload []byte, version int16) *kmsg.DescribeGroupsResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 5 { - skipTaggedFields(t, reader) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrDescribeGroupsResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode describe groups response: %v", err) - } - return resp -} - -func decodeSyncGroupResponse(t *testing.T, payload []byte, version int16) *kmsg.SyncGroupResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 4 { - skipTaggedFields(t, reader) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrSyncGroupResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode sync group response: %v", err) - } - return resp -} - -func decodeHeartbeatResponse(t *testing.T, payload []byte, version int16) *kmsg.HeartbeatResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 4 { - skipTaggedFields(t, reader) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrHeartbeatResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode heartbeat response: %v", err) - } - return resp -} - -func decodeLeaveGroupResponse(t *testing.T, payload []byte) *protocol.LeaveGroupResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.LeaveGroupResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &resp.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - return resp -} - -func decodeAlterConfigsResponse(t *testing.T, payload []byte, version int16) *protocol.AlterConfigsResponse { - t.Helper() - if version != 1 { - t.Fatalf("alter configs decode only supports version 1") - } - reader := bytes.NewReader(payload) - resp := &protocol.AlterConfigsResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle ms: %v", err) - } - var count int32 - if err := binary.Read(reader, binary.BigEndian, &count); err != nil { - t.Fatalf("read resource count: %v", err) - } - resp.Resources = make([]protocol.AlterConfigsResponseResource, 0, count) - for i := 0; i < int(count); i++ { - var code int16 - if err := binary.Read(reader, binary.BigEndian, &code); err != nil { - t.Fatalf("read error code: %v", err) - } - msg := readKafkaNullableString(t, reader) - var rtype int8 - if err := binary.Read(reader, binary.BigEndian, &rtype); err != nil { - t.Fatalf("read resource type: %v", err) - } - name := readKafkaString(t, reader) - resp.Resources = append(resp.Resources, protocol.AlterConfigsResponseResource{ - ErrorCode: code, - ErrorMessage: msg, - ResourceType: rtype, - ResourceName: name, - }) - } - return resp -} - -func readKafkaNullableString(t *testing.T, reader *bytes.Reader) *string { - t.Helper() - var length int16 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - t.Fatalf("read nullable string length: %v", err) - } - if length < 0 { - return nil - } - buf := make([]byte, length) - if _, err := reader.Read(buf); err != nil { - t.Fatalf("read nullable string: %v", err) - } - value := string(buf) - return &value -} diff --git a/cmd/broker/main.go b/cmd/broker/main.go index 1b491b48..971295e8 100644 --- a/cmd/broker/main.go +++ b/cmd/broker/main.go @@ -31,6 +31,8 @@ import ( "syscall" "time" + "github.com/twmb/franz-go/pkg/kmsg" + "github.com/KafScale/platform/pkg/acl" "github.com/KafScale/platform/pkg/broker" "github.com/KafScale/platform/pkg/cache" @@ -48,16 +50,16 @@ import ( ) const ( - defaultKafkaAddr = ":19092" - defaultKafkaPort = 19092 - defaultMetricsAddr = ":19093" - defaultControlAddr = ":19094" + defaultKafkaAddr = ":19092" + defaultKafkaPort = 19092 + defaultMetricsAddr = ":19093" + defaultControlAddr = ":19094" defaultS3Concurrency = 64 - brokerVersion = "dev" + brokerVersion = "dev" ) type handler struct { - apiVersions []protocol.ApiVersion + apiVersions []kmsg.ApiVersionsResponseApiKey store metadata.Store s3 storage.S3Client cache *cache.SegmentCache @@ -99,35 +101,40 @@ type etcdAvailability interface { Available() bool } -func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, req protocol.Request) ([]byte, error) { +func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, req kmsg.Request) ([]byte, error) { if h.traceKafka { h.logger.Debug("received request", "api_key", header.APIKey, "api_version", header.APIVersion, "correlation", header.CorrelationID, "client_id", header.ClientID) } principal := principalFromContext(ctx, header) switch req.(type) { - case *protocol.ApiVersionsRequest: - errorCode := protocol.NONE + case *kmsg.ApiVersionsRequest: + errorCode := int16(protocol.NONE) responseVersion := header.APIVersion if responseVersion > 4 { - errorCode = protocol.UNSUPPORTED_VERSION + errorCode = int16(protocol.UNSUPPORTED_VERSION) responseVersion = 0 } - resp := &protocol.ApiVersionsResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: errorCode, - Versions: h.apiVersions, - } + resp := kmsg.NewPtrApiVersionsResponse() + resp.ErrorCode = errorCode + resp.ApiKeys = h.apiVersions if h.traceKafka { - h.logger.Debug("api versions response", "versions", resp.Versions) + h.logger.Debug("api versions response", "versions", resp.ApiKeys) } - return protocol.EncodeApiVersionsResponse(resp, responseVersion) - case *protocol.MetadataRequest: - metaReq := req.(*protocol.MetadataRequest) + return protocol.EncodeResponse(header.CorrelationID, responseVersion, resp), nil + case *kmsg.MetadataRequest: + metaReq := req.(*kmsg.MetadataRequest) if h.traceKafka { - h.logger.Debug("metadata request", "topics", metaReq.Topics, "topic_ids", len(metaReq.TopicIDs)) + h.logger.Debug("metadata request", "topics", len(metaReq.Topics)) + } + // Extract topic names from request for auto-create and store lookup. + var topicNames []string + for _, t := range metaReq.Topics { + if t.Topic != nil { + topicNames = append(topicNames, *t.Topic) + } } - if h.autoCreateTopics && len(metaReq.Topics) > 0 { - for _, name := range metaReq.Topics { + if h.autoCreateTopics && len(topicNames) > 0 { + for _, name := range topicNames { if strings.TrimSpace(name) == "" { continue } @@ -139,14 +146,14 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re meta, err := func() (*metadata.ClusterMetadata, error) { zeroID := [16]byte{} useIDs := false - for _, id := range metaReq.TopicIDs { - if id != zeroID { + for _, t := range metaReq.Topics { + if t.TopicID != zeroID { useIDs = true break } } if !useIDs { - return h.store.Metadata(ctx, metaReq.Topics) + return h.store.Metadata(ctx, topicNames) } all, err := h.store.Metadata(ctx, nil) if err != nil { @@ -156,17 +163,17 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re for _, topic := range all.Topics { index[topic.TopicID] = topic } - filtered := make([]protocol.MetadataTopic, 0, len(metaReq.TopicIDs)) - for _, id := range metaReq.TopicIDs { - if id == zeroID { + filtered := make([]protocol.MetadataTopic, 0, len(metaReq.Topics)) + for _, t := range metaReq.Topics { + if t.TopicID == zeroID { continue } - if topic, ok := index[id]; ok { + if topic, ok := index[t.TopicID]; ok { filtered = append(filtered, topic) } else { filtered = append(filtered, protocol.MetadataTopic{ ErrorCode: protocol.UNKNOWN_TOPIC_ID, - TopicID: id, + TopicID: t.TopicID, }) } } @@ -180,97 +187,83 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re if err != nil { return nil, fmt.Errorf("load metadata: %w", err) } - resp := &protocol.MetadataResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Brokers: meta.Brokers, - ClusterID: meta.ClusterID, - ControllerID: meta.ControllerID, - Topics: meta.Topics, - } + resp := kmsg.NewPtrMetadataResponse() + resp.Brokers = meta.Brokers + resp.ClusterID = meta.ClusterID + resp.ControllerID = meta.ControllerID + resp.Topics = meta.Topics if h.traceKafka { topicSummaries := make([]string, 0, len(meta.Topics)) for _, topic := range meta.Topics { - topicSummaries = append(topicSummaries, fmt.Sprintf("%s(error=%d partitions=%d)", topic.Name, topic.ErrorCode, len(topic.Partitions))) + topicSummaries = append(topicSummaries, fmt.Sprintf("%s(error=%d partitions=%d)", *topic.Topic, topic.ErrorCode, len(topic.Partitions))) } brokerAddrs := make([]string, 0, len(meta.Brokers)) - for _, broker := range meta.Brokers { - brokerAddrs = append(brokerAddrs, fmt.Sprintf("%s:%d", broker.Host, broker.Port)) + for _, b := range meta.Brokers { + brokerAddrs = append(brokerAddrs, fmt.Sprintf("%s:%d", b.Host, b.Port)) } h.logger.Debug("metadata response", "topics", topicSummaries, "brokers", brokerAddrs) } - return protocol.EncodeMetadataResponse(resp, header.APIVersion) - case *protocol.ProduceRequest: - return h.handleProduce(ctx, header, req.(*protocol.ProduceRequest)) - case *protocol.FetchRequest: - return h.handleFetch(ctx, header, req.(*protocol.FetchRequest)) - case *protocol.FindCoordinatorRequest: + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.ProduceRequest: + return h.handleProduce(ctx, header, req.(*kmsg.ProduceRequest)) + case *kmsg.FetchRequest: + return h.handleFetch(ctx, header, req.(*kmsg.FetchRequest)) + case *kmsg.FindCoordinatorRequest: coord := h.coordinatorBroker(ctx) - resp := &protocol.FindCoordinatorResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.NONE, - NodeID: coord.NodeID, - Host: coord.Host, - Port: coord.Port, - } - return protocol.EncodeFindCoordinatorResponse(resp, header.APIVersion) - case *protocol.JoinGroupRequest: - req := req.(*protocol.JoinGroupRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupWrite) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.GroupID) - return protocol.EncodeJoinGroupResponse(&protocol.JoinGroupResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }, header.APIVersion) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - return protocol.EncodeJoinGroupResponse(&protocol.JoinGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: errCode, - }, header.APIVersion) + resp := kmsg.NewPtrFindCoordinatorResponse() + resp.ErrorCode = protocol.NONE + resp.NodeID = coord.NodeID + resp.Host = coord.Host + resp.Port = coord.Port + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.JoinGroupRequest: + req := req.(*kmsg.JoinGroupRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupWrite) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = errCode + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - return protocol.EncodeJoinGroupResponse(&protocol.JoinGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }, header.APIVersion) + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } - resp, err := h.coordinator.JoinGroup(ctx, req, header.CorrelationID) + resp, err := h.coordinator.JoinGroup(ctx, req) if err != nil { return nil, err } - return protocol.EncodeJoinGroupResponse(resp, header.APIVersion) - case *protocol.SyncGroupRequest: - req := req.(*protocol.SyncGroupRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupWrite) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.GroupID) - return protocol.EncodeSyncGroupResponse(&protocol.SyncGroupResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }, header.APIVersion) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - return protocol.EncodeSyncGroupResponse(&protocol.SyncGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: errCode, - }, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.SyncGroupRequest: + req := req.(*kmsg.SyncGroupRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupWrite) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = errCode + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - return protocol.EncodeSyncGroupResponse(&protocol.SyncGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }, header.APIVersion) + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } - resp, err := h.coordinator.SyncGroup(ctx, req, header.CorrelationID) + resp, err := h.coordinator.SyncGroup(ctx, req) if err != nil { return nil, err } - return protocol.EncodeSyncGroupResponse(resp, header.APIVersion) - case *protocol.DescribeGroupsRequest: - req := req.(*protocol.DescribeGroupsRequest) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.DescribeGroupsRequest: + req := req.(*kmsg.DescribeGroupsRequest) return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { allowed := make([]string, 0, len(req.Groups)) denied := make(map[string]struct{}) @@ -288,312 +281,197 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re allowed = append(allowed, groupID) } - responseByGroup := make(map[string]protocol.DescribeGroupsResponseGroup, len(req.Groups)) + responseByGroup := make(map[string]kmsg.DescribeGroupsResponseGroup, len(req.Groups)) if len(allowed) > 0 { if !h.etcdAvailable() { for _, groupID := range allowed { - responseByGroup[groupID] = protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - GroupID: groupID, - } + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.REQUEST_TIMED_OUT + g.Group = groupID + responseByGroup[groupID] = g } } else { allowedReq := *req allowedReq.Groups = allowed - resp, err := h.coordinator.DescribeGroups(ctx, &allowedReq, header.CorrelationID) + resp, err := h.coordinator.DescribeGroups(ctx, &allowedReq) if err != nil { return nil, err } for _, group := range resp.Groups { - responseByGroup[group.GroupID] = group + responseByGroup[group.Group] = group } } } - results := make([]protocol.DescribeGroupsResponseGroup, 0, len(req.Groups)) + results := make([]kmsg.DescribeGroupsResponseGroup, 0, len(req.Groups)) for _, groupID := range req.Groups { if _, ok := denied[groupID]; ok { - results = append(results, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - GroupID: groupID, - }) + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + g.Group = groupID + results = append(results, g) continue } if errCode, ok := leaseErrors[groupID]; ok { - results = append(results, protocol.DescribeGroupsResponseGroup{ - ErrorCode: errCode, - GroupID: groupID, - }) + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = errCode + g.Group = groupID + results = append(results, g) continue } if group, ok := responseByGroup[groupID]; ok { results = append(results, group) } else { - results = append(results, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - GroupID: groupID, - }) + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + g.Group = groupID + results = append(results, g) } } - return protocol.EncodeDescribeGroupsResponse(&protocol.DescribeGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: results, - }, header.APIVersion) + resp := kmsg.NewPtrDescribeGroupsResponse() + resp.Groups = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil }) - case *protocol.ListGroupsRequest: + case *kmsg.ListGroupsRequest: if !h.allowGroup(principal, "*", acl.ActionGroupRead) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupRead, acl.ResourceGroup, "*") - return protocol.EncodeListGroupsResponse(&protocol.ListGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - Groups: nil, - }, header.APIVersion) + resp := kmsg.NewPtrListGroupsResponse() + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { if !h.etcdAvailable() { - return protocol.EncodeListGroupsResponse(&protocol.ListGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - Groups: nil, - }, header.APIVersion) - } - resp, err := h.coordinator.ListGroups(ctx, req.(*protocol.ListGroupsRequest), header.CorrelationID) + resp := kmsg.NewPtrListGroupsResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + resp, err := h.coordinator.ListGroups(ctx, req.(*kmsg.ListGroupsRequest)) if err != nil { return nil, err } - return protocol.EncodeListGroupsResponse(resp, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil }) - case *protocol.HeartbeatRequest: - req := req.(*protocol.HeartbeatRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupWrite) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.GroupID) - return protocol.EncodeHeartbeatResponse(&protocol.HeartbeatResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }, header.APIVersion) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - return protocol.EncodeHeartbeatResponse(&protocol.HeartbeatResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: errCode, - }, header.APIVersion) + case *kmsg.HeartbeatRequest: + req := req.(*kmsg.HeartbeatRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupWrite) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = errCode + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - return protocol.EncodeHeartbeatResponse(&protocol.HeartbeatResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }, header.APIVersion) - } - resp := h.coordinator.Heartbeat(ctx, req, header.CorrelationID) - return protocol.EncodeHeartbeatResponse(resp, header.APIVersion) - case *protocol.LeaveGroupRequest: - req := req.(*protocol.LeaveGroupRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupWrite) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.GroupID) - return protocol.EncodeLeaveGroupResponse(&protocol.LeaveGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - return protocol.EncodeLeaveGroupResponse(&protocol.LeaveGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: errCode, - }) + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + resp := h.coordinator.Heartbeat(ctx, req) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.LeaveGroupRequest: + req := req.(*kmsg.LeaveGroupRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupWrite) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrLeaveGroupResponse() + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrLeaveGroupResponse() + resp.ErrorCode = errCode + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - return protocol.EncodeLeaveGroupResponse(&protocol.LeaveGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) - } - resp := h.coordinator.LeaveGroup(ctx, req, header.CorrelationID) - return protocol.EncodeLeaveGroupResponse(resp) - case *protocol.OffsetCommitRequest: - req := req.(*protocol.OffsetCommitRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupWrite) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.GroupID) - topics := make([]protocol.OffsetCommitTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetCommitPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetCommitPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }) - } - topics = append(topics, protocol.OffsetCommitTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetCommitResponse(&protocol.OffsetCommitResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - }) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - topics := make([]protocol.OffsetCommitTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetCommitPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetCommitPartitionResponse{ - Partition: part.Partition, - ErrorCode: errCode, - }) - } - topics = append(topics, protocol.OffsetCommitTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetCommitResponse(&protocol.OffsetCommitResponse{ - CorrelationID: header.CorrelationID, - Topics: topics, - }) + resp := kmsg.NewPtrLeaveGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + resp := h.coordinator.LeaveGroup(ctx, req) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.OffsetCommitRequest: + req := req.(*kmsg.OffsetCommitRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupWrite) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupWrite, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrOffsetCommitResponse() + resp.Topics = offsetCommitErrorTopics(req, protocol.GROUP_AUTHORIZATION_FAILED) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrOffsetCommitResponse() + resp.Topics = offsetCommitErrorTopics(req, errCode) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - topics := make([]protocol.OffsetCommitTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetCommitPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetCommitPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) - } - topics = append(topics, protocol.OffsetCommitTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetCommitResponse(&protocol.OffsetCommitResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - }) + resp := kmsg.NewPtrOffsetCommitResponse() + resp.Topics = offsetCommitErrorTopics(req, protocol.REQUEST_TIMED_OUT) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } - resp, err := h.coordinator.OffsetCommit(ctx, req, header.CorrelationID) + resp, err := h.coordinator.OffsetCommit(ctx, req) if err != nil { return nil, err } - return protocol.EncodeOffsetCommitResponse(resp) - case *protocol.OffsetFetchRequest: - req := req.(*protocol.OffsetFetchRequest) - if !h.allowGroup(principal, req.GroupID, acl.ActionGroupRead) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupRead, acl.ResourceGroup, req.GroupID) - topics := make([]protocol.OffsetFetchTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetFetchPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetFetchPartitionResponse{ - Partition: part.Partition, - Offset: -1, - LeaderEpoch: -1, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }) - } - topics = append(topics, protocol.OffsetFetchTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetFetchResponse(&protocol.OffsetFetchResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }, header.APIVersion) - } - if errCode := h.acquireGroupLease(ctx, req.GroupID); errCode != 0 { - topics := make([]protocol.OffsetFetchTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetFetchPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetFetchPartitionResponse{ - Partition: part.Partition, - Offset: -1, - LeaderEpoch: -1, - ErrorCode: errCode, - }) - } - topics = append(topics, protocol.OffsetFetchTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetFetchResponse(&protocol.OffsetFetchResponse{ - CorrelationID: header.CorrelationID, - Topics: topics, - ErrorCode: errCode, - }, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.OffsetFetchRequest: + req := req.(*kmsg.OffsetFetchRequest) + if !h.allowGroup(principal, req.Group, acl.ActionGroupRead) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupRead, acl.ResourceGroup, req.Group) + resp := kmsg.NewPtrOffsetFetchResponse() + resp.Topics = offsetFetchErrorTopics(req, protocol.GROUP_AUTHORIZATION_FAILED) + resp.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + if errCode := h.acquireGroupLease(ctx, req.Group); errCode != 0 { + resp := kmsg.NewPtrOffsetFetchResponse() + resp.Topics = offsetFetchErrorTopics(req, errCode) + resp.ErrorCode = errCode + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { - topics := make([]protocol.OffsetFetchTopicResponse, 0, len(req.Topics)) - for _, topic := range req.Topics { - partitions := make([]protocol.OffsetFetchPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetFetchPartitionResponse{ - Partition: part.Partition, - Offset: -1, - LeaderEpoch: -1, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) - } - topics = append(topics, protocol.OffsetFetchTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return protocol.EncodeOffsetFetchResponse(&protocol.OffsetFetchResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }, header.APIVersion) - } - resp, err := h.coordinator.OffsetFetch(ctx, req, header.CorrelationID) + resp := kmsg.NewPtrOffsetFetchResponse() + resp.Topics = offsetFetchErrorTopics(req, protocol.REQUEST_TIMED_OUT) + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + } + resp, err := h.coordinator.OffsetFetch(ctx, req) if err != nil { return nil, err } - return protocol.EncodeOffsetFetchResponse(resp, header.APIVersion) - case *protocol.OffsetForLeaderEpochRequest: + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.OffsetForLeaderEpochRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - offsetReq := req.(*protocol.OffsetForLeaderEpochRequest) + offsetReq := req.(*kmsg.OffsetForLeaderEpochRequest) if !h.allowTopics(principal, topicsFromOffsetForLeaderEpoch(offsetReq), acl.ActionFetch) { return h.unauthorizedOffsetForLeaderEpoch(principal, header, offsetReq) } return h.handleOffsetForLeaderEpoch(ctx, header, offsetReq) }) - case *protocol.DescribeConfigsRequest: + case *kmsg.DescribeConfigsRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - return h.handleDescribeConfigs(ctx, header, req.(*protocol.DescribeConfigsRequest)) + return h.handleDescribeConfigs(ctx, header, req.(*kmsg.DescribeConfigsRequest)) }) - case *protocol.AlterConfigsRequest: + case *kmsg.AlterConfigsRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - alterReq := req.(*protocol.AlterConfigsRequest) + alterReq := req.(*kmsg.AlterConfigsRequest) if !h.allowAdmin(principal) { return h.unauthorizedAlterConfigs(principal, header, alterReq) } return h.handleAlterConfigs(ctx, header, alterReq) }) - case *protocol.CreatePartitionsRequest: + case *kmsg.CreatePartitionsRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - createReq := req.(*protocol.CreatePartitionsRequest) + createReq := req.(*kmsg.CreatePartitionsRequest) if !h.allowAdmin(principal) { return h.unauthorizedCreatePartitions(principal, header, createReq) } return h.handleCreatePartitions(ctx, header, createReq) }) - case *protocol.DeleteGroupsRequest: + case *kmsg.DeleteGroupsRequest: return h.withAdminMetrics(header.APIKey, func() ([]byte, error) { - deleteReq := req.(*protocol.DeleteGroupsRequest) + deleteReq := req.(*kmsg.DeleteGroupsRequest) allowed := make([]string, 0, len(deleteReq.Groups)) denied := make(map[string]struct{}) for _, groupID := range deleteReq.Groups { @@ -605,19 +483,19 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re } } - responseByGroup := make(map[string]protocol.DeleteGroupsResponseGroup, len(deleteReq.Groups)) + responseByGroup := make(map[string]kmsg.DeleteGroupsResponseGroup, len(deleteReq.Groups)) if len(allowed) > 0 { if !h.etcdAvailable() { for _, groupID := range allowed { - responseByGroup[groupID] = protocol.DeleteGroupsResponseGroup{ - Group: groupID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - } + g := kmsg.NewDeleteGroupsResponseGroup() + g.Group = groupID + g.ErrorCode = protocol.REQUEST_TIMED_OUT + responseByGroup[groupID] = g } } else { allowedReq := *deleteReq allowedReq.Groups = allowed - resp, err := h.coordinator.DeleteGroups(ctx, &allowedReq, header.CorrelationID) + resp, err := h.coordinator.DeleteGroups(ctx, &allowedReq) if err != nil { return nil, err } @@ -627,45 +505,43 @@ func (h *handler) Handle(ctx context.Context, header *protocol.RequestHeader, re } } - results := make([]protocol.DeleteGroupsResponseGroup, 0, len(deleteReq.Groups)) + results := make([]kmsg.DeleteGroupsResponseGroup, 0, len(deleteReq.Groups)) for _, groupID := range deleteReq.Groups { if _, denied := denied[groupID]; denied { - results = append(results, protocol.DeleteGroupsResponseGroup{ - Group: groupID, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }) + g := kmsg.NewDeleteGroupsResponseGroup() + g.Group = groupID + g.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + results = append(results, g) continue } if group, ok := responseByGroup[groupID]; ok { results = append(results, group) } else { - results = append(results, protocol.DeleteGroupsResponseGroup{ - Group: groupID, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - }) + g := kmsg.NewDeleteGroupsResponseGroup() + g.Group = groupID + g.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + results = append(results, g) } } - return protocol.EncodeDeleteGroupsResponse(&protocol.DeleteGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteGroupsResponse() + resp.Groups = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil }) - case *protocol.CreateTopicsRequest: - createReq := req.(*protocol.CreateTopicsRequest) + case *kmsg.CreateTopicsRequest: + createReq := req.(*kmsg.CreateTopicsRequest) if !h.allowAdmin(principal) { return h.unauthorizedCreateTopics(principal, header, createReq) } return h.handleCreateTopics(ctx, header, createReq) - case *protocol.DeleteTopicsRequest: - deleteReq := req.(*protocol.DeleteTopicsRequest) + case *kmsg.DeleteTopicsRequest: + deleteReq := req.(*kmsg.DeleteTopicsRequest) if !h.allowAdmin(principal) { return h.unauthorizedDeleteTopics(principal, header, deleteReq) } return h.handleDeleteTopics(ctx, header, deleteReq) - case *protocol.ListOffsetsRequest: - listReq := req.(*protocol.ListOffsetsRequest) + case *kmsg.ListOffsetsRequest: + listReq := req.(*kmsg.ListOffsetsRequest) if !h.allowTopics(principal, topicsFromListOffsets(listReq), acl.ActionFetch) { return h.unauthorizedListOffsets(principal, header, listReq) } @@ -861,171 +737,155 @@ func (h *handler) allowAdmin(principal string) bool { return h.allowCluster(principal, acl.ActionAdmin) } -func topicsFromListOffsets(req *protocol.ListOffsetsRequest) []string { +func topicsFromListOffsets(req *kmsg.ListOffsetsRequest) []string { topics := make([]string, 0, len(req.Topics)) for _, topic := range req.Topics { - topics = append(topics, topic.Name) + topics = append(topics, topic.Topic) } return topics } -func topicsFromOffsetForLeaderEpoch(req *protocol.OffsetForLeaderEpochRequest) []string { +func topicsFromOffsetForLeaderEpoch(req *kmsg.OffsetForLeaderEpochRequest) []string { topics := make([]string, 0, len(req.Topics)) for _, topic := range req.Topics { - topics = append(topics, topic.Name) + topics = append(topics, topic.Topic) } return topics } -func topicsFromCreateTopics(req *protocol.CreateTopicsRequest) []string { +func topicsFromCreateTopics(req *kmsg.CreateTopicsRequest) []string { topics := make([]string, 0, len(req.Topics)) for _, topic := range req.Topics { - topics = append(topics, topic.Name) + topics = append(topics, topic.Topic) } return topics } -func topicsFromCreatePartitions(req *protocol.CreatePartitionsRequest) []string { +func topicsFromCreatePartitions(req *kmsg.CreatePartitionsRequest) []string { topics := make([]string, 0, len(req.Topics)) for _, topic := range req.Topics { - topics = append(topics, topic.Name) + topics = append(topics, topic.Topic) } return topics } -func (h *handler) unauthorizedOffsetForLeaderEpoch(principal string, header *protocol.RequestHeader, req *protocol.OffsetForLeaderEpochRequest) ([]byte, error) { +func (h *handler) unauthorizedOffsetForLeaderEpoch(principal string, header *protocol.RequestHeader, req *kmsg.OffsetForLeaderEpochRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionFetch, acl.ResourceTopic, strings.Join(topicsFromOffsetForLeaderEpoch(req), ",")) - respTopics := make([]protocol.OffsetForLeaderEpochTopicResponse, 0, len(req.Topics)) + respTopics := make([]kmsg.OffsetForLeaderEpochResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - partitions := make([]protocol.OffsetForLeaderEpochPartitionResponse, 0, len(topic.Partitions)) + partitions := make([]kmsg.OffsetForLeaderEpochResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - LeaderEpoch: -1, - EndOffset: -1, - }) + p := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + partitions = append(partitions, p) } - respTopics = append(respTopics, protocol.OffsetForLeaderEpochTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + t := kmsg.NewOffsetForLeaderEpochResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + respTopics = append(respTopics, t) } - return protocol.EncodeOffsetForLeaderEpochResponse(&protocol.OffsetForLeaderEpochResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: respTopics, - }, header.APIVersion) + resp := kmsg.NewPtrOffsetForLeaderEpochResponse() + resp.Topics = respTopics + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedAlterConfigs(principal string, header *protocol.RequestHeader, req *protocol.AlterConfigsRequest) ([]byte, error) { +func (h *handler) unauthorizedAlterConfigs(principal string, header *protocol.RequestHeader, req *kmsg.AlterConfigsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceCluster, "cluster") - resources := make([]protocol.AlterConfigsResponseResource, 0, len(req.Resources)) + resources := make([]kmsg.AlterConfigsResponseResource, 0, len(req.Resources)) for _, resource := range req.Resources { - errorCode := protocol.CLUSTER_AUTHORIZATION_FAILED - if resource.ResourceType == protocol.ConfigResourceTopic { + errorCode := int16(protocol.CLUSTER_AUTHORIZATION_FAILED) + if resource.ResourceType == kmsg.ConfigResourceTypeTopic { errorCode = protocol.TOPIC_AUTHORIZATION_FAILED } - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: errorCode, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewAlterConfigsResponseResource() + r.ErrorCode = errorCode + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) } - return protocol.EncodeAlterConfigsResponse(&protocol.AlterConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, - }, header.APIVersion) + resp := kmsg.NewPtrAlterConfigsResponse() + resp.Resources = resources + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedCreatePartitions(principal string, header *protocol.RequestHeader, req *protocol.CreatePartitionsRequest) ([]byte, error) { +func (h *handler) unauthorizedCreatePartitions(principal string, header *protocol.RequestHeader, req *kmsg.CreatePartitionsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceTopic, strings.Join(topicsFromCreatePartitions(req), ",")) - results := make([]protocol.CreatePartitionsResponseTopic, 0, len(req.Topics)) + results := make([]kmsg.CreatePartitionsResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - results = append(results, protocol.CreatePartitionsResponseTopic{ - Name: topic.Name, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - }) + t := kmsg.NewCreatePartitionsResponseTopic() + t.Topic = topic.Topic + t.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + results = append(results, t) } - return protocol.EncodeCreatePartitionsResponse(&protocol.CreatePartitionsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreatePartitionsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedDeleteGroups(principal string, header *protocol.RequestHeader, req *protocol.DeleteGroupsRequest) ([]byte, error) { +func (h *handler) unauthorizedDeleteGroups(principal string, header *protocol.RequestHeader, req *kmsg.DeleteGroupsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionGroupAdmin, acl.ResourceGroup, strings.Join(req.Groups, ",")) - results := make([]protocol.DeleteGroupsResponseGroup, 0, len(req.Groups)) + results := make([]kmsg.DeleteGroupsResponseGroup, 0, len(req.Groups)) for _, groupID := range req.Groups { - results = append(results, protocol.DeleteGroupsResponseGroup{ - Group: groupID, - ErrorCode: protocol.GROUP_AUTHORIZATION_FAILED, - }) + g := kmsg.NewDeleteGroupsResponseGroup() + g.Group = groupID + g.ErrorCode = protocol.GROUP_AUTHORIZATION_FAILED + results = append(results, g) } - return protocol.EncodeDeleteGroupsResponse(&protocol.DeleteGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteGroupsResponse() + resp.Groups = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedCreateTopics(principal string, header *protocol.RequestHeader, req *protocol.CreateTopicsRequest) ([]byte, error) { +func (h *handler) unauthorizedCreateTopics(principal string, header *protocol.RequestHeader, req *kmsg.CreateTopicsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceTopic, strings.Join(topicsFromCreateTopics(req), ",")) - results := make([]protocol.CreateTopicResult, 0, len(req.Topics)) + results := make([]kmsg.CreateTopicsResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - results = append(results, protocol.CreateTopicResult{ - Name: topic.Name, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - ErrorMessage: "unauthorized", - }) + t := kmsg.NewCreateTopicsResponseTopic() + t.Topic = topic.Topic + t.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + t.ErrorMessage = kmsg.StringPtr("unauthorized") + results = append(results, t) } - return protocol.EncodeCreateTopicsResponse(&protocol.CreateTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreateTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedDeleteTopics(principal string, header *protocol.RequestHeader, req *protocol.DeleteTopicsRequest) ([]byte, error) { +func (h *handler) unauthorizedDeleteTopics(principal string, header *protocol.RequestHeader, req *kmsg.DeleteTopicsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceTopic, strings.Join(req.TopicNames, ",")) - results := make([]protocol.DeleteTopicResult, 0, len(req.TopicNames)) + results := make([]kmsg.DeleteTopicsResponseTopic, 0, len(req.TopicNames)) for _, name := range req.TopicNames { - results = append(results, protocol.DeleteTopicResult{ - Name: name, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - ErrorMessage: "unauthorized", - }) + t := kmsg.NewDeleteTopicsResponseTopic() + t.Topic = kmsg.StringPtr(name) + t.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + t.ErrorMessage = kmsg.StringPtr("unauthorized") + results = append(results, t) } - return protocol.EncodeDeleteTopicsResponse(&protocol.DeleteTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) unauthorizedListOffsets(principal string, header *protocol.RequestHeader, req *protocol.ListOffsetsRequest) ([]byte, error) { +func (h *handler) unauthorizedListOffsets(principal string, header *protocol.RequestHeader, req *kmsg.ListOffsetsRequest) ([]byte, error) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionFetch, acl.ResourceTopic, strings.Join(topicsFromListOffsets(req), ",")) - topicResponses := make([]protocol.ListOffsetsTopicResponse, 0, len(req.Topics)) + topicResponses := make([]kmsg.ListOffsetsResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - partitions := make([]protocol.ListOffsetsPartitionResponse, 0, len(topic.Partitions)) + partitions := make([]kmsg.ListOffsetsResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { - partitions = append(partitions, protocol.ListOffsetsPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - LeaderEpoch: -1, - }) + p := kmsg.NewListOffsetsResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + partitions = append(partitions, p) } - topicResponses = append(topicResponses, protocol.ListOffsetsTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + t := kmsg.NewListOffsetsResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + topicResponses = append(topicResponses, t) } - return protocol.EncodeListOffsetsResponse(header.APIVersion, &protocol.ListOffsetsResponse{ - CorrelationID: header.CorrelationID, - Topics: topicResponses, - }) + resp := kmsg.NewPtrListOffsetsResponse() + resp.Topics = topicResponses + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } // acquireGroupLease attempts to acquire the coordination lease for the given @@ -1052,7 +912,7 @@ func (h *handler) acquireGroupLease(ctx context.Context, groupID string) int16 { // acquirePartitionLeases acquires leases for all partitions in the request // concurrently. Returns a map of partition -> error for partitions that failed. // Partitions already owned by this broker complete instantly (map lookup). -func (h *handler) acquirePartitionLeases(ctx context.Context, req *protocol.ProduceRequest) map[metadata.PartitionID]error { +func (h *handler) acquirePartitionLeases(ctx context.Context, req *kmsg.ProduceRequest) map[metadata.PartitionID]error { if h.leaseManager == nil { return nil } @@ -1060,7 +920,7 @@ func (h *handler) acquirePartitionLeases(ctx context.Context, req *protocol.Prod for _, topic := range req.Topics { for _, part := range topic.Partitions { partitions = append(partitions, metadata.PartitionID{ - Topic: topic.Name, + Topic: topic.Topic, Partition: part.Partition, }) } @@ -1081,12 +941,50 @@ func (h *handler) acquirePartitionLeases(ctx context.Context, req *protocol.Prod return errs } -func (h *handler) handleProduce(ctx context.Context, header *protocol.RequestHeader, req *protocol.ProduceRequest) ([]byte, error) { +func offsetCommitErrorTopics(req *kmsg.OffsetCommitRequest, errorCode int16) []kmsg.OffsetCommitResponseTopic { + topics := make([]kmsg.OffsetCommitResponseTopic, 0, len(req.Topics)) + for _, topic := range req.Topics { + partitions := make([]kmsg.OffsetCommitResponseTopicPartition, 0, len(topic.Partitions)) + for _, part := range topic.Partitions { + p := kmsg.NewOffsetCommitResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = errorCode + partitions = append(partitions, p) + } + t := kmsg.NewOffsetCommitResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + topics = append(topics, t) + } + return topics +} + +func offsetFetchErrorTopics(req *kmsg.OffsetFetchRequest, errorCode int16) []kmsg.OffsetFetchResponseTopic { + topics := make([]kmsg.OffsetFetchResponseTopic, 0, len(req.Topics)) + for _, topic := range req.Topics { + partitions := make([]kmsg.OffsetFetchResponseTopicPartition, 0, len(topic.Partitions)) + for _, partID := range topic.Partitions { + p := kmsg.NewOffsetFetchResponseTopicPartition() + p.Partition = partID + p.Offset = -1 + p.LeaderEpoch = -1 + p.ErrorCode = errorCode + partitions = append(partitions, p) + } + t := kmsg.NewOffsetFetchResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + topics = append(topics, t) + } + return topics +} + +func (h *handler) handleProduce(ctx context.Context, header *protocol.RequestHeader, req *kmsg.ProduceRequest) ([]byte, error) { start := time.Now() defer func() { h.recordProduceLatency(time.Since(start)) }() - topicResponses := make([]protocol.ProduceTopicResponse, 0, len(req.Topics)) + topicResponses := make([]kmsg.ProduceResponseTopic, 0, len(req.Topics)) now := time.Now().UnixMilli() var producedMessages int64 principal := principalFromContext(ctx, header) @@ -1095,119 +993,119 @@ func (h *handler) handleProduce(ctx context.Context, header *protocol.RequestHea for _, topic := range req.Topics { if h.traceKafka { - h.logger.Debug("produce request received", "topic", topic.Name, "partitions", len(topic.Partitions), "acks", req.Acks, "timeout_ms", req.TimeoutMs) + h.logger.Debug("produce request received", "topic", topic.Topic, "partitions", len(topic.Partitions), "acks", req.Acks, "timeout_ms", req.TimeoutMillis) } - partitionResponses := make([]protocol.ProducePartitionResponse, 0, len(topic.Partitions)) - if !h.allowTopic(principal, topic.Name, acl.ActionProduce) { - h.recordAuthzDeniedWithPrincipal(principal, acl.ActionProduce, acl.ResourceTopic, topic.Name) + partitionResponses := make([]kmsg.ProduceResponseTopicPartition, 0, len(topic.Partitions)) + if !h.allowTopic(principal, topic.Topic, acl.ActionProduce) { + h.recordAuthzDeniedWithPrincipal(principal, acl.ActionProduce, acl.ResourceTopic, topic.Topic) for _, part := range topic.Partitions { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - }) - } - topicResponses = append(topicResponses, protocol.ProduceTopicResponse{ - Name: topic.Name, - Partitions: partitionResponses, - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + partitionResponses = append(partitionResponses, p) + } + t := kmsg.NewProduceResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitionResponses + topicResponses = append(topicResponses, t) continue } for _, part := range topic.Partitions { if !h.etcdAvailable() { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.REQUEST_TIMED_OUT + partitionResponses = append(partitionResponses, p) if h.traceKafka { - h.logger.Debug("produce rejected due to etcd availability", "topic", topic.Name, "partition", part.Partition) + h.logger.Debug("produce rejected due to etcd availability", "topic", topic.Topic, "partition", part.Partition) } continue } - if leaseErr, hasErr := leaseErrors[metadata.PartitionID{Topic: topic.Name, Partition: part.Partition}]; hasErr { + if leaseErr, hasErr := leaseErrors[metadata.PartitionID{Topic: topic.Topic, Partition: part.Partition}]; hasErr { if errors.Is(leaseErr, metadata.ErrNotOwner) || errors.Is(leaseErr, metadata.ErrShuttingDown) { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.NOT_LEADER_OR_FOLLOWER, - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.NOT_LEADER_OR_FOLLOWER + partitionResponses = append(partitionResponses, p) if h.traceKafka { - h.logger.Debug("produce rejected: not partition owner", "topic", topic.Name, "partition", part.Partition) + h.logger.Debug("produce rejected: not partition owner", "topic", topic.Topic, "partition", part.Partition) } continue } - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) - h.logger.Warn("partition lease acquire failed", "topic", topic.Name, "partition", part.Partition, "error", leaseErr) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.REQUEST_TIMED_OUT + partitionResponses = append(partitionResponses, p) + h.logger.Warn("partition lease acquire failed", "topic", topic.Topic, "partition", part.Partition, "error", leaseErr) continue } if h.s3Health.State() != broker.S3StateHealthy { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: h.backpressureErrorCode(), - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = h.backpressureErrorCode() + partitionResponses = append(partitionResponses, p) if h.traceKafka { - h.logger.Debug("produce rejected due to S3 health", "topic", topic.Name, "partition", part.Partition, "s3_state", h.s3Health.State()) + h.logger.Debug("produce rejected due to S3 health", "topic", topic.Topic, "partition", part.Partition, "s3_state", h.s3Health.State()) } continue } - plog, err := h.getPartitionLog(ctx, topic.Name, part.Partition) + plog, err := h.getPartitionLog(ctx, topic.Topic, part.Partition) if err != nil { - h.logger.Error("partition log init failed", "error", err, "topic", topic.Name, "partition", part.Partition) - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - }) + h.logger.Error("partition log init failed", "error", err, "topic", topic.Topic, "partition", part.Partition) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + partitionResponses = append(partitionResponses, p) continue } batch, err := storage.NewRecordBatchFromBytes(part.Records) if err != nil { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + partitionResponses = append(partitionResponses, p) if h.traceKafka { - h.logger.Debug("produce record batch decode failed", "topic", topic.Name, "partition", part.Partition, "error", err) + h.logger.Debug("produce record batch decode failed", "topic", topic.Topic, "partition", part.Partition, "error", err) } continue } result, err := plog.AppendBatch(ctx, batch) if err != nil { - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: h.backpressureErrorCode(), - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = h.backpressureErrorCode() + partitionResponses = append(partitionResponses, p) if h.traceKafka { - h.logger.Debug("produce append failed", "topic", topic.Name, "partition", part.Partition, "error", err) + h.logger.Debug("produce append failed", "topic", topic.Topic, "partition", part.Partition, "error", err) } continue } if req.Acks != 0 && h.flushOnAck { if err := plog.Flush(ctx); err != nil { - h.logger.Error("flush failed", "error", err, "topic", topic.Name, "partition", part.Partition) - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: h.backpressureErrorCode(), - }) + h.logger.Error("flush failed", "error", err, "topic", topic.Topic, "partition", part.Partition) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = h.backpressureErrorCode() + partitionResponses = append(partitionResponses, p) continue } } - partitionResponses = append(partitionResponses, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: 0, - BaseOffset: result.BaseOffset, - LogAppendTimeMs: now, - LogStartOffset: 0, - }) + p := kmsg.NewProduceResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = 0 + p.BaseOffset = result.BaseOffset + p.LogAppendTime = now + p.LogStartOffset = 0 + partitionResponses = append(partitionResponses, p) producedMessages += int64(batch.MessageCount) if h.traceKafka { - h.logger.Debug("produce append success", "topic", topic.Name, "partition", part.Partition, "base_offset", result.BaseOffset, "last_offset", result.LastOffset) + h.logger.Debug("produce append success", "topic", topic.Topic, "partition", part.Partition, "base_offset", result.BaseOffset, "last_offset", result.LastOffset) } } - topicResponses = append(topicResponses, protocol.ProduceTopicResponse{ - Name: topic.Name, - Partitions: partitionResponses, - }) + t := kmsg.NewProduceResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitionResponses + topicResponses = append(topicResponses, t) } if producedMessages > 0 { @@ -1218,145 +1116,134 @@ func (h *handler) handleProduce(ctx context.Context, header *protocol.RequestHea return nil, nil } - return protocol.EncodeProduceResponse(&protocol.ProduceResponse{ - CorrelationID: header.CorrelationID, - Topics: topicResponses, - ThrottleMs: 0, - }, header.APIVersion) + resp := kmsg.NewPtrProduceResponse() + resp.Topics = topicResponses + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleCreateTopics(ctx context.Context, header *protocol.RequestHeader, req *protocol.CreateTopicsRequest) ([]byte, error) { +func (h *handler) handleCreateTopics(ctx context.Context, header *protocol.RequestHeader, req *kmsg.CreateTopicsRequest) ([]byte, error) { if header.APIVersion < 0 || header.APIVersion > 2 { return nil, fmt.Errorf("create topics version %d not supported", header.APIVersion) } - results := make([]protocol.CreateTopicResult, 0, len(req.Topics)) + results := make([]kmsg.CreateTopicsResponseTopic, 0, len(req.Topics)) if !h.allowAdminAPIs { for _, topic := range req.Topics { - results = append(results, protocol.CreateTopicResult{ - Name: topic.Name, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - ErrorMessage: "admin APIs disabled", - }) + t := kmsg.NewCreateTopicsResponseTopic() + t.Topic = topic.Topic + t.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + t.ErrorMessage = kmsg.StringPtr("admin APIs disabled") + results = append(results, t) } - return protocol.EncodeCreateTopicsResponse(&protocol.CreateTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreateTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { for _, topic := range req.Topics { - results = append(results, protocol.CreateTopicResult{ - Name: topic.Name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: "etcd unavailable", - }) + t := kmsg.NewCreateTopicsResponseTopic() + t.Topic = topic.Topic + t.ErrorCode = protocol.REQUEST_TIMED_OUT + t.ErrorMessage = kmsg.StringPtr("etcd unavailable") + results = append(results, t) } - return protocol.EncodeCreateTopicsResponse(&protocol.CreateTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreateTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } for _, topic := range req.Topics { if req.ValidateOnly { err := h.validateCreateTopic(ctx, topic) - result := protocol.CreateTopicResult{Name: topic.Name} + t := kmsg.NewCreateTopicsResponseTopic() + t.Topic = topic.Topic if err != nil { switch { case errors.Is(err, metadata.ErrTopicExists): - result.ErrorCode = protocol.TOPIC_ALREADY_EXISTS + t.ErrorCode = protocol.TOPIC_ALREADY_EXISTS case errors.Is(err, metadata.ErrInvalidTopic): - result.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION + t.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION default: - result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + t.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } - result.ErrorMessage = err.Error() + t.ErrorMessage = kmsg.StringPtr(err.Error()) } - results = append(results, result) + results = append(results, t) continue } _, err := h.store.CreateTopic(ctx, metadata.TopicSpec{ - Name: topic.Name, + Name: topic.Topic, NumPartitions: topic.NumPartitions, ReplicationFactor: topic.ReplicationFactor, }) - result := protocol.CreateTopicResult{Name: topic.Name} + t := kmsg.NewCreateTopicsResponseTopic() + t.Topic = topic.Topic if err != nil { switch { case errors.Is(err, metadata.ErrTopicExists): - result.ErrorCode = protocol.TOPIC_ALREADY_EXISTS + t.ErrorCode = protocol.TOPIC_ALREADY_EXISTS case errors.Is(err, metadata.ErrInvalidTopic): - result.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION + t.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION default: - result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + t.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } - result.ErrorMessage = err.Error() + t.ErrorMessage = kmsg.StringPtr(err.Error()) } - results = append(results, result) + results = append(results, t) } - return protocol.EncodeCreateTopicsResponse(&protocol.CreateTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreateTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleDeleteTopics(ctx context.Context, header *protocol.RequestHeader, req *protocol.DeleteTopicsRequest) ([]byte, error) { +func (h *handler) handleDeleteTopics(ctx context.Context, header *protocol.RequestHeader, req *kmsg.DeleteTopicsRequest) ([]byte, error) { if header.APIVersion < 0 || header.APIVersion > 2 { return nil, fmt.Errorf("delete topics version %d not supported", header.APIVersion) } - results := make([]protocol.DeleteTopicResult, 0, len(req.TopicNames)) + results := make([]kmsg.DeleteTopicsResponseTopic, 0, len(req.TopicNames)) if !h.allowAdminAPIs { for _, name := range req.TopicNames { - results = append(results, protocol.DeleteTopicResult{ - Name: name, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - ErrorMessage: "admin APIs disabled", - }) + t := kmsg.NewDeleteTopicsResponseTopic() + t.Topic = kmsg.StringPtr(name) + t.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + t.ErrorMessage = kmsg.StringPtr("admin APIs disabled") + results = append(results, t) } - return protocol.EncodeDeleteTopicsResponse(&protocol.DeleteTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } if !h.etcdAvailable() { for _, name := range req.TopicNames { - results = append(results, protocol.DeleteTopicResult{ - Name: name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: "etcd unavailable", - }) + t := kmsg.NewDeleteTopicsResponseTopic() + t.Topic = kmsg.StringPtr(name) + t.ErrorCode = protocol.REQUEST_TIMED_OUT + t.ErrorMessage = kmsg.StringPtr("etcd unavailable") + results = append(results, t) } - return protocol.EncodeDeleteTopicsResponse(&protocol.DeleteTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } for _, name := range req.TopicNames { - result := protocol.DeleteTopicResult{Name: name} + t := kmsg.NewDeleteTopicsResponseTopic() + t.Topic = kmsg.StringPtr(name) if err := h.store.DeleteTopic(ctx, name); err != nil { switch { case errors.Is(err, metadata.ErrUnknownTopic): - result.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + t.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION default: - result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + t.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } - result.ErrorMessage = err.Error() + t.ErrorMessage = kmsg.StringPtr(err.Error()) } - results = append(results, result) + results = append(results, t) } - return protocol.EncodeDeleteTopicsResponse(&protocol.DeleteTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrDeleteTopicsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) validateCreateTopic(ctx context.Context, topic protocol.CreateTopicConfig) error { - if topic.Name == "" || topic.NumPartitions <= 0 { +func (h *handler) validateCreateTopic(ctx context.Context, topic kmsg.CreateTopicsRequestTopic) error { + if topic.Topic == "" || topic.NumPartitions <= 0 { return metadata.ErrInvalidTopic } replicationFactor := topic.ReplicationFactor @@ -1368,7 +1255,7 @@ func (h *handler) validateCreateTopic(ctx context.Context, topic protocol.Create return err } for _, existing := range meta.Topics { - if existing.Name == topic.Name { + if *existing.Topic == topic.Topic { return metadata.ErrTopicExists } } @@ -1400,10 +1287,10 @@ const ( configFlushInterval = "kafscale.flush.interval.ms" ) -func (h *handler) handleOffsetForLeaderEpoch(ctx context.Context, header *protocol.RequestHeader, req *protocol.OffsetForLeaderEpochRequest) ([]byte, error) { +func (h *handler) handleOffsetForLeaderEpoch(ctx context.Context, header *protocol.RequestHeader, req *kmsg.OffsetForLeaderEpochRequest) ([]byte, error) { topics := make([]string, 0, len(req.Topics)) for _, topic := range req.Topics { - topics = append(topics, topic.Name) + topics = append(topics, topic.Topic) } meta, err := h.store.Metadata(ctx, topics) if err != nil { @@ -1411,178 +1298,166 @@ func (h *handler) handleOffsetForLeaderEpoch(ctx context.Context, header *protoc } metaIndex := make(map[string]protocol.MetadataTopic, len(meta.Topics)) for _, topic := range meta.Topics { - metaIndex[topic.Name] = topic + metaIndex[*topic.Topic] = topic } - respTopics := make([]protocol.OffsetForLeaderEpochTopicResponse, 0, len(req.Topics)) + respTopics := make([]kmsg.OffsetForLeaderEpochResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - metaTopic, ok := metaIndex[topic.Name] + metaTopic, ok := metaIndex[topic.Topic] partIndex := make(map[int32]protocol.MetadataPartition, len(metaTopic.Partitions)) if ok { for _, part := range metaTopic.Partitions { - partIndex[part.PartitionIndex] = part + partIndex[part.Partition] = part } } - partitions := make([]protocol.OffsetForLeaderEpochPartitionResponse, 0, len(topic.Partitions)) + partitions := make([]kmsg.OffsetForLeaderEpochResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { if !ok { - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, - LeaderEpoch: -1, - EndOffset: -1, - }) + p := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + partitions = append(partitions, p) continue } metaPart, exists := partIndex[part.Partition] if !exists { - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, - LeaderEpoch: -1, - EndOffset: -1, - }) + p := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + partitions = append(partitions, p) continue } - nextOffset, err := h.store.NextOffset(ctx, topic.Name, part.Partition) + nextOffset, err := h.store.NextOffset(ctx, topic.Topic, part.Partition) if err != nil { - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - LeaderEpoch: -1, - EndOffset: -1, - }) + p := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + partitions = append(partitions, p) continue } - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.NONE, - LeaderEpoch: metaPart.LeaderEpoch, - EndOffset: nextOffset, - }) + p := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.NONE + p.LeaderEpoch = metaPart.LeaderEpoch + p.EndOffset = nextOffset + partitions = append(partitions, p) } - respTopics = append(respTopics, protocol.OffsetForLeaderEpochTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + t := kmsg.NewOffsetForLeaderEpochResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + respTopics = append(respTopics, t) } - return protocol.EncodeOffsetForLeaderEpochResponse(&protocol.OffsetForLeaderEpochResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: respTopics, - }, header.APIVersion) + resp := kmsg.NewPtrOffsetForLeaderEpochResponse() + resp.Topics = respTopics + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleDescribeConfigs(ctx context.Context, header *protocol.RequestHeader, req *protocol.DescribeConfigsRequest) ([]byte, error) { - resources := make([]protocol.DescribeConfigsResponseResource, 0, len(req.Resources)) +func (h *handler) handleDescribeConfigs(ctx context.Context, header *protocol.RequestHeader, req *kmsg.DescribeConfigsRequest) ([]byte, error) { + resources := make([]kmsg.DescribeConfigsResponseResource, 0, len(req.Resources)) principal := principalFromContext(ctx, header) for _, resource := range req.Resources { switch resource.ResourceType { - case protocol.ConfigResourceTopic: + case kmsg.ConfigResourceTypeTopic: if !h.allowTopic(principal, resource.ResourceName, acl.ActionFetch) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionFetch, acl.ResourceTopic, resource.ResourceName) - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } cfg, err := h.store.FetchTopicConfig(ctx, resource.ResourceName) if err != nil { - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } configs := h.topicConfigEntries(cfg, resource.ConfigNames) - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.NONE, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - Configs: configs, - }) - case protocol.ConfigResourceBroker: + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.NONE + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + r.Configs = configs + resources = append(resources, r) + case kmsg.ConfigResourceTypeBroker: if !h.allowAdmin(principal) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceCluster, "cluster") - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.CLUSTER_AUTHORIZATION_FAILED, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.CLUSTER_AUTHORIZATION_FAILED + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } configs := h.brokerConfigEntries(resource.ConfigNames) - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.NONE, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - Configs: configs, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.NONE + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + r.Configs = configs + resources = append(resources, r) default: if !h.allowAdmin(principal) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionAdmin, acl.ResourceCluster, "cluster") - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.CLUSTER_AUTHORIZATION_FAILED, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.CLUSTER_AUTHORIZATION_FAILED + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.INVALID_REQUEST, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewDescribeConfigsResponseResource() + r.ErrorCode = protocol.INVALID_REQUEST + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) } } - return protocol.EncodeDescribeConfigsResponse(&protocol.DescribeConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, - }, header.APIVersion) + resp := kmsg.NewPtrDescribeConfigsResponse() + resp.Resources = resources + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleAlterConfigs(ctx context.Context, header *protocol.RequestHeader, req *protocol.AlterConfigsRequest) ([]byte, error) { - resources := make([]protocol.AlterConfigsResponseResource, 0, len(req.Resources)) +func (h *handler) handleAlterConfigs(ctx context.Context, header *protocol.RequestHeader, req *kmsg.AlterConfigsRequest) ([]byte, error) { + resources := make([]kmsg.AlterConfigsResponseResource, 0, len(req.Resources)) if !h.etcdAvailable() { for _, resource := range req.Resources { - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewAlterConfigsResponseResource() + r.ErrorCode = protocol.REQUEST_TIMED_OUT + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) } - return protocol.EncodeAlterConfigsResponse(&protocol.AlterConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, - }, header.APIVersion) + resp := kmsg.NewPtrAlterConfigsResponse() + resp.Resources = resources + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } for _, resource := range req.Resources { - if resource.ResourceType != protocol.ConfigResourceTopic || resource.ResourceName == "" { - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: protocol.INVALID_REQUEST, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + if resource.ResourceType != kmsg.ConfigResourceTypeTopic || resource.ResourceName == "" { + r := kmsg.NewAlterConfigsResponseResource() + r.ErrorCode = protocol.INVALID_REQUEST + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } cfg, err := h.store.FetchTopicConfig(ctx, resource.ResourceName) if err != nil { - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewAlterConfigsResponseResource() + r.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) continue } updated := proto.Clone(cfg).(*metadatapb.TopicConfig) if updated.Config == nil { updated.Config = make(map[string]string) } - errorCode := protocol.NONE + errorCode := int16(protocol.NONE) for _, entry := range resource.Configs { if entry.Value == nil { errorCode = protocol.INVALID_CONFIG @@ -1622,93 +1497,82 @@ func (h *handler) handleAlterConfigs(ctx context.Context, header *protocol.Reque errorCode = protocol.UNKNOWN_SERVER_ERROR } } - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: errorCode, - ResourceType: resource.ResourceType, - ResourceName: resource.ResourceName, - }) + r := kmsg.NewAlterConfigsResponseResource() + r.ErrorCode = errorCode + r.ResourceType = resource.ResourceType + r.ResourceName = resource.ResourceName + resources = append(resources, r) } - return protocol.EncodeAlterConfigsResponse(&protocol.AlterConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, - }, header.APIVersion) + resp := kmsg.NewPtrAlterConfigsResponse() + resp.Resources = resources + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleCreatePartitions(ctx context.Context, header *protocol.RequestHeader, req *protocol.CreatePartitionsRequest) ([]byte, error) { - results := make([]protocol.CreatePartitionsResponseTopic, 0, len(req.Topics)) +func (h *handler) handleCreatePartitions(ctx context.Context, header *protocol.RequestHeader, req *kmsg.CreatePartitionsRequest) ([]byte, error) { + results := make([]kmsg.CreatePartitionsResponseTopic, 0, len(req.Topics)) seen := make(map[string]struct{}, len(req.Topics)) if !h.etcdAvailable() { for _, topic := range req.Topics { - msg := "etcd unavailable" - results = append(results, protocol.CreatePartitionsResponseTopic{ - Name: topic.Name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: &msg, - }) + t := kmsg.NewCreatePartitionsResponseTopic() + t.Topic = topic.Topic + t.ErrorCode = protocol.REQUEST_TIMED_OUT + t.ErrorMessage = kmsg.StringPtr("etcd unavailable") + results = append(results, t) } - return protocol.EncodeCreatePartitionsResponse(&protocol.CreatePartitionsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreatePartitionsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } for _, topic := range req.Topics { - result := protocol.CreatePartitionsResponseTopic{Name: topic.Name} - if strings.TrimSpace(topic.Name) == "" { - result.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION - msg := "invalid topic name" - result.ErrorMessage = &msg - results = append(results, result) + t := kmsg.NewCreatePartitionsResponseTopic() + t.Topic = topic.Topic + if strings.TrimSpace(topic.Topic) == "" { + t.ErrorCode = protocol.INVALID_TOPIC_EXCEPTION + t.ErrorMessage = kmsg.StringPtr("invalid topic name") + results = append(results, t) continue } - if _, ok := seen[topic.Name]; ok { - result.ErrorCode = protocol.INVALID_REQUEST - msg := "duplicate topic in request" - result.ErrorMessage = &msg - results = append(results, result) + if _, ok := seen[topic.Topic]; ok { + t.ErrorCode = protocol.INVALID_REQUEST + t.ErrorMessage = kmsg.StringPtr("duplicate topic in request") + results = append(results, t) continue } - seen[topic.Name] = struct{}{} + seen[topic.Topic] = struct{}{} if topic.Count <= 0 { - result.ErrorCode = protocol.INVALID_PARTITIONS - msg := "invalid partition count" - result.ErrorMessage = &msg - results = append(results, result) + t.ErrorCode = protocol.INVALID_PARTITIONS + t.ErrorMessage = kmsg.StringPtr("invalid partition count") + results = append(results, t) continue } - if len(topic.Assignments) > 0 { - result.ErrorCode = protocol.INVALID_REQUEST - msg := "replica assignment not supported" - result.ErrorMessage = &msg - results = append(results, result) + if len(topic.Assignment) > 0 { + t.ErrorCode = protocol.INVALID_REQUEST + t.ErrorMessage = kmsg.StringPtr("replica assignment not supported") + results = append(results, t) continue } var err error if req.ValidateOnly { - err = h.validateCreatePartitions(ctx, topic.Name, topic.Count) + err = h.validateCreatePartitions(ctx, topic.Topic, topic.Count) } else { - err = h.store.CreatePartitions(ctx, topic.Name, topic.Count) + err = h.store.CreatePartitions(ctx, topic.Topic, topic.Count) } if err != nil { switch { case errors.Is(err, metadata.ErrUnknownTopic): - result.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + t.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION case errors.Is(err, metadata.ErrInvalidTopic): - result.ErrorCode = protocol.INVALID_PARTITIONS + t.ErrorCode = protocol.INVALID_PARTITIONS default: - result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + t.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } - msg := err.Error() - result.ErrorMessage = &msg + t.ErrorMessage = kmsg.StringPtr(err.Error()) } - results = append(results, result) + results = append(results, t) } - return protocol.EncodeCreatePartitionsResponse(&protocol.CreatePartitionsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: results, - }, header.APIVersion) + resp := kmsg.NewPtrCreatePartitionsResponse() + resp.Topics = results + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } func (h *handler) validateCreatePartitions(ctx context.Context, topic string, count int32) error { @@ -1726,31 +1590,31 @@ func (h *handler) validateCreatePartitions(ctx context.Context, topic string, co return nil } -func (h *handler) topicConfigEntries(cfg *metadatapb.TopicConfig, requested []string) []protocol.DescribeConfigsResponseConfig { +func (h *handler) topicConfigEntries(cfg *metadatapb.TopicConfig, requested []string) []kmsg.DescribeConfigsResponseResourceConfig { allow := configNameSet(requested) - entries := make([]protocol.DescribeConfigsResponseConfig, 0, 3) + entries := make([]kmsg.DescribeConfigsResponseResourceConfig, 0, 3) retentionMs, retentionMsDefault := normalizeRetention(cfg.RetentionMs) retentionBytes, retentionBytesDefault := normalizeRetention(cfg.RetentionBytes) segmentBytes, segmentDefault := normalizeSegmentBytes(cfg.SegmentBytes, int64(h.segmentBytes)) - entries = appendConfigEntry(entries, allow, configRetentionMs, retentionMs, retentionMsDefault, protocol.ConfigTypeLong, false) - entries = appendConfigEntry(entries, allow, configRetentionBytes, retentionBytes, retentionBytesDefault, protocol.ConfigTypeLong, false) - entries = appendConfigEntry(entries, allow, configSegmentBytes, segmentBytes, segmentDefault, protocol.ConfigTypeInt, false) + entries = appendConfigEntry(entries, allow, configRetentionMs, retentionMs, retentionMsDefault, kmsg.ConfigTypeLong, false) + entries = appendConfigEntry(entries, allow, configRetentionBytes, retentionBytes, retentionBytesDefault, kmsg.ConfigTypeLong, false) + entries = appendConfigEntry(entries, allow, configSegmentBytes, segmentBytes, segmentDefault, kmsg.ConfigTypeInt, false) return entries } -func (h *handler) brokerConfigEntries(requested []string) []protocol.DescribeConfigsResponseConfig { +func (h *handler) brokerConfigEntries(requested []string) []kmsg.DescribeConfigsResponseResourceConfig { allow := configNameSet(requested) - entries := make([]protocol.DescribeConfigsResponseConfig, 0, 8) - entries = appendConfigEntry(entries, allow, configBrokerID, fmt.Sprintf("%d", h.brokerInfo.NodeID), true, protocol.ConfigTypeInt, true) - entries = appendConfigEntry(entries, allow, configAdvertised, fmt.Sprintf("%s:%d", h.brokerInfo.Host, h.brokerInfo.Port), true, protocol.ConfigTypeString, true) - entries = appendConfigEntry(entries, allow, configS3Bucket, os.Getenv("KAFSCALE_S3_BUCKET"), true, protocol.ConfigTypeString, true) - entries = appendConfigEntry(entries, allow, configS3Region, os.Getenv("KAFSCALE_S3_REGION"), true, protocol.ConfigTypeString, true) - entries = appendConfigEntry(entries, allow, configS3Endpoint, os.Getenv("KAFSCALE_S3_ENDPOINT"), true, protocol.ConfigTypeString, true) - entries = appendConfigEntry(entries, allow, configCacheBytes, fmt.Sprintf("%d", h.cacheSize), true, protocol.ConfigTypeLong, true) - entries = appendConfigEntry(entries, allow, configReadAhead, fmt.Sprintf("%d", h.readAhead), true, protocol.ConfigTypeInt, true) - entries = appendConfigEntry(entries, allow, configSegmentBytesB, fmt.Sprintf("%d", h.segmentBytes), true, protocol.ConfigTypeInt, true) - entries = appendConfigEntry(entries, allow, configFlushInterval, fmt.Sprintf("%d", int64(h.flushInterval/time.Millisecond)), true, protocol.ConfigTypeLong, true) + entries := make([]kmsg.DescribeConfigsResponseResourceConfig, 0, 8) + entries = appendConfigEntry(entries, allow, configBrokerID, fmt.Sprintf("%d", h.brokerInfo.NodeID), true, kmsg.ConfigTypeInt, true) + entries = appendConfigEntry(entries, allow, configAdvertised, fmt.Sprintf("%s:%d", h.brokerInfo.Host, h.brokerInfo.Port), true, kmsg.ConfigTypeString, true) + entries = appendConfigEntry(entries, allow, configS3Bucket, os.Getenv("KAFSCALE_S3_BUCKET"), true, kmsg.ConfigTypeString, true) + entries = appendConfigEntry(entries, allow, configS3Region, os.Getenv("KAFSCALE_S3_REGION"), true, kmsg.ConfigTypeString, true) + entries = appendConfigEntry(entries, allow, configS3Endpoint, os.Getenv("KAFSCALE_S3_ENDPOINT"), true, kmsg.ConfigTypeString, true) + entries = appendConfigEntry(entries, allow, configCacheBytes, fmt.Sprintf("%d", h.cacheSize), true, kmsg.ConfigTypeLong, true) + entries = appendConfigEntry(entries, allow, configReadAhead, fmt.Sprintf("%d", h.readAhead), true, kmsg.ConfigTypeInt, true) + entries = appendConfigEntry(entries, allow, configSegmentBytesB, fmt.Sprintf("%d", h.segmentBytes), true, kmsg.ConfigTypeInt, true) + entries = appendConfigEntry(entries, allow, configFlushInterval, fmt.Sprintf("%d", int64(h.flushInterval/time.Millisecond)), true, kmsg.ConfigTypeLong, true) return entries } @@ -1765,34 +1629,32 @@ func configNameSet(names []string) map[string]struct{} { return set } -func appendConfigEntry(entries []protocol.DescribeConfigsResponseConfig, allow map[string]struct{}, name string, value string, isDefault bool, configType int8, readOnly bool) []protocol.DescribeConfigsResponseConfig { +func appendConfigEntry(entries []kmsg.DescribeConfigsResponseResourceConfig, allow map[string]struct{}, name string, value string, isDefault bool, configType kmsg.ConfigType, readOnly bool) []kmsg.DescribeConfigsResponseResourceConfig { if allow != nil { if _, ok := allow[name]; !ok { return entries } } - val := value - entries = append(entries, protocol.DescribeConfigsResponseConfig{ - Name: name, - Value: &val, - ReadOnly: readOnly, - IsDefault: isDefault, - Source: chooseConfigSource(isDefault, readOnly), - IsSensitive: false, - Synonyms: nil, - ConfigType: configType, - }) + c := kmsg.NewDescribeConfigsResponseResourceConfig() + c.Name = name + c.Value = kmsg.StringPtr(value) + c.ReadOnly = readOnly + c.IsDefault = isDefault + c.Source = chooseConfigSource(isDefault, readOnly) + c.IsSensitive = false + c.ConfigType = configType + entries = append(entries, c) return entries } -func chooseConfigSource(isDefault bool, readOnly bool) int8 { +func chooseConfigSource(isDefault bool, readOnly bool) kmsg.ConfigSource { if isDefault { - return protocol.ConfigSourceDefaultConfig + return kmsg.ConfigSourceDefaultConfig } if readOnly { - return protocol.ConfigSourceStaticBroker + return kmsg.ConfigSourceStaticBrokerConfig } - return protocol.ConfigSourceDynamicTopic + return kmsg.ConfigSourceDynamicTopicConfig } func normalizeRetention(value int64) (string, bool) { @@ -1814,67 +1676,66 @@ func parseConfigInt64(value string) (int64, error) { return strconv.ParseInt(trimmed, 10, 64) } -func (h *handler) handleListOffsets(ctx context.Context, header *protocol.RequestHeader, req *protocol.ListOffsetsRequest) ([]byte, error) { +func (h *handler) handleListOffsets(ctx context.Context, header *protocol.RequestHeader, req *kmsg.ListOffsetsRequest) ([]byte, error) { if header.APIVersion < 0 || header.APIVersion > 4 { return nil, fmt.Errorf("list offsets version %d not supported", header.APIVersion) } if h.traceKafka { h.logger.Debug("list offsets request", "api_version", header.APIVersion, "topics", len(req.Topics)) } - topicResponses := make([]protocol.ListOffsetsTopicResponse, 0, len(req.Topics)) + topicResponses := make([]kmsg.ListOffsetsResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - partitions := make([]protocol.ListOffsetsPartitionResponse, 0, len(topic.Partitions)) + partitions := make([]kmsg.ListOffsetsResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { - h.logger.Warn("list offsets partition", "topic", topic.Name, "partition", part.Partition, "timestamp", part.Timestamp, "max_offsets", part.MaxNumOffsets, "leader_epoch", part.CurrentLeaderEpoch) - resp := protocol.ListOffsetsPartitionResponse{ - Partition: part.Partition, - LeaderEpoch: -1, + if h.traceKafka { + h.logger.Debug("list offsets partition", "topic", topic.Topic, "partition", part.Partition, "timestamp", part.Timestamp, "max_offsets", part.MaxNumOffsets, "leader_epoch", part.CurrentLeaderEpoch) } + p := kmsg.NewListOffsetsResponseTopicPartition() + p.Partition = part.Partition offset, err := func() (int64, error) { switch part.Timestamp { case -2: - plog, err := h.getPartitionLog(ctx, topic.Name, part.Partition) + plog, err := h.getPartitionLog(ctx, topic.Topic, part.Partition) if err != nil { return 0, err } return plog.EarliestOffset(), nil default: - return h.store.NextOffset(ctx, topic.Name, part.Partition) + return h.store.NextOffset(ctx, topic.Topic, part.Partition) } }() if err != nil { - resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + p.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } else { - resp.Timestamp = part.Timestamp - resp.Offset = offset + p.Timestamp = part.Timestamp + p.Offset = offset if header.APIVersion == 0 { max := part.MaxNumOffsets if max <= 0 { max = 1 } - resp.OldStyleOffsets = make([]int64, 0, max) - resp.OldStyleOffsets = append(resp.OldStyleOffsets, offset) + p.OldStyleOffsets = make([]int64, 0, max) + p.OldStyleOffsets = append(p.OldStyleOffsets, offset) } } - partitions = append(partitions, resp) + partitions = append(partitions, p) } - topicResponses = append(topicResponses, protocol.ListOffsetsTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + t := kmsg.NewListOffsetsResponseTopic() + t.Topic = topic.Topic + t.Partitions = partitions + topicResponses = append(topicResponses, t) } - return protocol.EncodeListOffsetsResponse(header.APIVersion, &protocol.ListOffsetsResponse{ - CorrelationID: header.CorrelationID, - Topics: topicResponses, - }) + resp := kmsg.NewPtrListOffsetsResponse() + resp.Topics = topicResponses + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (h *handler) handleFetch(ctx context.Context, header *protocol.RequestHeader, req *protocol.FetchRequest) ([]byte, error) { +func (h *handler) handleFetch(ctx context.Context, header *protocol.RequestHeader, req *kmsg.FetchRequest) ([]byte, error) { if header.APIVersion < 11 || header.APIVersion > 13 { return nil, fmt.Errorf("fetch version %d not supported", header.APIVersion) } - topicResponses := make([]protocol.FetchTopicResponse, 0, len(req.Topics)) - maxWait := time.Duration(req.MaxWaitMs) * time.Millisecond + topicResponses := make([]kmsg.FetchResponseTopic, 0, len(req.Topics)) + maxWait := time.Duration(req.MaxWaitMillis) * time.Millisecond if maxWait < 0 { maxWait = 0 } @@ -1889,86 +1750,85 @@ func (h *handler) handleFetch(ctx context.Context, header *protocol.RequestHeade return nil, fmt.Errorf("load metadata: %w", err) } for _, t := range meta.Topics { - idToName[t.TopicID] = t.Name + idToName[t.TopicID] = *t.Topic } break } } for _, topic := range req.Topics { - topicName := topic.Name + topicName := topic.Topic if topicName == "" && topic.TopicID != zeroID { if resolved, ok := idToName[topic.TopicID]; ok { topicName = resolved } else { - partitionResponses := make([]protocol.FetchPartitionResponse, 0, len(topic.Partitions)) + partitionResponses := make([]kmsg.FetchResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_TOPIC_ID, - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_TOPIC_ID + partitionResponses = append(partitionResponses, p) } - topicResponses = append(topicResponses, protocol.FetchTopicResponse{ - Name: topicName, - TopicID: topic.TopicID, - Partitions: partitionResponses, - }) + t := kmsg.NewFetchResponseTopic() + t.Topic = topicName + t.TopicID = topic.TopicID + t.Partitions = partitionResponses + topicResponses = append(topicResponses, t) continue } } if !h.allowTopic(principal, topicName, acl.ActionFetch) { h.recordAuthzDeniedWithPrincipal(principal, acl.ActionFetch, acl.ResourceTopic, topicName) - partitionResponses := make([]protocol.FetchPartitionResponse, 0, len(topic.Partitions)) + partitionResponses := make([]kmsg.FetchResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.TOPIC_AUTHORIZATION_FAILED, - }) - } - topicResponses = append(topicResponses, protocol.FetchTopicResponse{ - Name: topicName, - TopicID: topic.TopicID, - Partitions: partitionResponses, - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.TOPIC_AUTHORIZATION_FAILED + partitionResponses = append(partitionResponses, p) + } + t := kmsg.NewFetchResponseTopic() + t.Topic = topicName + t.TopicID = topic.TopicID + t.Partitions = partitionResponses + topicResponses = append(topicResponses, t) continue } - partitionResponses := make([]protocol.FetchPartitionResponse, 0, len(topic.Partitions)) + partitionResponses := make([]kmsg.FetchResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { if h.traceKafka { - h.logger.Debug("fetch partition request", "topic", topicName, "partition", part.Partition, "fetch_offset", part.FetchOffset, "max_bytes", part.MaxBytes) + h.logger.Debug("fetch partition request", "topic", topicName, "partition", part.Partition, "fetch_offset", part.FetchOffset, "max_bytes", part.PartitionMaxBytes) } switch h.s3Health.State() { case broker.S3StateDegraded, broker.S3StateUnavailable: - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: h.backpressureErrorCode(), - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = h.backpressureErrorCode() + partitionResponses = append(partitionResponses, p) continue } plog, err := h.getPartitionLog(ctx, topicName, part.Partition) if err != nil { h.logger.Error("fetch partition log failed", "topic", topicName, "partition", part.Partition, "error", err, "etcd_available", h.etcdAvailable(), "s3_state", h.s3Health.State()) - errorCode := protocol.UNKNOWN_SERVER_ERROR + errorCode := int16(protocol.UNKNOWN_SERVER_ERROR) if !h.etcdAvailable() { errorCode = protocol.REQUEST_TIMED_OUT } - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: errorCode, - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = errorCode + partitionResponses = append(partitionResponses, p) continue } nextOffset, offsetErr := h.waitForFetchData(ctx, topicName, part.Partition, part.FetchOffset, maxWait) if offsetErr != nil { if errors.Is(offsetErr, metadata.ErrUnknownTopic) { - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = protocol.UNKNOWN_TOPIC_OR_PARTITION + partitionResponses = append(partitionResponses, p) continue } if errors.Is(offsetErr, context.Canceled) || errors.Is(offsetErr, context.DeadlineExceeded) { - // Treat request timeouts as empty fetches, not server errors. offsetErr = nil } else if !h.etcdAvailable() { nextOffset = part.FetchOffset @@ -1984,10 +1844,9 @@ func (h *handler) handleFetch(ctx context.Context, header *protocol.RequestHeade case part.FetchOffset > nextOffset: errorCode = protocol.OFFSET_OUT_OF_RANGE case part.FetchOffset == nextOffset: - // At the high watermark; Kafka returns an empty set rather than an error. recordSet = nil default: - recordSet, err = plog.Read(ctx, part.FetchOffset, part.MaxBytes) + recordSet, err = plog.Read(ctx, part.FetchOffset, part.PartitionMaxBytes) if err != nil { if errors.Is(err, storage.ErrOffsetOutOfRange) { errorCode = protocol.OFFSET_OUT_OF_RANGE @@ -2011,34 +1870,30 @@ func (h *handler) handleFetch(ctx context.Context, header *protocol.RequestHeade } else if h.traceKafka { h.logger.Debug("fetch partition error", "topic", topicName, "partition", part.Partition, "error_code", errorCode) } - partitionResponses = append(partitionResponses, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: errorCode, - HighWatermark: highWatermark, - LastStableOffset: highWatermark, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: recordSet, - }) + p := kmsg.NewFetchResponseTopicPartition() + p.Partition = part.Partition + p.ErrorCode = errorCode + p.HighWatermark = highWatermark + p.LastStableOffset = highWatermark + p.LogStartOffset = 0 + p.PreferredReadReplica = -1 + p.RecordBatches = recordSet + partitionResponses = append(partitionResponses, p) } - topicResponses = append(topicResponses, protocol.FetchTopicResponse{ - Name: topicName, - TopicID: topic.TopicID, - Partitions: partitionResponses, - }) + t := kmsg.NewFetchResponseTopic() + t.Topic = topicName + t.TopicID = topic.TopicID + t.Partitions = partitionResponses + topicResponses = append(topicResponses, t) } if fetchedMessages > 0 { h.fetchRate.add(fetchedMessages) } - return protocol.EncodeFetchResponse(&protocol.FetchResponse{ - CorrelationID: header.CorrelationID, - Topics: topicResponses, - ThrottleMs: 0, - ErrorCode: 0, - SessionID: 0, - }, header.APIVersion) + resp := kmsg.NewPtrFetchResponse() + resp.Topics = topicResponses + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } const fetchPollInterval = 10 * time.Millisecond @@ -2567,27 +2422,27 @@ func startupTimeoutFromEnv() time.Duration { return time.Duration(parseEnvInt("KAFSCALE_STARTUP_TIMEOUT_SEC", 30)) * time.Second } -func metadataForBroker(broker protocol.MetadataBroker) metadata.ClusterMetadata { +func metadataForBroker(b protocol.MetadataBroker) metadata.ClusterMetadata { clusterID := "kafscale-cluster" return metadata.ClusterMetadata{ - ControllerID: broker.NodeID, + ControllerID: b.NodeID, ClusterID: &clusterID, Brokers: []protocol.MetadataBroker{ - broker, + b, }, Topics: []protocol.MetadataTopic{ { ErrorCode: 0, - Name: "orders", + Topic: kmsg.StringPtr("orders"), TopicID: metadata.TopicIDForName("orders"), IsInternal: false, Partitions: []protocol.MetadataPartition{ { - ErrorCode: 0, - PartitionIndex: 0, - LeaderID: broker.NodeID, - ReplicaNodes: []int32{broker.NodeID}, - ISRNodes: []int32{broker.NodeID}, + ErrorCode: 0, + Partition: 0, + Leader: b.NodeID, + Replicas: []int32{b.NodeID}, + ISR: []int32{b.NodeID}, }, }, }, @@ -2841,7 +2696,7 @@ type apiVersionSupport struct { minVersion, maxVersion int16 } -func generateApiVersions() []protocol.ApiVersion { +func generateApiVersions() []kmsg.ApiVersionsResponseApiKey { supported := []apiVersionSupport{ {key: protocol.APIKeyApiVersion, minVersion: 0, maxVersion: 4}, {key: protocol.APIKeyMetadata, minVersion: 0, maxVersion: 12}, @@ -2871,17 +2726,17 @@ func generateApiVersions() []protocol.ApiVersion { 24, 25, 26, } - entries := make([]protocol.ApiVersion, 0, len(supported)+len(unsupported)) + entries := make([]kmsg.ApiVersionsResponseApiKey, 0, len(supported)+len(unsupported)) for _, entry := range supported { - entries = append(entries, protocol.ApiVersion{ - APIKey: entry.key, + entries = append(entries, kmsg.ApiVersionsResponseApiKey{ + ApiKey: entry.key, MinVersion: entry.minVersion, MaxVersion: entry.maxVersion, }) } for _, key := range unsupported { - entries = append(entries, protocol.ApiVersion{ - APIKey: key, + entries = append(entries, kmsg.ApiVersionsResponseApiKey{ + ApiKey: key, MinVersion: -1, MaxVersion: -1, }) diff --git a/cmd/broker/main_test.go b/cmd/broker/main_test.go index 224fce80..fd87ee84 100644 --- a/cmd/broker/main_test.go +++ b/cmd/broker/main_test.go @@ -50,13 +50,13 @@ func TestHandleProduceAckAll(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -87,13 +87,13 @@ func TestHandleProduceEtcdUnavailable(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(unavailableMetadataStore{Store: store}) - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -124,13 +124,13 @@ func TestHandleProduceAckZero(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - req := &protocol.ProduceRequest{ - Acks: 0, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: 0, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -158,7 +158,7 @@ func TestHandlerApiVersionsUnsupported(t *testing.T) { APIVersion: 5, CorrelationID: 42, } - payload, err := handler.Handle(context.Background(), header, &protocol.ApiVersionsRequest{}) + payload, err := handler.Handle(context.Background(), header, kmsg.NewPtrApiVersionsRequest()) if err != nil { t.Fatalf("handler returned error: %v", err) } @@ -186,13 +186,13 @@ func TestHandleFetch(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - produceReq := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + produceReq := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -205,15 +205,15 @@ func TestHandleFetch(t *testing.T) { t.Fatalf("handleProduce: %v", err) } - fetchReq := &protocol.FetchRequest{ - Topics: []protocol.FetchTopicRequest{ + fetchReq := &kmsg.FetchRequest{ + Topics: []kmsg.FetchRequestTopic{ { - Name: "orders", - Partitions: []protocol.FetchPartitionRequest{ + Topic: "orders", + Partitions: []kmsg.FetchRequestTopicPartition{ { - Partition: 0, - FetchOffset: 0, - MaxBytes: 1024, + Partition: 0, + FetchOffset: 0, + PartitionMaxBytes: 1024, }, }, }, @@ -233,13 +233,13 @@ func TestHandleFetchByTopicID(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - produceReq := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + produceReq := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -252,15 +252,15 @@ func TestHandleFetchByTopicID(t *testing.T) { t.Fatalf("handleProduce: %v", err) } - fetchReq := &protocol.FetchRequest{ - Topics: []protocol.FetchTopicRequest{ + fetchReq := &kmsg.FetchRequest{ + Topics: []kmsg.FetchRequestTopic{ { TopicID: metadata.TopicIDForName("orders"), - Partitions: []protocol.FetchPartitionRequest{ + Partitions: []kmsg.FetchRequestTopicPartition{ { - Partition: 0, - FetchOffset: 0, - MaxBytes: 1024, + Partition: 0, + FetchOffset: 0, + PartitionMaxBytes: 1024, }, }, }, @@ -271,8 +271,11 @@ func TestHandleFetchByTopicID(t *testing.T) { if err != nil { t.Fatalf("handleFetch: %v", err) } - recordSet := decodeFetchResponseV13RecordSet(t, resp) - if len(recordSet) == 0 { + fetchResp := decodeKmsgResponse(t, 13, resp, kmsg.NewPtrFetchResponse) + if len(fetchResp.Topics) == 0 || len(fetchResp.Topics[0].Partitions) == 0 { + t.Fatalf("expected topic partition in fetch response") + } + if len(fetchResp.Topics[0].Partitions[0].RecordBatches) == 0 { t.Fatalf("expected records for topic id fetch") } } @@ -286,13 +289,13 @@ func TestAutoCreateTopicOnProduce(t *testing.T) { }) handler := newTestHandler(store) - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 1000, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 1000, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "auto-created", - Partitions: []protocol.ProducePartition{ + Topic: "auto-created", + Partitions: []kmsg.ProduceRequestTopicPartition{ { Partition: 0, Records: testBatchBytes(0, 0, 1), @@ -323,30 +326,30 @@ func TestAutoCreateTopicOnProduce(t *testing.T) { func TestHandleCreateDeleteTopics(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - createReq := &protocol.CreateTopicsRequest{ - Topics: []protocol.CreateTopicConfig{ - {Name: "payments", NumPartitions: 1, ReplicationFactor: 1}, + createReq := &kmsg.CreateTopicsRequest{ + Topics: []kmsg.CreateTopicsRequestTopic{ + {Topic: "payments", NumPartitions: 1, ReplicationFactor: 1}, }, } respBytes, err := handler.handleCreateTopics(context.Background(), &protocol.RequestHeader{CorrelationID: 42}, createReq) if err != nil { t.Fatalf("handleCreateTopics: %v", err) } - resp := decodeCreateTopicsResponse(t, respBytes, 0) + resp := decodeKmsgResponse(t, 0, respBytes, kmsg.NewPtrCreateTopicsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != protocol.NONE { t.Fatalf("expected topic creation success: %#v", resp) } dupRespBytes, _ := handler.handleCreateTopics(context.Background(), &protocol.RequestHeader{CorrelationID: 43}, createReq) - dupResp := decodeCreateTopicsResponse(t, dupRespBytes, 0) + dupResp := decodeKmsgResponse(t, 0, dupRespBytes, kmsg.NewPtrCreateTopicsResponse) if dupResp.Topics[0].ErrorCode != protocol.TOPIC_ALREADY_EXISTS { t.Fatalf("expected duplicate error got %d", dupResp.Topics[0].ErrorCode) } - deleteReq := &protocol.DeleteTopicsRequest{TopicNames: []string{"payments", "missing"}} + deleteReq := &kmsg.DeleteTopicsRequest{TopicNames: []string{"payments", "missing"}} delBytes, err := handler.handleDeleteTopics(context.Background(), &protocol.RequestHeader{CorrelationID: 44}, deleteReq) if err != nil { t.Fatalf("handleDeleteTopics: %v", err) } - delResp := decodeDeleteTopicsResponse(t, delBytes, 0) + delResp := decodeKmsgResponse(t, 0, delBytes, kmsg.NewPtrDeleteTopicsResponse) if len(delResp.Topics) != 2 || delResp.Topics[0].ErrorCode != protocol.NONE { t.Fatalf("expected delete success, got %#v", delResp) } @@ -358,16 +361,16 @@ func TestHandleCreateDeleteTopics(t *testing.T) { func TestHandleCreateTopicsEtcdUnavailable(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(unavailableMetadataStore{Store: store}) - createReq := &protocol.CreateTopicsRequest{ - Topics: []protocol.CreateTopicConfig{ - {Name: "payments", NumPartitions: 1, ReplicationFactor: 1}, + createReq := &kmsg.CreateTopicsRequest{ + Topics: []kmsg.CreateTopicsRequestTopic{ + {Topic: "payments", NumPartitions: 1, ReplicationFactor: 1}, }, } respBytes, err := handler.handleCreateTopics(context.Background(), &protocol.RequestHeader{CorrelationID: 45}, createReq) if err != nil { t.Fatalf("handleCreateTopics: %v", err) } - resp := decodeCreateTopicsResponse(t, respBytes, 0) + resp := decodeKmsgResponse(t, 0, respBytes, kmsg.NewPtrCreateTopicsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != protocol.REQUEST_TIMED_OUT { t.Fatalf("expected etcd unavailable error, got %#v", resp) } @@ -384,12 +387,12 @@ func TestHandleCreatePartitions(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - req := &protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{ - {Name: "orders", Count: 2}, + req := &kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{ + {Topic: "orders", Count: 2}, }, - TimeoutMs: 1000, - ValidateOnly: false, + TimeoutMillis: 1000, + ValidateOnly: false, } payload, err := handler.handleCreatePartitions(context.Background(), &protocol.RequestHeader{ CorrelationID: 51, @@ -398,7 +401,7 @@ func TestHandleCreatePartitions(t *testing.T) { if err != nil { t.Fatalf("handleCreatePartitions: %v", err) } - resp := decodeCreatePartitionsResponse(t, payload, 3) + resp := decodeKmsgResponse(t, 3, payload, kmsg.NewPtrCreatePartitionsResponse) if len(resp.Topics) != 1 || resp.Topics[0].ErrorCode != 0 { t.Fatalf("unexpected create partitions response: %+v", resp.Topics) } @@ -415,7 +418,7 @@ func TestHandleCreatePartitionsErrors(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - run := func(req *protocol.CreatePartitionsRequest) *kmsg.CreatePartitionsResponse { + run := func(req *kmsg.CreatePartitionsRequest) *kmsg.CreatePartitionsResponse { t.Helper() payload, err := handler.handleCreatePartitions(context.Background(), &protocol.RequestHeader{ CorrelationID: 52, @@ -424,21 +427,21 @@ func TestHandleCreatePartitionsErrors(t *testing.T) { if err != nil { t.Fatalf("handleCreatePartitions: %v", err) } - return decodeCreatePartitionsResponse(t, payload, 3) + return decodeKmsgResponse(t, 3, payload, kmsg.NewPtrCreatePartitionsResponse) } - resp := run(&protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{{Name: "", Count: 2}}, + resp := run(&kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{{Topic: "", Count: 2}}, ValidateOnly: true, }) if resp.Topics[0].ErrorCode != protocol.INVALID_TOPIC_EXCEPTION { t.Fatalf("expected invalid topic error, got %d", resp.Topics[0].ErrorCode) } - resp = run(&protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{ - {Name: "orders", Count: 2}, - {Name: "orders", Count: 3}, + resp = run(&kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{ + {Topic: "orders", Count: 2}, + {Topic: "orders", Count: 3}, }, ValidateOnly: true, }) @@ -449,9 +452,9 @@ func TestHandleCreatePartitionsErrors(t *testing.T) { t.Fatalf("expected duplicate topic error, got %d", resp.Topics[1].ErrorCode) } - resp = run(&protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{ - {Name: "assignments", Count: 2, Assignments: []protocol.CreatePartitionsAssignment{{Replicas: []int32{1}}}}, + resp = run(&kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{ + {Topic: "assignments", Count: 2, Assignment: []kmsg.CreatePartitionsRequestTopicAssignment{{Replicas: []int32{1}}}}, }, ValidateOnly: true, }) @@ -459,16 +462,16 @@ func TestHandleCreatePartitionsErrors(t *testing.T) { t.Fatalf("expected assignment error, got %d", resp.Topics[0].ErrorCode) } - resp = run(&protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{{Name: "orders", Count: 1}}, + resp = run(&kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{{Topic: "orders", Count: 1}}, ValidateOnly: true, }) if resp.Topics[0].ErrorCode != protocol.INVALID_PARTITIONS { t.Fatalf("expected invalid partitions error, got %d", resp.Topics[0].ErrorCode) } - resp = run(&protocol.CreatePartitionsRequest{ - Topics: []protocol.CreatePartitionsTopic{{Name: "missing", Count: 2}}, + resp = run(&kmsg.CreatePartitionsRequest{ + Topics: []kmsg.CreatePartitionsRequestTopic{{Topic: "missing", Count: 2}}, ValidateOnly: true, }) if resp.Topics[0].ErrorCode != protocol.UNKNOWN_TOPIC_OR_PARTITION { @@ -484,7 +487,7 @@ func TestHandleDeleteGroups(t *testing.T) { } handler := newTestHandler(store) - req := &protocol.DeleteGroupsRequest{Groups: []string{"group-1", "missing", ""}} + req := &kmsg.DeleteGroupsRequest{Groups: []string{"group-1", "missing", ""}} header := &protocol.RequestHeader{ APIKey: protocol.APIKeyDeleteGroups, APIVersion: 2, @@ -494,7 +497,7 @@ func TestHandleDeleteGroups(t *testing.T) { if err != nil { t.Fatalf("Handle DeleteGroups: %v", err) } - resp := decodeDeleteGroupsResponse(t, payload, 2) + resp := decodeKmsgResponse(t, 2, payload, kmsg.NewPtrDeleteGroupsResponse) if len(resp.Groups) != 3 { t.Fatalf("expected 3 delete group results, got %d", len(resp.Groups)) } @@ -513,13 +516,13 @@ func TestHandleJoinGroupEtcdUnavailable(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(unavailableMetadataStore{Store: store}) - req := &protocol.JoinGroupRequest{ - GroupID: "group-1", - MemberID: "member-1", - ProtocolType: "consumer", - SessionTimeoutMs: 1000, - RebalanceTimeoutMs: 1000, - Protocols: []protocol.JoinGroupProtocol{ + req := &kmsg.JoinGroupRequest{ + Group: "group-1", + MemberID: "member-1", + ProtocolType: "consumer", + SessionTimeoutMillis: 1000, + RebalanceTimeoutMillis: 1000, + Protocols: []kmsg.JoinGroupRequestProtocol{ {Name: "range", Metadata: []byte("meta")}, }, } @@ -532,7 +535,7 @@ func TestHandleJoinGroupEtcdUnavailable(t *testing.T) { if err != nil { t.Fatalf("Handle JoinGroup: %v", err) } - resp := decodeJoinGroupResponse(t, payload, 4) + resp := decodeKmsgResponse(t, 4, payload, kmsg.NewPtrJoinGroupResponse) if resp.ErrorCode != protocol.REQUEST_TIMED_OUT { t.Fatalf("expected etcd unavailable error, got %d", resp.ErrorCode) } @@ -542,13 +545,13 @@ func TestHandleOffsetCommitEtcdUnavailable(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(unavailableMetadataStore{Store: store}) - req := &protocol.OffsetCommitRequest{ - GroupID: "group-1", + req := &kmsg.OffsetCommitRequest{ + Group: "group-1", MemberID: "member-1", - Topics: []protocol.OffsetCommitTopic{ + Topics: []kmsg.OffsetCommitRequestTopic{ { - Name: "orders", - Partitions: []protocol.OffsetCommitPartition{ + Topic: "orders", + Partitions: []kmsg.OffsetCommitRequestTopicPartition{ {Partition: 0, Offset: 5}, }, }, @@ -563,7 +566,7 @@ func TestHandleOffsetCommitEtcdUnavailable(t *testing.T) { if err != nil { t.Fatalf("Handle OffsetCommit: %v", err) } - resp := decodeOffsetCommitResponse(t, payload, 3) + resp := decodeKmsgResponse(t, 3, payload, kmsg.NewPtrOffsetCommitResponse) if len(resp.Topics) != 1 || len(resp.Topics[0].Partitions) != 1 { t.Fatalf("expected commit response entries, got %+v", resp.Topics) } @@ -598,11 +601,11 @@ func TestHandleListOffsets(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - req := &protocol.ListOffsetsRequest{ - Topics: []protocol.ListOffsetsTopic{ + req := &kmsg.ListOffsetsRequest{ + Topics: []kmsg.ListOffsetsRequestTopic{ { - Name: "orders", - Partitions: []protocol.ListOffsetsPartition{{Partition: 0, Timestamp: tc.timestamp}}, + Topic: "orders", + Partitions: []kmsg.ListOffsetsRequestTopicPartition{{Partition: 0, Timestamp: tc.timestamp}}, }, }, } @@ -614,7 +617,7 @@ func TestHandleListOffsets(t *testing.T) { if err != nil { t.Fatalf("handleListOffsets: %v", err) } - resp := decodeListOffsetsResponse(t, tc.version, respBytes) + resp := decodeKmsgResponse(t, tc.version, respBytes, kmsg.NewPtrListOffsetsResponse) if len(resp.Topics) != 1 || len(resp.Topics[0].Partitions) != 1 { t.Fatalf("unexpected list offsets response: %#v", resp) } @@ -638,11 +641,11 @@ func TestHandleListOffsets(t *testing.T) { func TestHandleListOffsetsRejectsUnsupportedVersion(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - req := &protocol.ListOffsetsRequest{ - Topics: []protocol.ListOffsetsTopic{ + req := &kmsg.ListOffsetsRequest{ + Topics: []kmsg.ListOffsetsRequestTopic{ { - Name: "orders", - Partitions: []protocol.ListOffsetsPartition{{Partition: 0, Timestamp: -1}}, + Topic: "orders", + Partitions: []kmsg.ListOffsetsRequestTopicPartition{{Partition: 0, Timestamp: -1}}, }, }, } @@ -656,17 +659,17 @@ func TestConsumerGroupLifecycle(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - joinReq := &protocol.JoinGroupRequest{ - GroupID: "group-1", - SessionTimeoutMs: 10000, - RebalanceTimeoutMs: 10000, - MemberID: "", - ProtocolType: "consumer", - Protocols: []protocol.JoinGroupProtocol{ + joinReq := &kmsg.JoinGroupRequest{ + Group: "group-1", + SessionTimeoutMillis: 10000, + RebalanceTimeoutMillis: 10000, + MemberID: "", + ProtocolType: "consumer", + Protocols: []kmsg.JoinGroupRequestProtocol{ {Name: "range", Metadata: encodeJoinMetadata([]string{"orders"})}, }, } - joinResp, err := handler.coordinator.JoinGroup(context.Background(), joinReq, 1) + joinResp, err := handler.coordinator.JoinGroup(context.Background(), joinReq) if err != nil { t.Fatalf("JoinGroup: %v", err) } @@ -674,56 +677,54 @@ func TestConsumerGroupLifecycle(t *testing.T) { t.Fatalf("expected member id") } - syncReq := &protocol.SyncGroupRequest{ - GroupID: "group-1", - GenerationID: joinResp.GenerationID, - MemberID: joinResp.MemberID, + syncReq := &kmsg.SyncGroupRequest{ + Group: "group-1", + Generation: joinResp.Generation, + MemberID: joinResp.MemberID, } - syncResp, err := handler.coordinator.SyncGroup(context.Background(), syncReq, 2) + syncResp, err := handler.coordinator.SyncGroup(context.Background(), syncReq) if err != nil { t.Fatalf("SyncGroup: %v", err) } - if len(syncResp.Assignment) == 0 { + if len(syncResp.MemberAssignment) == 0 { t.Fatalf("expected assignment bytes") } - hbResp := handler.coordinator.Heartbeat(context.Background(), &protocol.HeartbeatRequest{ - GroupID: "group-1", - GenerationID: joinResp.GenerationID, - MemberID: joinResp.MemberID, - }, 3) + hbResp := handler.coordinator.Heartbeat(context.Background(), &kmsg.HeartbeatRequest{ + Group: "group-1", + Generation: joinResp.Generation, + MemberID: joinResp.MemberID, + }) if hbResp.ErrorCode != protocol.NONE { t.Fatalf("heartbeat error: %d", hbResp.ErrorCode) } - commitResp, err := handler.coordinator.OffsetCommit(context.Background(), &protocol.OffsetCommitRequest{ - GroupID: "group-1", - GenerationID: joinResp.GenerationID, - MemberID: joinResp.MemberID, - Topics: []protocol.OffsetCommitTopic{ + commitResp, err := handler.coordinator.OffsetCommit(context.Background(), &kmsg.OffsetCommitRequest{ + Group: "group-1", + Generation: joinResp.Generation, + MemberID: joinResp.MemberID, + Topics: []kmsg.OffsetCommitRequestTopic{ { - Name: "orders", - Partitions: []protocol.OffsetCommitPartition{ - {Partition: 0, Offset: 5, Metadata: ""}, + Topic: "orders", + Partitions: []kmsg.OffsetCommitRequestTopicPartition{ + {Partition: 0, Offset: 5, Metadata: kmsg.StringPtr("")}, }, }, }, - }, 4) + }) if err != nil || len(commitResp.Topics) == 0 { t.Fatalf("offset commit failed: %v", err) } - fetchResp, err := handler.coordinator.OffsetFetch(context.Background(), &protocol.OffsetFetchRequest{ - GroupID: "group-1", - Topics: []protocol.OffsetFetchTopic{ + fetchResp, err := handler.coordinator.OffsetFetch(context.Background(), &kmsg.OffsetFetchRequest{ + Group: "group-1", + Topics: []kmsg.OffsetFetchRequestTopic{ { - Name: "orders", - Partitions: []protocol.OffsetFetchPartition{ - {Partition: 0}, - }, + Topic: "orders", + Partitions: []int32{0}, }, }, - }, 5) + }) if err != nil { t.Fatalf("OffsetFetch: %v", err) } @@ -744,13 +745,13 @@ func TestProduceBackpressureDegraded(t *testing.T) { handler := newHandler(store, &failingS3Client{}, brokerInfo, testLogger()) handler.s3Health.RecordUpload(2*time.Millisecond, nil) - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 100, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 100, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ {Partition: 0, Records: testBatchBytes(0, 0, 1)}, }, }, @@ -760,7 +761,7 @@ func TestProduceBackpressureDegraded(t *testing.T) { if err != nil { t.Fatalf("handleProduce: %v", err) } - produceResp := decodeProduceResponse(t, resp, 0) + produceResp := decodeKmsgResponse(t, 0, resp, kmsg.NewPtrProduceResponse) code := produceResp.Topics[0].Partitions[0].ErrorCode if code != protocol.REQUEST_TIMED_OUT { t.Fatalf("expected request timed out error, got %d", code) @@ -773,13 +774,13 @@ func TestProduceBackpressureUnavailable(t *testing.T) { brokerInfo := protocol.MetadataBroker{NodeID: 1, Host: "localhost", Port: 19092} handler := newHandler(store, &failingS3Client{}, brokerInfo, testLogger()) - req := &protocol.ProduceRequest{ - Acks: -1, - TimeoutMs: 100, - Topics: []protocol.ProduceTopic{ + req := &kmsg.ProduceRequest{ + Acks: -1, + TimeoutMillis: 100, + Topics: []kmsg.ProduceRequestTopic{ { - Name: "orders", - Partitions: []protocol.ProducePartition{ + Topic: "orders", + Partitions: []kmsg.ProduceRequestTopicPartition{ {Partition: 0, Records: testBatchBytes(0, 0, 1)}, }, }, @@ -789,7 +790,7 @@ func TestProduceBackpressureUnavailable(t *testing.T) { if err != nil { t.Fatalf("handleProduce: %v", err) } - produceResp := decodeProduceResponse(t, resp, 0) + produceResp := decodeKmsgResponse(t, 0, resp, kmsg.NewPtrProduceResponse) code := produceResp.Topics[0].Partitions[0].ErrorCode if code != protocol.UNKNOWN_SERVER_ERROR { t.Fatalf("expected unknown server error, got %d", code) @@ -803,11 +804,11 @@ func TestFetchBackpressureDegraded(t *testing.T) { handler := newTestHandler(store) handler.s3Health.RecordOperation("download", 2*time.Millisecond, nil) - req := &protocol.FetchRequest{ - Topics: []protocol.FetchTopicRequest{ + req := &kmsg.FetchRequest{ + Topics: []kmsg.FetchRequestTopic{ { - Name: "orders", - Partitions: []protocol.FetchPartitionRequest{{Partition: 0, FetchOffset: 0, MaxBytes: 1024}}, + Topic: "orders", + Partitions: []kmsg.FetchRequestTopicPartition{{Partition: 0, FetchOffset: 0, PartitionMaxBytes: 1024}}, }, }, } @@ -815,7 +816,7 @@ func TestFetchBackpressureDegraded(t *testing.T) { if err != nil { t.Fatalf("handleFetch: %v", err) } - fetchResp := decodeFetchResponse(t, resp) + fetchResp := decodeKmsgResponse(t, 11, resp, kmsg.NewPtrFetchResponse) code := fetchResp.Topics[0].Partitions[0].ErrorCode if code != protocol.REQUEST_TIMED_OUT { t.Fatalf("expected request timed out error, got %d", code) @@ -830,11 +831,11 @@ func TestFetchBackpressureUnavailable(t *testing.T) { handler.s3Health.RecordOperation("download", time.Millisecond, errors.New("boom")) } - req := &protocol.FetchRequest{ - Topics: []protocol.FetchTopicRequest{ + req := &kmsg.FetchRequest{ + Topics: []kmsg.FetchRequestTopic{ { - Name: "orders", - Partitions: []protocol.FetchPartitionRequest{{Partition: 0, FetchOffset: 0, MaxBytes: 1024}}, + Topic: "orders", + Partitions: []kmsg.FetchRequestTopicPartition{{Partition: 0, FetchOffset: 0, PartitionMaxBytes: 1024}}, }, }, } @@ -842,7 +843,7 @@ func TestFetchBackpressureUnavailable(t *testing.T) { if err != nil { t.Fatalf("handleFetch: %v", err) } - fetchResp := decodeFetchResponse(t, resp) + fetchResp := decodeKmsgResponse(t, 11, resp, kmsg.NewPtrFetchResponse) code := fetchResp.Topics[0].Partitions[0].ErrorCode if code != protocol.UNKNOWN_SERVER_ERROR { t.Fatalf("expected unknown server error, got %d", code) @@ -977,14 +978,14 @@ func TestFranzGoProduceConsumeLocal(t *testing.T) { Topics: []protocol.MetadataTopic{ { ErrorCode: 0, - Name: "orders", + Topic: kmsg.StringPtr("orders"), Partitions: []protocol.MetadataPartition{ { - ErrorCode: 0, - PartitionIndex: 0, - LeaderID: 1, - ReplicaNodes: []int32{1}, - ISRNodes: []int32{1}, + ErrorCode: 0, + Partition: 0, + Leader: 1, + Replicas: []int32{1}, + ISR: []int32{1}, }, }, }, @@ -1053,493 +1054,22 @@ func TestFranzGoProduceConsumeLocal(t *testing.T) { } } -func decodeProduceResponse(t *testing.T, payload []byte, version int16) *protocol.ProduceResponse { +// decodeKmsgResponse is a generic test helper that decodes any kmsg.Response +// from a wire-encoded response payload (correlation ID + optional tagged fields + body). +func decodeKmsgResponse[T kmsg.Response](t *testing.T, version int16, payload []byte, newFn func() T) T { t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.ProduceResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.ProduceTopicResponse, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - name := readKafkaString(t, reader) - var partCount int32 - if err := binary.Read(reader, binary.BigEndian, &partCount); err != nil { - t.Fatalf("read partition count: %v", err) - } - topicResp := protocol.ProduceTopicResponse{ - Name: name, - Partitions: make([]protocol.ProducePartitionResponse, 0, partCount), - } - for j := 0; j < int(partCount); j++ { - var part protocol.ProducePartitionResponse - if err := binary.Read(reader, binary.BigEndian, &part.Partition); err != nil { - t.Fatalf("read partition id: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.BaseOffset); err != nil { - t.Fatalf("read base offset: %v", err) - } - if version >= 3 { - if err := binary.Read(reader, binary.BigEndian, &part.LogAppendTimeMs); err != nil { - t.Fatalf("read append time: %v", err) - } - } - if version >= 5 { - if err := binary.Read(reader, binary.BigEndian, &part.LogStartOffset); err != nil { - t.Fatalf("read start offset: %v", err) - } - } - if version >= 8 { - var skip int32 - if err := binary.Read(reader, binary.BigEndian, &skip); err != nil { - t.Fatalf("read delta: %v", err) - } - } - topicResp.Partitions = append(topicResp.Partitions, part) - } - resp.Topics = append(resp.Topics, topicResp) - } - if version >= 1 { - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle: %v", err) - } - } - return resp -} - -func readKafkaString(t *testing.T, reader *bytes.Reader) string { - t.Helper() - var length int16 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - t.Fatalf("read string length: %v", err) - } - if length < 0 { - return "" - } - buf := make([]byte, length) - if _, err := io.ReadFull(reader, buf); err != nil { - t.Fatalf("read string bytes: %v", err) - } - return string(buf) -} - -func decodeCreateTopicsResponse(t *testing.T, payload []byte, version int16) *protocol.CreateTopicsResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.CreateTopicsResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 2 { - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle ms: %v", err) - } - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.CreateTopicResult, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - name := readKafkaString(t, reader) - var code int16 - if err := binary.Read(reader, binary.BigEndian, &code); err != nil { - t.Fatalf("read error code: %v", err) - } - msg := "" - if version >= 1 { - msg = readKafkaString(t, reader) - } - resp.Topics = append(resp.Topics, protocol.CreateTopicResult{Name: name, ErrorCode: code, ErrorMessage: msg}) - } - return resp -} - -func decodeDeleteTopicsResponse(t *testing.T, payload []byte, version int16) *protocol.DeleteTopicsResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.DeleteTopicsResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 1 { - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle ms: %v", err) - } - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.DeleteTopicResult, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - name := readKafkaString(t, reader) - var code int16 - if err := binary.Read(reader, binary.BigEndian, &code); err != nil { - t.Fatalf("read error code: %v", err) - } - resp.Topics = append(resp.Topics, protocol.DeleteTopicResult{Name: name, ErrorCode: code}) - } - return resp -} - -func decodeCreatePartitionsResponse(t *testing.T, payload []byte, version int16) *kmsg.CreatePartitionsResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 2 { - skipTaggedFields(t, reader) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrCreatePartitionsResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode create partitions response: %v", err) - } - return resp -} - -func decodeDeleteGroupsResponse(t *testing.T, payload []byte, version int16) *kmsg.DeleteGroupsResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 2 { - skipTaggedFields(t, reader) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrDeleteGroupsResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode delete groups response: %v", err) - } - return resp -} - -func decodeJoinGroupResponse(t *testing.T, payload []byte, version int16) *kmsg.JoinGroupResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) - } - resp := kmsg.NewPtrJoinGroupResponse() - resp.Version = version - if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode join group response: %v", err) - } - return resp -} - -func decodeOffsetCommitResponse(t *testing.T, payload []byte, version int16) *kmsg.OffsetCommitResponse { - t.Helper() - reader := bytes.NewReader(payload) - var corr int32 - if err := binary.Read(reader, binary.BigEndian, &corr); err != nil { - t.Fatalf("read correlation id: %v", err) - } - body, err := io.ReadAll(reader) - if err != nil { - t.Fatalf("read response body: %v", err) + resp := newFn() + body, ok := protocol.SkipResponseHeader(resp.Key(), version, payload) + if !ok { + t.Fatalf("failed to skip response header for api key %d", resp.Key()) } - resp := kmsg.NewPtrOffsetCommitResponse() - resp.Version = version + resp.SetVersion(version) if err := resp.ReadFrom(body); err != nil { - t.Fatalf("decode offset commit response: %v", err) - } - return resp -} - -func decodeListOffsetsResponse(t *testing.T, version int16, payload []byte) *protocol.ListOffsetsResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.ListOffsetsResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if version >= 2 { - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle: %v", err) - } - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.ListOffsetsTopicResponse, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - topic := protocol.ListOffsetsTopicResponse{} - topic.Name = readKafkaString(t, reader) - var partCount int32 - if err := binary.Read(reader, binary.BigEndian, &partCount); err != nil { - t.Fatalf("read partition count: %v", err) - } - topic.Partitions = make([]protocol.ListOffsetsPartitionResponse, 0, partCount) - for j := 0; j < int(partCount); j++ { - var part protocol.ListOffsetsPartitionResponse - if err := binary.Read(reader, binary.BigEndian, &part.Partition); err != nil { - t.Fatalf("read partition: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - if version == 0 { - var count int32 - if err := binary.Read(reader, binary.BigEndian, &count); err != nil { - t.Fatalf("read offset count: %v", err) - } - part.OldStyleOffsets = make([]int64, count) - for k := 0; k < int(count); k++ { - if err := binary.Read(reader, binary.BigEndian, &part.OldStyleOffsets[k]); err != nil { - t.Fatalf("read offset[%d]: %v", k, err) - } - } - } else { - if err := binary.Read(reader, binary.BigEndian, &part.Timestamp); err != nil { - t.Fatalf("read timestamp: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.Offset); err != nil { - t.Fatalf("read offset: %v", err) - } - if version >= 4 { - if err := binary.Read(reader, binary.BigEndian, &part.LeaderEpoch); err != nil { - t.Fatalf("read leader epoch: %v", err) - } - } - } - topic.Partitions = append(topic.Partitions, part) - } - resp.Topics = append(resp.Topics, topic) + t.Fatalf("decode response for api key %d v%d: %v", resp.Key(), version, err) } return resp } -func decodeFetchResponse(t *testing.T, payload []byte) *protocol.FetchResponse { - t.Helper() - reader := bytes.NewReader(payload) - resp := &protocol.FetchResponse{} - if err := binary.Read(reader, binary.BigEndian, &resp.CorrelationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &resp.ThrottleMs); err != nil { - t.Fatalf("read throttle: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &resp.ErrorCode); err != nil { - t.Fatalf("read fetch error: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &resp.SessionID); err != nil { - t.Fatalf("read fetch session id: %v", err) - } - var topicCount int32 - if err := binary.Read(reader, binary.BigEndian, &topicCount); err != nil { - t.Fatalf("read topic count: %v", err) - } - resp.Topics = make([]protocol.FetchTopicResponse, 0, topicCount) - for i := 0; i < int(topicCount); i++ { - topic := protocol.FetchTopicResponse{} - topic.Name = readKafkaString(t, reader) - var partitionCount int32 - if err := binary.Read(reader, binary.BigEndian, &partitionCount); err != nil { - t.Fatalf("read partition count: %v", err) - } - topic.Partitions = make([]protocol.FetchPartitionResponse, 0, partitionCount) - for j := 0; j < int(partitionCount); j++ { - part := protocol.FetchPartitionResponse{} - if err := binary.Read(reader, binary.BigEndian, &part.Partition); err != nil { - t.Fatalf("read partition: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.ErrorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - if err := binary.Read(reader, binary.BigEndian, &part.HighWatermark); err != nil { - t.Fatalf("read watermark: %v", err) - } - var lastStable int64 - if err := binary.Read(reader, binary.BigEndian, &lastStable); err != nil { - t.Fatalf("read last stable offset: %v", err) - } - var logStart int32 - if err := binary.Read(reader, binary.BigEndian, &logStart); err != nil { - t.Fatalf("read log start offset: %v", err) - } - var recordLen int32 - if err := binary.Read(reader, binary.BigEndian, &recordLen); err != nil { - t.Fatalf("read record length: %v", err) - } - if recordLen > 0 { - buf := make([]byte, recordLen) - if _, err := io.ReadFull(reader, buf); err != nil { - t.Fatalf("read record bytes: %v", err) - } - part.RecordSet = buf - } - topic.Partitions = append(topic.Partitions, part) - } - resp.Topics = append(resp.Topics, topic) - } - return resp -} - -func decodeFetchResponseV13RecordSet(t *testing.T, payload []byte) []byte { - t.Helper() - reader := bytes.NewReader(payload) - var correlationID int32 - if err := binary.Read(reader, binary.BigEndian, &correlationID); err != nil { - t.Fatalf("read correlation id: %v", err) - } - skipTaggedFields(t, reader) - var throttleMs int32 - if err := binary.Read(reader, binary.BigEndian, &throttleMs); err != nil { - t.Fatalf("read throttle: %v", err) - } - var errorCode int16 - if err := binary.Read(reader, binary.BigEndian, &errorCode); err != nil { - t.Fatalf("read error code: %v", err) - } - var sessionID int32 - if err := binary.Read(reader, binary.BigEndian, &sessionID); err != nil { - t.Fatalf("read session id: %v", err) - } - topicCount := readCompactArrayLen(t, reader) - if topicCount <= 0 { - t.Fatalf("expected topic count, got %d", topicCount) - } - if _, err := readUUID(reader); err != nil { - t.Fatalf("read topic id: %v", err) - } - partitionCount := readCompactArrayLen(t, reader) - if partitionCount <= 0 { - t.Fatalf("expected partition count, got %d", partitionCount) - } - var partitionID int32 - if err := binary.Read(reader, binary.BigEndian, &partitionID); err != nil { - t.Fatalf("read partition id: %v", err) - } - var partError int16 - if err := binary.Read(reader, binary.BigEndian, &partError); err != nil { - t.Fatalf("read partition error: %v", err) - } - if partError != 0 { - t.Fatalf("expected partition error 0 got %d", partError) - } - var watermark int64 - if err := binary.Read(reader, binary.BigEndian, &watermark); err != nil { - t.Fatalf("read high watermark: %v", err) - } - var lastStable int64 - if err := binary.Read(reader, binary.BigEndian, &lastStable); err != nil { - t.Fatalf("read last stable: %v", err) - } - var logStart int64 - if err := binary.Read(reader, binary.BigEndian, &logStart); err != nil { - t.Fatalf("read log start: %v", err) - } - abortedCount := readCompactArrayLen(t, reader) - for i := int32(0); i < abortedCount; i++ { - var producerID int64 - if err := binary.Read(reader, binary.BigEndian, &producerID); err != nil { - t.Fatalf("read aborted producer id: %v", err) - } - var firstOffset int64 - if err := binary.Read(reader, binary.BigEndian, &firstOffset); err != nil { - t.Fatalf("read aborted first offset: %v", err) - } - } - var preferredReadReplica int32 - if err := binary.Read(reader, binary.BigEndian, &preferredReadReplica); err != nil { - t.Fatalf("read preferred replica: %v", err) - } - recordSet := readCompactBytes(t, reader) - skipTaggedFields(t, reader) - skipTaggedFields(t, reader) - skipTaggedFields(t, reader) - return recordSet -} - -func readCompactArrayLen(t *testing.T, reader io.ByteReader) int32 { - t.Helper() - val, err := binary.ReadUvarint(reader) - if err != nil { - t.Fatalf("read uvarint: %v", err) - } - if val == 0 { - return -1 - } - return int32(val - 1) -} - -func readCompactBytes(t *testing.T, reader io.Reader) []byte { - t.Helper() - br, ok := reader.(io.ByteReader) - if !ok { - t.Fatalf("reader does not support ReadByte") - } - val, err := binary.ReadUvarint(br) - if err != nil { - t.Fatalf("read compact bytes length: %v", err) - } - if val == 0 { - return nil - } - length := int(val - 1) - buf := make([]byte, length) - if _, err := io.ReadFull(reader, buf); err != nil { - t.Fatalf("read compact bytes: %v", err) - } - return buf -} - -func skipTaggedFields(t *testing.T, reader io.ByteReader) { - t.Helper() - count, err := binary.ReadUvarint(reader) - if err != nil { - t.Fatalf("read tagged field count: %v", err) - } - for i := uint64(0); i < count; i++ { - if _, err := binary.ReadUvarint(reader); err != nil { - t.Fatalf("read tag: %v", err) - } - size, err := binary.ReadUvarint(reader) - if err != nil { - t.Fatalf("read tag size: %v", err) - } - if size == 0 { - continue - } - if _, err := io.CopyN(io.Discard, reader.(io.Reader), int64(size)); err != nil { - t.Fatalf("skip tag bytes: %v", err) - } - } -} - -func readUUID(reader io.Reader) ([16]byte, error) { - var id [16]byte - _, err := io.ReadFull(reader, id[:]) - return id, err -} - func TestMetricsHandlerExposesS3Health(t *testing.T) { t.Setenv("KAFSCALE_S3_LATENCY_WARN_MS", "1") store := metadata.NewInMemoryStore(defaultMetadata()) @@ -1557,9 +1087,9 @@ func TestMetricsHandlerExposesS3Health(t *testing.T) { func TestMetricsHandlerExposesAdminMetrics(t *testing.T) { store := metadata.NewInMemoryStore(defaultMetadata()) handler := newTestHandler(store) - req := &protocol.DescribeConfigsRequest{ - Resources: []protocol.DescribeConfigsResource{ - {ResourceType: protocol.ConfigResourceTopic, ResourceName: "orders"}, + req := &kmsg.DescribeConfigsRequest{ + Resources: []kmsg.DescribeConfigsRequestResource{ + {ResourceType: kmsg.ConfigResourceTypeTopic, ResourceName: "orders"}, }, } header := &protocol.RequestHeader{ @@ -1774,7 +1304,7 @@ func TestCoordinatorBrokerPrefersMetadata(t *testing.T) { err: errors.New("metadata down"), }, storage.NewMemoryS3Client(), brokerInfo, testLogger()) got = handler.coordinatorBroker(context.Background()) - if got != brokerInfo { + if got.NodeID != brokerInfo.NodeID || got.Host != brokerInfo.Host || got.Port != brokerInfo.Port { t.Fatalf("expected broker info fallback, got %+v", got) } } diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index a8decb27..df4ecad3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -33,6 +33,7 @@ import ( "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" "golang.org/x/sync/singleflight" ) @@ -54,7 +55,7 @@ type proxy struct { cacheTTL time.Duration cacheMu sync.RWMutex cachedBackends []string - apiVersions []protocol.ApiVersion + apiVersions []kmsg.ApiVersionsResponseApiKey router *metadata.PartitionRouter groupRouter *metadata.GroupRouter brokerAddrMu sync.RWMutex @@ -315,6 +316,7 @@ func (p *proxy) cacheFresh() bool { // checkReady uses cached state when fresh, falling back to a live metadata // fetch only when the cache TTL has expired (e.g. no traffic for >60s). +// The fallback uses a short timeout to prevent health probes from blocking. func (p *proxy) checkReady(ctx context.Context) bool { if len(p.backends) > 0 { return true @@ -325,7 +327,9 @@ func (p *proxy) checkReady(ctx context.Context) bool { if p.store == nil { return false } - backends, err := p.currentBackends(ctx) + checkCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + backends, err := p.currentBackends(checkCtx) return err == nil && len(backends) > 0 } @@ -334,6 +338,19 @@ func (p *proxy) initMetadataCache(ctx context.Context) { return } p.refreshMetadataCache(ctx) + // Periodic refresh so topology changes are picked up without a cache miss. + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.refreshMetadataCache(ctx) + } + } + }() } func (p *proxy) startHealthServer(ctx context.Context, addr string) { @@ -380,7 +397,7 @@ func (p *proxy) handleConnection(ctx context.Context, conn net.Conn) { if err != nil { return } - header, _, err := protocol.ParseRequestHeader(frame.Payload) + header, body, err := protocol.ParseRequestHeader(frame.Payload) if err != nil { p.logger.Warn("parse request header failed", "error", err) return @@ -400,7 +417,7 @@ func (p *proxy) handleConnection(ctx context.Context, conn net.Conn) { } if !p.isReady() { - resp, ok, err := p.buildNotReadyResponse(header, frame.Payload) + resp, ok, err := p.buildNotReadyResponse(header, body) if err != nil { p.logger.Warn("not-ready response build failed", "error", err) return @@ -440,7 +457,7 @@ func (p *proxy) handleConnection(ctx context.Context, conn net.Conn) { resp, err := p.handleProduceRouting(ctx, header, frame.Payload, pool) if err != nil { p.logger.Warn("produce routing failed", "error", err) - p.respondBackendError(conn, header, frame.Payload) + p.respondBackendError(conn, header, body) return } if resp == nil { @@ -456,7 +473,7 @@ func (p *proxy) handleConnection(ctx context.Context, conn net.Conn) { resp, err := p.handleFetchRouting(ctx, header, frame.Payload, pool) if err != nil { p.logger.Warn("fetch routing failed", "error", err) - p.respondBackendError(conn, header, frame.Payload) + p.respondBackendError(conn, header, body) return } if err := protocol.WriteFrame(conn, resp); err != nil { @@ -571,7 +588,7 @@ func (p *proxy) handleProduceRouting(ctx context.Context, header *protocol.Reque if err != nil { return p.forwardProduceRaw(ctx, payload, pool) } - produceReq, ok := req.(*protocol.ProduceRequest) + produceReq, ok := req.(*kmsg.ProduceRequest) if !ok || len(produceReq.Topics) == 0 { return p.forwardProduceRaw(ctx, payload, pool) } @@ -604,7 +621,7 @@ func (p *proxy) forwardProduceRaw(ctx context.Context, payload []byte, pool *con // fireAndForgetProduce writes a produce request to backends without reading a // response. Used for acks=0 produces where the Kafka protocol specifies no // server response. -func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.RequestHeader, req *protocol.ProduceRequest, originalPayload []byte, pool *connPool) { +func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.RequestHeader, req *kmsg.ProduceRequest, originalPayload []byte, pool *connPool) { groups := p.groupPartitionsByBroker(ctx, req, nil) for addr, subReq := range groups { @@ -612,12 +629,7 @@ func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.Reque if len(groups) == 1 { payload = originalPayload } else { - encoded, err := protocol.EncodeProduceRequest(header, subReq, header.APIVersion) - if err != nil { - p.logger.Warn("fire-and-forget encode failed", "error", err) - continue - } - payload = encoded + payload = encodeProduceRequest(header, subReq) } conn, targetAddr, err := p.connectForAddr(ctx, addr, nil, pool) @@ -637,14 +649,14 @@ func (p *proxy) fireAndForgetProduce(ctx context.Context, header *protocol.Reque // groupPartitionsByBroker groups topic-partitions by the owning broker's address. // If include is non-nil, only partitions present in the include map are grouped. // Partitions with no known owner are grouped under "" for round-robin fallback. -func (p *proxy) groupPartitionsByBroker(ctx context.Context, req *protocol.ProduceRequest, include map[string]map[int32]bool) map[string]*protocol.ProduceRequest { - groups := make(map[string]*protocol.ProduceRequest) +func (p *proxy) groupPartitionsByBroker(ctx context.Context, req *kmsg.ProduceRequest, include map[string]map[int32]bool) map[string]*kmsg.ProduceRequest { + groups := make(map[string]*kmsg.ProduceRequest) topicIndices := make(map[string]map[string]int) // addr -> topic name -> index in subReq.Topics for _, topic := range req.Topics { var includeParts map[int32]bool if include != nil { - includeParts = include[topic.Name] + includeParts = include[topic.Topic] if len(includeParts) == 0 { continue } @@ -655,25 +667,26 @@ func (p *proxy) groupPartitionsByBroker(ctx context.Context, req *protocol.Produ } addr := "" if p.router != nil { - if ownerID := p.router.LookupOwner(topic.Name, part.Partition); ownerID != "" { + if ownerID := p.router.LookupOwner(topic.Topic, part.Partition); ownerID != "" { addr = p.brokerIDToAddr(ctx, ownerID) } } subReq, ok := groups[addr] if !ok { - subReq = &protocol.ProduceRequest{ - Acks: req.Acks, - TimeoutMs: req.TimeoutMs, - TransactionalID: req.TransactionalID, + subReq = &kmsg.ProduceRequest{ + Version: req.Version, + Acks: req.Acks, + TimeoutMillis: req.TimeoutMillis, + TransactionID: req.TransactionID, } groups[addr] = subReq topicIndices[addr] = make(map[string]int) } - idx, ok := topicIndices[addr][topic.Name] + idx, ok := topicIndices[addr][topic.Topic] if !ok { idx = len(subReq.Topics) - subReq.Topics = append(subReq.Topics, protocol.ProduceTopic{Name: topic.Name}) - topicIndices[addr][topic.Name] = idx + subReq.Topics = append(subReq.Topics, kmsg.ProduceRequestTopic{Topic: topic.Topic}) + topicIndices[addr][topic.Topic] = idx } subReq.Topics[idx].Partitions = append(subReq.Topics[idx].Partitions, part) } @@ -685,19 +698,14 @@ func (p *proxy) groupPartitionsByBroker(ctx context.Context, req *protocol.Produ // concurrently, and merges the responses. If any partitions are rejected with // NOT_LEADER_OR_FOLLOWER, those partitions are retried on a different broker // (up to maxRetries total attempts). -func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHeader, fullReq *protocol.ProduceRequest, originalPayload []byte, groups map[string]*protocol.ProduceRequest, pool *connPool) ([]byte, error) { +func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHeader, fullReq *kmsg.ProduceRequest, originalPayload []byte, groups map[string]*kmsg.ProduceRequest, pool *connPool) ([]byte, error) { const maxRetries = 3 - merged := &protocol.ProduceResponse{ - CorrelationID: header.CorrelationID, - } + merged := &kmsg.ProduceResponse{Version: header.APIVersion} var failedPartitions map[string]map[int32]bool for attempt := 0; attempt < maxRetries; attempt++ { failedPartitions = nil - // Scope triedBackends per attempt so that retries can revisit brokers - // from earlier attempts. Without this, with N brokers all N get excluded - // after the first attempt and subsequent retries always fail to connect. triedBackends := make(map[string]bool) subResults := p.fanOutProduce(ctx, header, groups, originalPayload, triedBackends, pool) @@ -716,30 +724,30 @@ func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHead if failedPartitions == nil { failedPartitions = make(map[string]map[int32]bool) } - if failedPartitions[topic.Name] == nil { - failedPartitions[topic.Name] = make(map[int32]bool) + if failedPartitions[topic.Topic] == nil { + failedPartitions[topic.Topic] = make(map[int32]bool) } - failedPartitions[topic.Name][part.Partition] = true + failedPartitions[topic.Topic][part.Partition] = true if p.router != nil { - p.router.Invalidate(topic.Name, part.Partition) + p.router.Invalidate(topic.Topic, part.Partition) } } else { - tr := findOrAddTopicResponse(merged, topic.Name) + tr := findOrAddTopicResponse(merged, topic.Topic) tr.Partitions = append(tr.Partitions, part) } } } - if r.subResp.ThrottleMs > merged.ThrottleMs { - merged.ThrottleMs = r.subResp.ThrottleMs + if r.subResp.ThrottleMillis > merged.ThrottleMillis { + merged.ThrottleMillis = r.subResp.ThrottleMillis } } if len(failedPartitions) == 0 { - return protocol.EncodeProduceResponse(merged, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, merged), nil } groups = p.groupPartitionsByBroker(ctx, fullReq, failedPartitions) - originalPayload = nil // force re-encoding on retry + originalPayload = nil if len(groups) == 0 { break } @@ -747,14 +755,14 @@ func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHead } for _, topic := range fullReq.Topics { - failedParts, ok := failedPartitions[topic.Name] + failedParts, ok := failedPartitions[topic.Topic] if !ok { continue } - tr := findOrAddTopicResponse(merged, topic.Name) + tr := findOrAddTopicResponse(merged, topic.Topic) for _, part := range topic.Partitions { if failedParts[part.Partition] { - tr.Partitions = append(tr.Partitions, protocol.ProducePartitionResponse{ + tr.Partitions = append(tr.Partitions, kmsg.ProduceResponseTopicPartition{ Partition: part.Partition, ErrorCode: protocol.NOT_LEADER_OR_FOLLOWER, BaseOffset: -1, @@ -762,12 +770,12 @@ func (p *proxy) forwardProduce(ctx context.Context, header *protocol.RequestHead } } } - return protocol.EncodeProduceResponse(merged, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, merged), nil } type fanOutResult struct { - subReq *protocol.ProduceRequest - subResp *protocol.ProduceResponse + subReq *kmsg.ProduceRequest + subResp *kmsg.ProduceResponse conn net.Conn // non-nil on success; caller must Return or Close target string err error @@ -776,9 +784,9 @@ type fanOutResult struct { // fanOutProduce borrows connections and forwards sub-requests concurrently. // When there's only one group and originalPayload is non-nil, the original // payload is forwarded as-is (avoiding re-encoding). -func (p *proxy) fanOutProduce(ctx context.Context, header *protocol.RequestHeader, groups map[string]*protocol.ProduceRequest, originalPayload []byte, triedBackends map[string]bool, pool *connPool) []fanOutResult { +func (p *proxy) fanOutProduce(ctx context.Context, header *protocol.RequestHeader, groups map[string]*kmsg.ProduceRequest, originalPayload []byte, triedBackends map[string]bool, pool *connPool) []fanOutResult { type workItem struct { - subReq *protocol.ProduceRequest + subReq *kmsg.ProduceRequest conn net.Conn target string payload []byte @@ -799,13 +807,7 @@ func (p *proxy) fanOutProduce(ctx context.Context, header *protocol.RequestHeade if canUseOriginal { payload = originalPayload } else { - encoded, encErr := protocol.EncodeProduceRequest(header, subReq, header.APIVersion) - if encErr != nil { - conn.Close() - connectErrors = append(connectErrors, fanOutResult{subReq: subReq, target: targetAddr, err: encErr}) - continue - } - payload = encoded + payload = encodeProduceRequest(header, subReq) } work = append(work, workItem{subReq: subReq, conn: conn, target: targetAddr, payload: payload}) } @@ -824,7 +826,7 @@ func (p *proxy) fanOutProduce(ctx context.Context, header *protocol.RequestHeade results[i] = fanOutResult{subReq: w.subReq, target: w.target, err: err} return } - subResp, parseErr := protocol.ParseProduceResponse(respBytes, header.APIVersion) + subResp, parseErr := parseProduceResponse(respBytes, header.APIVersion) if parseErr != nil { w.conn.Close() results[i] = fanOutResult{subReq: w.subReq, target: w.target, err: parseErr} @@ -850,23 +852,23 @@ func (p *proxy) connectForAddr(ctx context.Context, addr string, exclude map[str return p.connectBackendExcluding(ctx, exclude) } -func findOrAddTopicResponse(resp *protocol.ProduceResponse, name string) *protocol.ProduceTopicResponse { +func findOrAddTopicResponse(resp *kmsg.ProduceResponse, name string) *kmsg.ProduceResponseTopic { for i := range resp.Topics { - if resp.Topics[i].Name == name { + if resp.Topics[i].Topic == name { return &resp.Topics[i] } } - resp.Topics = append(resp.Topics, protocol.ProduceTopicResponse{Name: name}) + resp.Topics = append(resp.Topics, kmsg.ProduceResponseTopic{Topic: name}) return &resp.Topics[len(resp.Topics)-1] } // addErrorForAllPartitions fills the response with errorCode for every partition // in the sub-request (used when a broker is unreachable). -func addErrorForAllPartitions(resp *protocol.ProduceResponse, req *protocol.ProduceRequest, errorCode int16) { +func addErrorForAllPartitions(resp *kmsg.ProduceResponse, req *kmsg.ProduceRequest, errorCode int16) { for _, topic := range req.Topics { - topicResp := findOrAddTopicResponse(resp, topic.Name) + topicResp := findOrAddTopicResponse(resp, topic.Topic) for _, part := range topic.Partitions { - topicResp.Partitions = append(topicResp.Partitions, protocol.ProducePartitionResponse{ + topicResp.Partitions = append(topicResp.Partitions, kmsg.ProduceResponseTopicPartition{ Partition: part.Partition, ErrorCode: errorCode, BaseOffset: -1, @@ -929,17 +931,14 @@ func (p *proxy) connectBackendExcluding(ctx context.Context, exclude map[string] } func (p *proxy) handleApiVersions(header *protocol.RequestHeader) ([]byte, error) { - resp := &protocol.ApiVersionsResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.NONE, - ThrottleMs: 0, - Versions: p.apiVersions, - } - return protocol.EncodeApiVersionsResponse(resp, header.APIVersion) + resp := kmsg.NewPtrApiVersionsResponse() + resp.ErrorCode = protocol.NONE + resp.ApiKeys = p.apiVersions + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (p *proxy) respondBackendError(conn net.Conn, header *protocol.RequestHeader, payload []byte) { - resp, ok, err := p.buildNotReadyResponse(header, payload) +func (p *proxy) respondBackendError(conn net.Conn, header *protocol.RequestHeader, body []byte) { + resp, ok, err := p.buildNotReadyResponse(header, body) if err != nil || !ok { return } @@ -951,7 +950,7 @@ func (p *proxy) handleMetadata(ctx context.Context, header *protocol.RequestHead if err != nil { return nil, err } - metaReq, ok := req.(*protocol.MetadataRequest) + metaReq, ok := req.(*kmsg.MetadataRequest) if !ok { return nil, fmt.Errorf("unexpected metadata request type %T", req) } @@ -961,33 +960,35 @@ func (p *proxy) handleMetadata(ctx context.Context, header *protocol.RequestHead return nil, err } resp := buildProxyMetadataResponse(meta, header.CorrelationID, header.APIVersion, p.advertisedHost, p.advertisedPort) - return protocol.EncodeMetadataResponse(resp, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } func (p *proxy) handleFindCoordinator(header *protocol.RequestHeader) ([]byte, error) { - resp := &protocol.FindCoordinatorResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.NONE, - NodeID: 0, - Host: p.advertisedHost, - Port: p.advertisedPort, - ErrorMessage: nil, - } - return protocol.EncodeFindCoordinatorResponse(resp, header.APIVersion) + resp := kmsg.NewPtrFindCoordinatorResponse() + resp.ErrorCode = protocol.NONE + resp.NodeID = 0 + resp.Host = p.advertisedHost + resp.Port = p.advertisedPort + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil } -func (p *proxy) loadMetadata(ctx context.Context, req *protocol.MetadataRequest) (*metadata.ClusterMetadata, error) { +func (p *proxy) loadMetadata(ctx context.Context, req *kmsg.MetadataRequest) (*metadata.ClusterMetadata, error) { + var zeroID [16]byte useIDs := false - zeroID := [16]byte{} - for _, id := range req.TopicIDs { - if id != zeroID { - useIDs = true - break + var topicNames []string + if req.Topics != nil { + for _, t := range req.Topics { + if t.TopicID != zeroID { + useIDs = true + break + } + if t.Topic != nil { + topicNames = append(topicNames, *t.Topic) + } } } if !useIDs { - return p.store.Metadata(ctx, req.Topics) + return p.store.Metadata(ctx, topicNames) } all, err := p.store.Metadata(ctx, nil) if err != nil { @@ -997,17 +998,17 @@ func (p *proxy) loadMetadata(ctx context.Context, req *protocol.MetadataRequest) for _, topic := range all.Topics { index[topic.TopicID] = topic } - filtered := make([]protocol.MetadataTopic, 0, len(req.TopicIDs)) - for _, id := range req.TopicIDs { - if id == zeroID { + filtered := make([]protocol.MetadataTopic, 0, len(req.Topics)) + for _, t := range req.Topics { + if t.TopicID == zeroID { continue } - if topic, ok := index[id]; ok { + if topic, ok := index[t.TopicID]; ok { filtered = append(filtered, topic) } else { filtered = append(filtered, protocol.MetadataTopic{ ErrorCode: protocol.UNKNOWN_TOPIC_ID, - TopicID: id, + TopicID: t.TopicID, }) } } @@ -1019,366 +1020,236 @@ func (p *proxy) loadMetadata(ctx context.Context, req *protocol.MetadataRequest) }, nil } -func (p *proxy) buildNotReadyResponse(header *protocol.RequestHeader, payload []byte) ([]byte, bool, error) { - _, req, err := protocol.ParseRequest(payload) +func (p *proxy) buildNotReadyResponse(header *protocol.RequestHeader, body []byte) ([]byte, bool, error) { + _, req, err := protocol.ParseRequestBody(header, body) if err != nil { return nil, false, err } - wrapEncode := func(payload []byte, err error) ([]byte, bool, error) { - return payload, true, err + encode := func(resp kmsg.Response) ([]byte, bool, error) { + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), true, nil } switch header.APIKey { case protocol.APIKeyMetadata: - metaReq := req.(*protocol.MetadataRequest) - topics := make([]protocol.MetadataTopic, 0, len(metaReq.Topics)+len(metaReq.TopicIDs)) - for _, name := range metaReq.Topics { - topics = append(topics, protocol.MetadataTopic{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - Name: name, - }) - } - for _, id := range metaReq.TopicIDs { - topics = append(topics, protocol.MetadataTopic{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - TopicID: id, - }) - } - resp := &protocol.MetadataResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Brokers: nil, - ClusterID: nil, - ControllerID: -1, - Topics: topics, - } - return wrapEncode(protocol.EncodeMetadataResponse(resp, header.APIVersion)) + metaReq := req.(*kmsg.MetadataRequest) + resp := kmsg.NewPtrMetadataResponse() + resp.ControllerID = -1 + for _, t := range metaReq.Topics { + mt := kmsg.NewMetadataResponseTopic() + mt.ErrorCode = protocol.REQUEST_TIMED_OUT + mt.Topic = t.Topic + mt.TopicID = t.TopicID + resp.Topics = append(resp.Topics, mt) + } + return encode(resp) case protocol.APIKeyFindCoordinator: - resp := &protocol.FindCoordinatorResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: nil, - NodeID: -1, - Host: "", - Port: 0, - } - return wrapEncode(protocol.EncodeFindCoordinatorResponse(resp, header.APIVersion)) + resp := kmsg.NewPtrFindCoordinatorResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.NodeID = -1 + return encode(resp) case protocol.APIKeyProduce: - prodReq := req.(*protocol.ProduceRequest) - topics := make([]protocol.ProduceTopicResponse, 0, len(prodReq.Topics)) + prodReq := req.(*kmsg.ProduceRequest) + resp := kmsg.NewPtrProduceResponse() for _, topic := range prodReq.Topics { - partitions := make([]protocol.ProducePartitionResponse, 0, len(topic.Partitions)) + rt := kmsg.NewProduceResponseTopic() + rt.Topic = topic.Topic for _, part := range topic.Partitions { - partitions = append(partitions, protocol.ProducePartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - BaseOffset: -1, - LogAppendTimeMs: -1, - LogStartOffset: -1, - }) + rp := kmsg.NewProduceResponseTopicPartition() + rp.Partition = part.Partition + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rp.BaseOffset = -1 + rp.LogAppendTime = -1 + rp.LogStartOffset = -1 + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.ProduceTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - resp := &protocol.ProduceResponse{ - CorrelationID: header.CorrelationID, - Topics: topics, - ThrottleMs: 0, + resp.Topics = append(resp.Topics, rt) } - return wrapEncode(protocol.EncodeProduceResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyFetch: - fetchReq := req.(*protocol.FetchRequest) - topics := make([]protocol.FetchTopicResponse, 0, len(fetchReq.Topics)) + fetchReq := req.(*kmsg.FetchRequest) + resp := kmsg.NewPtrFetchResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.SessionID = fetchReq.SessionID for _, topic := range fetchReq.Topics { - partitions := make([]protocol.FetchPartitionResponse, 0, len(topic.Partitions)) + rt := kmsg.NewFetchResponseTopic() + rt.Topic = topic.Topic + rt.TopicID = topic.TopicID for _, part := range topic.Partitions { - partitions = append(partitions, protocol.FetchPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - HighWatermark: 0, - }) + rp := kmsg.NewFetchResponseTopicPartition() + rp.Partition = part.Partition + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.FetchTopicResponse{ - Name: topic.Name, - TopicID: topic.TopicID, - Partitions: partitions, - }) - } - resp := &protocol.FetchResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - SessionID: fetchReq.SessionID, - Topics: topics, + resp.Topics = append(resp.Topics, rt) } - return wrapEncode(protocol.EncodeFetchResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyListOffsets: - offsetReq := req.(*protocol.ListOffsetsRequest) - topics := make([]protocol.ListOffsetsTopicResponse, 0, len(offsetReq.Topics)) + offsetReq := req.(*kmsg.ListOffsetsRequest) + resp := kmsg.NewPtrListOffsetsResponse() for _, topic := range offsetReq.Topics { - partitions := make([]protocol.ListOffsetsPartitionResponse, 0, len(topic.Partitions)) + rt := kmsg.NewListOffsetsResponseTopic() + rt.Topic = topic.Topic for _, part := range topic.Partitions { - partitions = append(partitions, protocol.ListOffsetsPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - Timestamp: -1, - Offset: -1, - LeaderEpoch: -1, - OldStyleOffsets: nil, - }) + rp := kmsg.NewListOffsetsResponseTopicPartition() + rp.Partition = part.Partition + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rp.Timestamp = -1 + rp.Offset = -1 + rp.LeaderEpoch = -1 + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.ListOffsetsTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - resp := &protocol.ListOffsetsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, + resp.Topics = append(resp.Topics, rt) } - return wrapEncode(protocol.EncodeListOffsetsResponse(header.APIVersion, resp)) + return encode(resp) case protocol.APIKeyJoinGroup: - resp := &protocol.JoinGroupResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - GenerationID: -1, - ProtocolName: "", - LeaderID: "", - MemberID: "", - Members: nil, - } - return wrapEncode(protocol.EncodeJoinGroupResponse(resp, header.APIVersion)) + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.Generation = -1 + return encode(resp) case protocol.APIKeySyncGroup: - resp := &protocol.SyncGroupResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ProtocolType: nil, - ProtocolName: nil, - Assignment: nil, - } - return wrapEncode(protocol.EncodeSyncGroupResponse(resp, header.APIVersion)) + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return encode(resp) case protocol.APIKeyHeartbeat: - resp := &protocol.HeartbeatResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - } - return wrapEncode(protocol.EncodeHeartbeatResponse(resp, header.APIVersion)) + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return encode(resp) case protocol.APIKeyLeaveGroup: - resp := &protocol.LeaveGroupResponse{ - CorrelationID: header.CorrelationID, - ErrorCode: protocol.REQUEST_TIMED_OUT, - } - return wrapEncode(protocol.EncodeLeaveGroupResponse(resp)) + resp := kmsg.NewPtrLeaveGroupResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return encode(resp) case protocol.APIKeyOffsetCommit: - commitReq := req.(*protocol.OffsetCommitRequest) - topics := make([]protocol.OffsetCommitTopicResponse, 0, len(commitReq.Topics)) + commitReq := req.(*kmsg.OffsetCommitRequest) + resp := kmsg.NewPtrOffsetCommitResponse() for _, topic := range commitReq.Topics { - partitions := make([]protocol.OffsetCommitPartitionResponse, 0, len(topic.Partitions)) + rt := kmsg.NewOffsetCommitResponseTopic() + rt.Topic = topic.Topic for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetCommitPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) + rp := kmsg.NewOffsetCommitResponseTopicPartition() + rp.Partition = part.Partition + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.OffsetCommitTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - resp := &protocol.OffsetCommitResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, + resp.Topics = append(resp.Topics, rt) } - return wrapEncode(protocol.EncodeOffsetCommitResponse(resp)) + return encode(resp) case protocol.APIKeyOffsetFetch: - fetchReq := req.(*protocol.OffsetFetchRequest) - topics := make([]protocol.OffsetFetchTopicResponse, 0, len(fetchReq.Topics)) - for _, topic := range fetchReq.Topics { - partitions := make([]protocol.OffsetFetchPartitionResponse, 0, len(topic.Partitions)) + ofReq := req.(*kmsg.OffsetFetchRequest) + resp := kmsg.NewPtrOffsetFetchResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + for _, topic := range ofReq.Topics { + rt := kmsg.NewOffsetFetchResponseTopic() + rt.Topic = topic.Topic for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetFetchPartitionResponse{ - Partition: part.Partition, - Offset: -1, - LeaderEpoch: -1, - Metadata: nil, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) + rp := kmsg.NewOffsetFetchResponseTopicPartition() + rp.Partition = part + rp.Offset = -1 + rp.LeaderEpoch = -1 + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.OffsetFetchTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + resp.Topics = append(resp.Topics, rt) } - resp := &protocol.OffsetFetchResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - ErrorCode: protocol.REQUEST_TIMED_OUT, - } - return wrapEncode(protocol.EncodeOffsetFetchResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyOffsetForLeaderEpoch: - epochReq := req.(*protocol.OffsetForLeaderEpochRequest) - topics := make([]protocol.OffsetForLeaderEpochTopicResponse, 0, len(epochReq.Topics)) + epochReq := req.(*kmsg.OffsetForLeaderEpochRequest) + resp := kmsg.NewPtrOffsetForLeaderEpochResponse() for _, topic := range epochReq.Topics { - partitions := make([]protocol.OffsetForLeaderEpochPartitionResponse, 0, len(topic.Partitions)) + rt := kmsg.NewOffsetForLeaderEpochResponseTopic() + rt.Topic = topic.Topic for _, part := range topic.Partitions { - partitions = append(partitions, protocol.OffsetForLeaderEpochPartitionResponse{ - Partition: part.Partition, - ErrorCode: protocol.REQUEST_TIMED_OUT, - LeaderEpoch: -1, - EndOffset: -1, - }) + rp := kmsg.NewOffsetForLeaderEpochResponseTopicPartition() + rp.Partition = part.Partition + rp.ErrorCode = protocol.REQUEST_TIMED_OUT + rp.LeaderEpoch = -1 + rp.EndOffset = -1 + rt.Partitions = append(rt.Partitions, rp) } - topics = append(topics, protocol.OffsetForLeaderEpochTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) + resp.Topics = append(resp.Topics, rt) } - resp := &protocol.OffsetForLeaderEpochResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - } - return wrapEncode(protocol.EncodeOffsetForLeaderEpochResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyDescribeGroups: - descReq := req.(*protocol.DescribeGroupsRequest) - groups := make([]protocol.DescribeGroupsResponseGroup, 0, len(descReq.Groups)) + descReq := req.(*kmsg.DescribeGroupsRequest) + resp := kmsg.NewPtrDescribeGroupsResponse() for _, group := range descReq.Groups { - groups = append(groups, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - GroupID: group, - State: "", - ProtocolType: "", - Protocol: "", - Members: nil, - AuthorizedOperations: 0, - }) - } - resp := &protocol.DescribeGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: groups, + rg := kmsg.NewDescribeGroupsResponseGroup() + rg.ErrorCode = protocol.REQUEST_TIMED_OUT + rg.Group = group + resp.Groups = append(resp.Groups, rg) } - return wrapEncode(protocol.EncodeDescribeGroupsResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyListGroups: - resp := &protocol.ListGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - ErrorCode: protocol.REQUEST_TIMED_OUT, - Groups: nil, - } - return wrapEncode(protocol.EncodeListGroupsResponse(resp, header.APIVersion)) + resp := kmsg.NewPtrListGroupsResponse() + resp.ErrorCode = protocol.REQUEST_TIMED_OUT + return encode(resp) case protocol.APIKeyDescribeConfigs: - descReq := req.(*protocol.DescribeConfigsRequest) - resources := make([]protocol.DescribeConfigsResponseResource, 0, len(descReq.Resources)) + descReq := req.(*kmsg.DescribeConfigsRequest) + resp := kmsg.NewPtrDescribeConfigsResponse() for _, res := range descReq.Resources { - resources = append(resources, protocol.DescribeConfigsResponseResource{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: nil, - ResourceType: res.ResourceType, - ResourceName: res.ResourceName, - Configs: nil, - }) - } - resp := &protocol.DescribeConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, + rr := kmsg.NewDescribeConfigsResponseResource() + rr.ErrorCode = protocol.REQUEST_TIMED_OUT + rr.ResourceType = res.ResourceType + rr.ResourceName = res.ResourceName + resp.Resources = append(resp.Resources, rr) } - return wrapEncode(protocol.EncodeDescribeConfigsResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyAlterConfigs: - alterReq := req.(*protocol.AlterConfigsRequest) - resources := make([]protocol.AlterConfigsResponseResource, 0, len(alterReq.Resources)) + alterReq := req.(*kmsg.AlterConfigsRequest) + resp := kmsg.NewPtrAlterConfigsResponse() for _, res := range alterReq.Resources { - resources = append(resources, protocol.AlterConfigsResponseResource{ - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: nil, - ResourceType: res.ResourceType, - ResourceName: res.ResourceName, - }) - } - resp := &protocol.AlterConfigsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Resources: resources, + rr := kmsg.NewAlterConfigsResponseResource() + rr.ErrorCode = protocol.REQUEST_TIMED_OUT + rr.ResourceType = res.ResourceType + rr.ResourceName = res.ResourceName + resp.Resources = append(resp.Resources, rr) } - return wrapEncode(protocol.EncodeAlterConfigsResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyCreatePartitions: - createReq := req.(*protocol.CreatePartitionsRequest) - topics := make([]protocol.CreatePartitionsResponseTopic, 0, len(createReq.Topics)) + createReq := req.(*kmsg.CreatePartitionsRequest) + resp := kmsg.NewPtrCreatePartitionsResponse() for _, topic := range createReq.Topics { - topics = append(topics, protocol.CreatePartitionsResponseTopic{ - Name: topic.Name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: nil, - }) + rt := kmsg.NewCreatePartitionsResponseTopic() + rt.Topic = topic.Topic + rt.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.Topics = append(resp.Topics, rt) } - resp := &protocol.CreatePartitionsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - } - return wrapEncode(protocol.EncodeCreatePartitionsResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyCreateTopics: - createReq := req.(*protocol.CreateTopicsRequest) - topics := make([]protocol.CreateTopicResult, 0, len(createReq.Topics)) + createReq := req.(*kmsg.CreateTopicsRequest) + resp := kmsg.NewPtrCreateTopicsResponse() for _, topic := range createReq.Topics { - topics = append(topics, protocol.CreateTopicResult{ - Name: topic.Name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: "", - }) - } - resp := &protocol.CreateTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, + rt := kmsg.NewCreateTopicsResponseTopic() + rt.Topic = topic.Topic + rt.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.Topics = append(resp.Topics, rt) } - return wrapEncode(protocol.EncodeCreateTopicsResponse(resp, header.APIVersion)) + return encode(resp) case protocol.APIKeyDeleteTopics: - delReq := req.(*protocol.DeleteTopicsRequest) - topics := make([]protocol.DeleteTopicResult, 0, len(delReq.TopicNames)) - for _, name := range delReq.TopicNames { - topics = append(topics, protocol.DeleteTopicResult{ - Name: name, - ErrorCode: protocol.REQUEST_TIMED_OUT, - ErrorMessage: "", - }) - } - resp := &protocol.DeleteTopicsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Topics: topics, - } - return wrapEncode(protocol.EncodeDeleteTopicsResponse(resp, header.APIVersion)) + delReq := req.(*kmsg.DeleteTopicsRequest) + resp := kmsg.NewPtrDeleteTopicsResponse() + for _, t := range delReq.Topics { + rt := kmsg.NewDeleteTopicsResponseTopic() + rt.Topic = t.Topic + rt.TopicID = t.TopicID + rt.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.Topics = append(resp.Topics, rt) + } + return encode(resp) case protocol.APIKeyDeleteGroups: - delReq := req.(*protocol.DeleteGroupsRequest) - groups := make([]protocol.DeleteGroupsResponseGroup, 0, len(delReq.Groups)) + delReq := req.(*kmsg.DeleteGroupsRequest) + resp := kmsg.NewPtrDeleteGroupsResponse() for _, group := range delReq.Groups { - groups = append(groups, protocol.DeleteGroupsResponseGroup{ - Group: group, - ErrorCode: protocol.REQUEST_TIMED_OUT, - }) - } - resp := &protocol.DeleteGroupsResponse{ - CorrelationID: header.CorrelationID, - ThrottleMs: 0, - Groups: groups, + rg := kmsg.NewDeleteGroupsResponseGroup() + rg.Group = group + rg.ErrorCode = protocol.REQUEST_TIMED_OUT + resp.Groups = append(resp.Groups, rg) } - return wrapEncode(protocol.EncodeDeleteGroupsResponse(resp, header.APIVersion)) + return encode(resp) default: return nil, false, nil } } -func generateProxyApiVersions() []protocol.ApiVersion { +func generateProxyApiVersions() []kmsg.ApiVersionsResponseApiKey { supported := []struct { key int16 min, max int16 @@ -1406,17 +1277,17 @@ func generateProxyApiVersions() []protocol.ApiVersion { {key: protocol.APIKeyDeleteGroups, min: 0, max: 2}, } unsupported := []int16{4, 5, 6, 7, 21, 22, 24, 25, 26} - entries := make([]protocol.ApiVersion, 0, len(supported)+len(unsupported)) + entries := make([]kmsg.ApiVersionsResponseApiKey, 0, len(supported)+len(unsupported)) for _, entry := range supported { - entries = append(entries, protocol.ApiVersion{ - APIKey: entry.key, + entries = append(entries, kmsg.ApiVersionsResponseApiKey{ + ApiKey: entry.key, MinVersion: entry.min, MaxVersion: entry.max, }) } for _, key := range unsupported { - entries = append(entries, protocol.ApiVersion{ - APIKey: key, + entries = append(entries, kmsg.ApiVersionsResponseApiKey{ + ApiKey: key, MinVersion: -1, MaxVersion: -1, }) @@ -1424,7 +1295,7 @@ func generateProxyApiVersions() []protocol.ApiVersion { return entries } -func buildProxyMetadataResponse(meta *metadata.ClusterMetadata, correlationID int32, version int16, host string, port int32) *protocol.MetadataResponse { +func buildProxyMetadataResponse(meta *metadata.ClusterMetadata, correlationID int32, version int16, host string, port int32) *kmsg.MetadataResponse { brokers := []protocol.MetadataBroker{{ NodeID: 0, Host: host, @@ -1439,30 +1310,28 @@ func buildProxyMetadataResponse(meta *metadata.ClusterMetadata, correlationID in partitions := make([]protocol.MetadataPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { partitions = append(partitions, protocol.MetadataPartition{ - ErrorCode: part.ErrorCode, - PartitionIndex: part.PartitionIndex, - LeaderID: 0, - LeaderEpoch: part.LeaderEpoch, - ReplicaNodes: []int32{0}, - ISRNodes: []int32{0}, + ErrorCode: part.ErrorCode, + Partition: part.Partition, + Leader: 0, + LeaderEpoch: part.LeaderEpoch, + Replicas: []int32{0}, + ISR: []int32{0}, }) } topics = append(topics, protocol.MetadataTopic{ ErrorCode: topic.ErrorCode, - Name: topic.Name, + Topic: topic.Topic, TopicID: topic.TopicID, IsInternal: topic.IsInternal, Partitions: partitions, }) } - return &protocol.MetadataResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - Brokers: brokers, - ClusterID: meta.ClusterID, - ControllerID: 0, - Topics: topics, - } + resp := kmsg.NewPtrMetadataResponse() + resp.Brokers = brokers + resp.ClusterID = meta.ClusterID + resp.ControllerID = 0 + resp.Topics = topics + return resp } func (p *proxy) connectBackend(ctx context.Context) (net.Conn, string, error) { @@ -1509,13 +1378,16 @@ func (p *proxy) updateBrokerAddrs(brokers []protocol.MetadataBroker) { } // refreshMetadataCache updates broker address and topic name caches from -// metadata. Concurrent calls are coalesced via singleflight. +// metadata. Concurrent calls are coalesced via singleflight. Uses a detached +// context so that a single caller's cancellation does not abort the shared fetch. func (p *proxy) refreshMetadataCache(ctx context.Context) { if p.store == nil { return } - p.metaFlight.Do("refresh", func() (interface{}, error) { - meta, err := p.store.Metadata(ctx, nil) + fetchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + _, err, _ := p.metaFlight.Do("refresh", func() (interface{}, error) { + meta, err := p.store.Metadata(fetchCtx, nil) if err != nil { return nil, err } @@ -1524,14 +1396,18 @@ func (p *proxy) refreshMetadataCache(ctx context.Context) { p.touchHealthy() return nil, nil }) + if err != nil { + p.logger.Warn("metadata cache refresh failed", "error", err) + } } func (p *proxy) updateTopicNames(topics []protocol.MetadataTopic) { names := make(map[[16]byte]string, len(topics)) var zeroID [16]byte for _, topic := range topics { - if topic.TopicID != zeroID && topic.Name != "" { - names[topic.TopicID] = topic.Name + name := *topic.Topic + if topic.TopicID != zeroID && name != "" { + names[topic.TopicID] = name } } p.topicNamesMu.Lock() @@ -1621,19 +1497,19 @@ func (p *proxy) extractGroupID(apiKey int16, payload []byte) string { return "" } switch r := req.(type) { - case *protocol.JoinGroupRequest: - return r.GroupID - case *protocol.SyncGroupRequest: - return r.GroupID - case *protocol.HeartbeatRequest: - return r.GroupID - case *protocol.LeaveGroupRequest: - return r.GroupID - case *protocol.OffsetCommitRequest: - return r.GroupID - case *protocol.OffsetFetchRequest: - return r.GroupID - case *protocol.DescribeGroupsRequest: + case *kmsg.JoinGroupRequest: + return r.Group + case *kmsg.SyncGroupRequest: + return r.Group + case *kmsg.HeartbeatRequest: + return r.Group + case *kmsg.LeaveGroupRequest: + return r.Group + case *kmsg.OffsetCommitRequest: + return r.Group + case *kmsg.OffsetFetchRequest: + return r.Group + case *kmsg.DescribeGroupsRequest: if len(r.Groups) > 0 { return r.Groups[0] } @@ -1650,7 +1526,7 @@ func (p *proxy) handleFetchRouting(ctx context.Context, header *protocol.Request if err != nil { return p.forwardFetchRaw(ctx, payload, pool) } - fetchReq, ok := req.(*protocol.FetchRequest) + fetchReq, ok := req.(*kmsg.FetchRequest) if !ok || len(fetchReq.Topics) == 0 { return p.forwardFetchRaw(ctx, payload, pool) } @@ -1678,11 +1554,11 @@ func (p *proxy) forwardFetchRaw(ctx context.Context, payload []byte, pool *connP // resolveFetchTopicNames resolves topic IDs to names so the partition router // (which is keyed by name) can look up owners for v12+ requests. -func (p *proxy) resolveFetchTopicNames(ctx context.Context, req *protocol.FetchRequest) { +func (p *proxy) resolveFetchTopicNames(ctx context.Context, req *kmsg.FetchRequest) { var zeroID [16]byte for i := range req.Topics { - if req.Topics[i].Name == "" && req.Topics[i].TopicID != zeroID { - req.Topics[i].Name = p.resolveTopicID(ctx, req.Topics[i].TopicID) + if req.Topics[i].Topic == "" && req.Topics[i].TopicID != zeroID { + req.Topics[i].Topic = p.resolveTopicID(ctx, req.Topics[i].TopicID) } } } @@ -1699,12 +1575,12 @@ func fetchTopicKey(name string, id [16]byte) string { // groupFetchPartitionsByBroker groups partitions by owning broker. If include // is non-nil, only listed partitions are grouped. Unknown owners go under "" // for round-robin. -func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *protocol.FetchRequest, include map[string]map[int32]bool) map[string]*protocol.FetchRequest { - groups := make(map[string]*protocol.FetchRequest) +func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *kmsg.FetchRequest, include map[string]map[int32]bool) map[string]*kmsg.FetchRequest { + groups := make(map[string]*kmsg.FetchRequest) topicIndices := make(map[string]map[string]int) // addr -> topicKey -> index in subReq.Topics for _, topic := range req.Topics { - topicName := topic.Name + topicName := topic.Topic key := fetchTopicKey(topicName, topic.TopicID) var includeParts map[int32]bool if include != nil { @@ -1725,9 +1601,10 @@ func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *protocol. } subReq, ok := groups[addr] if !ok { - subReq = &protocol.FetchRequest{ + subReq = &kmsg.FetchRequest{ + Version: req.Version, ReplicaID: req.ReplicaID, - MaxWaitMs: req.MaxWaitMs, + MaxWaitMillis: req.MaxWaitMillis, MinBytes: req.MinBytes, MaxBytes: req.MaxBytes, IsolationLevel: req.IsolationLevel, @@ -1740,8 +1617,8 @@ func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *protocol. idx, ok := topicIndices[addr][key] if !ok { idx = len(subReq.Topics) - subReq.Topics = append(subReq.Topics, protocol.FetchTopicRequest{ - Name: topic.Name, + subReq.Topics = append(subReq.Topics, kmsg.FetchRequestTopic{ + Topic: topic.Topic, TopicID: topic.TopicID, }) topicIndices[addr][key] = idx @@ -1753,8 +1630,8 @@ func (p *proxy) groupFetchPartitionsByBroker(ctx context.Context, req *protocol. } type fetchFanOutResult struct { - subReq *protocol.FetchRequest - subResp *protocol.FetchResponse + subReq *kmsg.FetchRequest + subResp *kmsg.FetchResponse conn net.Conn target string err error @@ -1762,19 +1639,17 @@ type fetchFanOutResult struct { // forwardFetch fans out sub-requests, merges responses, and retries // NOT_LEADER_OR_FOLLOWER partitions on a different broker. -func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader, fullReq *protocol.FetchRequest, originalPayload []byte, groups map[string]*protocol.FetchRequest, pool *connPool) ([]byte, error) { +func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader, fullReq *kmsg.FetchRequest, originalPayload []byte, groups map[string]*kmsg.FetchRequest, pool *connPool) ([]byte, error) { const maxRetries = 3 - merged := &protocol.FetchResponse{ - CorrelationID: header.CorrelationID, - SessionID: fullReq.SessionID, + merged := &kmsg.FetchResponse{ + Version: header.APIVersion, + SessionID: fullReq.SessionID, } - // Keyed by fetchTopicKey to avoid collisions among unresolved v12+ topics. var failedPartitions map[string]map[int32]bool for attempt := 0; attempt < maxRetries; attempt++ { failedPartitions = nil - // Reset per attempt so retries can revisit brokers from earlier attempts. triedBackends := make(map[string]bool) subResults := p.fanOutFetch(ctx, header, groups, originalPayload, triedBackends, pool) @@ -1793,7 +1668,7 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader for _, topic := range r.subResp.Topics { for _, part := range topic.Partitions { if part.ErrorCode == protocol.NOT_LEADER_OR_FOLLOWER { - topicName := topic.Name + topicName := topic.Topic if topicName == "" { topicName = p.resolveTopicID(ctx, topic.TopicID) } @@ -1809,18 +1684,18 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader p.router.Invalidate(topicName, part.Partition) } } else { - tr := findOrAddFetchTopicResponse(merged, topic.Name, topic.TopicID) + tr := findOrAddFetchTopicResponse(merged, topic.Topic, topic.TopicID) tr.Partitions = append(tr.Partitions, part) } } } - if r.subResp.ThrottleMs > merged.ThrottleMs { - merged.ThrottleMs = r.subResp.ThrottleMs + if r.subResp.ThrottleMillis > merged.ThrottleMillis { + merged.ThrottleMillis = r.subResp.ThrottleMillis } } if len(failedPartitions) == 0 { - return protocol.EncodeFetchResponse(merged, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, merged), nil } groups = p.groupFetchPartitionsByBroker(ctx, fullReq, failedPartitions) @@ -1832,28 +1707,28 @@ func (p *proxy) forwardFetch(ctx context.Context, header *protocol.RequestHeader } for _, topic := range fullReq.Topics { - key := fetchTopicKey(topic.Name, topic.TopicID) + key := fetchTopicKey(topic.Topic, topic.TopicID) failedParts, ok := failedPartitions[key] if !ok { continue } - tr := findOrAddFetchTopicResponse(merged, topic.Name, topic.TopicID) + tr := findOrAddFetchTopicResponse(merged, topic.Topic, topic.TopicID) for _, part := range topic.Partitions { if failedParts[part.Partition] { - tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{ + tr.Partitions = append(tr.Partitions, kmsg.FetchResponseTopicPartition{ Partition: part.Partition, ErrorCode: protocol.NOT_LEADER_OR_FOLLOWER, }) } } } - return protocol.EncodeFetchResponse(merged, header.APIVersion) + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, merged), nil } // fanOutFetch borrows connections and forwards fetch sub-requests concurrently. -func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, groups map[string]*protocol.FetchRequest, originalPayload []byte, triedBackends map[string]bool, pool *connPool) []fetchFanOutResult { +func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, groups map[string]*kmsg.FetchRequest, originalPayload []byte, triedBackends map[string]bool, pool *connPool) []fetchFanOutResult { type workItem struct { - subReq *protocol.FetchRequest + subReq *kmsg.FetchRequest conn net.Conn target string payload []byte @@ -1874,13 +1749,7 @@ func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, if canUseOriginal { payload = originalPayload } else { - encoded, encErr := protocol.EncodeFetchRequest(header, subReq, header.APIVersion) - if encErr != nil { - conn.Close() - connectErrors = append(connectErrors, fetchFanOutResult{subReq: subReq, target: targetAddr, err: encErr}) - continue - } - payload = encoded + payload = encodeFetchRequest(header, subReq) } work = append(work, workItem{subReq: subReq, conn: conn, target: targetAddr, payload: payload}) } @@ -1899,7 +1768,7 @@ func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, results[i] = fetchFanOutResult{subReq: w.subReq, target: w.target, err: err} return } - subResp, parseErr := protocol.ParseFetchResponse(respBytes, header.APIVersion) + subResp, parseErr := parseFetchResponse(respBytes, header.APIVersion) if parseErr != nil { w.conn.Close() results[i] = fetchFanOutResult{subReq: w.subReq, target: w.target, err: parseErr} @@ -1913,7 +1782,7 @@ func (p *proxy) fanOutFetch(ctx context.Context, header *protocol.RequestHeader, return append(connectErrors, results...) } -func findOrAddFetchTopicResponse(resp *protocol.FetchResponse, name string, topicID [16]byte) *protocol.FetchTopicResponse { +func findOrAddFetchTopicResponse(resp *kmsg.FetchResponse, name string, topicID [16]byte) *kmsg.FetchResponseTopic { var zeroID [16]byte for i := range resp.Topics { if topicID != zeroID { @@ -1921,23 +1790,70 @@ func findOrAddFetchTopicResponse(resp *protocol.FetchResponse, name string, topi return &resp.Topics[i] } } else { - if resp.Topics[i].Name == name { + if resp.Topics[i].Topic == name { return &resp.Topics[i] } } } - resp.Topics = append(resp.Topics, protocol.FetchTopicResponse{Name: name, TopicID: topicID}) + resp.Topics = append(resp.Topics, kmsg.FetchResponseTopic{Topic: name, TopicID: topicID}) return &resp.Topics[len(resp.Topics)-1] } -func addFetchErrorForAllPartitions(resp *protocol.FetchResponse, req *protocol.FetchRequest, errorCode int16) { +func addFetchErrorForAllPartitions(resp *kmsg.FetchResponse, req *kmsg.FetchRequest, errorCode int16) { for _, topic := range req.Topics { - tr := findOrAddFetchTopicResponse(resp, topic.Name, topic.TopicID) + tr := findOrAddFetchTopicResponse(resp, topic.Topic, topic.TopicID) for _, part := range topic.Partitions { - tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{ + tr.Partitions = append(tr.Partitions, kmsg.FetchResponseTopicPartition{ Partition: part.Partition, ErrorCode: errorCode, }) } } } + +// encodeProduceRequest serializes a produce request with header into a wire frame. +func encodeProduceRequest(header *protocol.RequestHeader, req *kmsg.ProduceRequest) []byte { + formatter := kmsg.NewRequestFormatter(kmsg.FormatterClientID(clientIDStr(header.ClientID))) + return formatter.AppendRequest(nil, req, header.CorrelationID) +} + +// parseProduceResponse deserializes a produce response from wire bytes. +func parseProduceResponse(data []byte, version int16) (*kmsg.ProduceResponse, error) { + body, ok := protocol.SkipResponseHeader(protocol.APIKeyProduce, version, data) + if !ok { + return nil, fmt.Errorf("produce response too short or malformed header") + } + resp := kmsg.NewPtrProduceResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { + return nil, fmt.Errorf("decode produce response v%d: %w", version, err) + } + return resp, nil +} + +// encodeFetchRequest serializes a fetch request with header into a wire frame. +func encodeFetchRequest(header *protocol.RequestHeader, req *kmsg.FetchRequest) []byte { + formatter := kmsg.NewRequestFormatter(kmsg.FormatterClientID(clientIDStr(header.ClientID))) + return formatter.AppendRequest(nil, req, header.CorrelationID) +} + +// parseFetchResponse deserializes a fetch response from wire bytes. +func parseFetchResponse(data []byte, version int16) (*kmsg.FetchResponse, error) { + body, ok := protocol.SkipResponseHeader(protocol.APIKeyFetch, version, data) + if !ok { + return nil, fmt.Errorf("fetch response too short or malformed header") + } + resp := kmsg.NewPtrFetchResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { + return nil, fmt.Errorf("decode fetch response v%d: %w", version, err) + } + return resp, nil +} + +func clientIDStr(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/cmd/proxy/main_test.go b/cmd/proxy/main_test.go index 5039f1f2..14b5ccd0 100644 --- a/cmd/proxy/main_test.go +++ b/cmd/proxy/main_test.go @@ -25,6 +25,7 @@ import ( "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestBuildProxyMetadataResponseRewritesBrokers(t *testing.T) { @@ -34,14 +35,14 @@ func TestBuildProxyMetadataResponseRewritesBrokers(t *testing.T) { }, Topics: []protocol.MetadataTopic{ { - Name: "orders", + Topic: kmsg.StringPtr("orders"), TopicID: metadata.TopicIDForName("orders"), Partitions: []protocol.MetadataPartition{ { - PartitionIndex: 0, - LeaderID: 1, - ReplicaNodes: []int32{1, 2}, - ISRNodes: []int32{1}, + Partition: 0, + Leader: 1, + Replicas: []int32{1, 2}, + ISR: []int32{1}, }, }, }, @@ -59,14 +60,14 @@ func TestBuildProxyMetadataResponseRewritesBrokers(t *testing.T) { t.Fatalf("expected 1 topic, got %d", len(resp.Topics)) } part := resp.Topics[0].Partitions[0] - if part.LeaderID != 0 { - t.Fatalf("expected leader 0, got %d", part.LeaderID) + if part.Leader != 0 { + t.Fatalf("expected leader 0, got %d", part.Leader) } - if len(part.ReplicaNodes) != 1 || part.ReplicaNodes[0] != 0 { - t.Fatalf("expected replica nodes [0], got %+v", part.ReplicaNodes) + if len(part.Replicas) != 1 || part.Replicas[0] != 0 { + t.Fatalf("expected replica nodes [0], got %+v", part.Replicas) } - if len(part.ISRNodes) != 1 || part.ISRNodes[0] != 0 { - t.Fatalf("expected ISR nodes [0], got %+v", part.ISRNodes) + if len(part.ISR) != 1 || part.ISR[0] != 0 { + t.Fatalf("expected ISR nodes [0], got %+v", part.ISR) } } @@ -74,7 +75,7 @@ func TestBuildProxyMetadataResponsePreservesTopicErrors(t *testing.T) { meta := &metadata.ClusterMetadata{ Topics: []protocol.MetadataTopic{ { - Name: "missing", + Topic: kmsg.StringPtr("missing"), ErrorCode: protocol.UNKNOWN_TOPIC_OR_PARTITION, }, }, @@ -90,12 +91,12 @@ func TestBuildProxyMetadataResponsePreservesTopicErrors(t *testing.T) { func TestBuildNotReadyResponseProduce(t *testing.T) { payload := encodeProduceRequestV3("orders", 0) - header, _, err := protocol.ParseRequestHeader(payload) + header, body, err := protocol.ParseRequestHeader(payload) if err != nil { t.Fatalf("parse header: %v", err) } p := &proxy{} - respBytes, ok, err := p.buildNotReadyResponse(header, payload) + respBytes, ok, err := p.buildNotReadyResponse(header, body) if err != nil { t.Fatalf("build not-ready response: %v", err) } @@ -113,12 +114,12 @@ func TestBuildNotReadyResponseProduce(t *testing.T) { func TestBuildNotReadyResponseFetch(t *testing.T) { payload := encodeFetchRequestV7("orders", 0) - header, _, err := protocol.ParseRequestHeader(payload) + header, body, err := protocol.ParseRequestHeader(payload) if err != nil { t.Fatalf("parse header: %v", err) } p := &proxy{} - respBytes, ok, err := p.buildNotReadyResponse(header, payload) + respBytes, ok, err := p.buildNotReadyResponse(header, body) if err != nil { t.Fatalf("build not-ready response: %v", err) } @@ -370,12 +371,12 @@ func readString(r *bytes.Reader) (string, error) { return string(buf), nil } -func makeProduceRequest(topics map[string][]int32) *protocol.ProduceRequest { - req := &protocol.ProduceRequest{Acks: -1, TimeoutMs: 5000} +func makeProduceRequest(topics map[string][]int32) *kmsg.ProduceRequest { + req := &kmsg.ProduceRequest{Acks: -1, TimeoutMillis: 5000} for name, parts := range topics { - topic := protocol.ProduceTopic{Name: name} + topic := kmsg.ProduceRequestTopic{Topic: name} for _, p := range parts { - topic.Partitions = append(topic.Partitions, protocol.ProducePartition{ + topic.Partitions = append(topic.Partitions, kmsg.ProduceRequestTopicPartition{ Partition: p, Records: []byte{1, 2, 3}, }) @@ -406,8 +407,8 @@ func TestGroupPartitionsByBrokerNoRouter(t *testing.T) { if totalParts != 4 { t.Fatalf("expected 4 total partitions, got %d", totalParts) } - if rr.Acks != -1 || rr.TimeoutMs != 5000 { - t.Fatalf("sub-request should preserve acks/timeout: got acks=%d timeout=%d", rr.Acks, rr.TimeoutMs) + if rr.Acks != -1 || rr.TimeoutMillis != 5000 { + t.Fatalf("sub-request should preserve acks/timeout: got acks=%d timeout=%d", rr.Acks, rr.TimeoutMillis) } } @@ -430,7 +431,7 @@ func TestGroupPartitionsByBrokerNoRouterMultipleTopics(t *testing.T) { } topicNames := make(map[string]int) for _, topic := range rr.Topics { - topicNames[topic.Name] = len(topic.Partitions) + topicNames[topic.Topic] = len(topic.Partitions) } if topicNames["orders"] != 2 || topicNames["events"] != 3 { t.Fatalf("unexpected topic grouping: %v", topicNames) @@ -461,10 +462,10 @@ func TestGroupPartitionsByBrokerFiltersCorrectly(t *testing.T) { } func TestFindOrAddTopicResponse(t *testing.T) { - resp := &protocol.ProduceResponse{} + resp := &kmsg.ProduceResponse{} tr := findOrAddTopicResponse(resp, "orders") - tr.Partitions = append(tr.Partitions, protocol.ProducePartitionResponse{Partition: 0}) + tr.Partitions = append(tr.Partitions, kmsg.ProduceResponseTopicPartition{Partition: 0}) // Second call should return the same topic, not create a new one. tr2 := findOrAddTopicResponse(resp, "orders") @@ -483,7 +484,7 @@ func TestFindOrAddTopicResponse(t *testing.T) { } func TestAddErrorForAllPartitions(t *testing.T) { - resp := &protocol.ProduceResponse{} + resp := &kmsg.ProduceResponse{} req := makeProduceRequest(map[string][]int32{ "orders": {0, 1}, "events": {0}, @@ -563,7 +564,7 @@ func TestConnPoolBorrowReturn(t *testing.T) { // --- Test helpers --- -func countPartitions(req *protocol.ProduceRequest) int { +func countPartitions(req *kmsg.ProduceRequest) int { n := 0 for _, t := range req.Topics { n += len(t.Partitions) @@ -571,7 +572,7 @@ func countPartitions(req *protocol.ProduceRequest) int { return n } -func mapKeys(m map[string]*protocol.ProduceRequest) []string { +func mapKeys(m map[string]*kmsg.ProduceRequest) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) @@ -584,14 +585,14 @@ type fakeNetConn struct { closed bool } -func (c *fakeNetConn) Read(b []byte) (int, error) { return 0, nil } -func (c *fakeNetConn) Write(b []byte) (int, error) { return len(b), nil } -func (c *fakeNetConn) Close() error { c.closed = true; return nil } -func (c *fakeNetConn) LocalAddr() net.Addr { return nil } -func (c *fakeNetConn) RemoteAddr() net.Addr { return nil } -func (c *fakeNetConn) SetDeadline(time.Time) error { return nil } -func (c *fakeNetConn) SetReadDeadline(time.Time) error { return nil } -func (c *fakeNetConn) SetWriteDeadline(time.Time) error { return nil } +func (c *fakeNetConn) Read(b []byte) (int, error) { return 0, nil } +func (c *fakeNetConn) Write(b []byte) (int, error) { return len(b), nil } +func (c *fakeNetConn) Close() error { c.closed = true; return nil } +func (c *fakeNetConn) LocalAddr() net.Addr { return nil } +func (c *fakeNetConn) RemoteAddr() net.Addr { return nil } +func (c *fakeNetConn) SetDeadline(time.Time) error { return nil } +func (c *fakeNetConn) SetReadDeadline(time.Time) error { return nil } +func (c *fakeNetConn) SetWriteDeadline(time.Time) error { return nil } func TestExtractGroupID(t *testing.T) { p := &proxy{} @@ -617,14 +618,14 @@ func TestExtractGroupID(t *testing.T) { payload: func() []byte { w := &testWriter{} writeHeader(w, protocol.APIKeyJoinGroup, 2) - w.String("my-join-group") // group_id - w.Int32(30000) // session_timeout - w.Int32(10000) // rebalance_timeout - w.String("") // member_id - w.String("consumer") // protocol_type - w.Int32(1) // protocol count - w.String("range") // protocol name - w.Bytes([]byte{0x00, 0x01}) // protocol metadata + w.String("my-join-group") // group_id + w.Int32(30000) // session_timeout + w.Int32(10000) // rebalance_timeout + w.String("") // member_id + w.String("consumer") // protocol_type + w.Int32(1) // protocol count + w.String("range") // protocol name + w.Bytes([]byte{0x00, 0x01}) // protocol metadata return w.buf.Bytes() }, want: "my-join-group", @@ -753,22 +754,22 @@ func TestExtractGroupID(t *testing.T) { // --- Fetch routing tests --- -func makeFetchRequest(topics map[string][]int32) *protocol.FetchRequest { - req := &protocol.FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - SessionID: 0, - SessionEpoch: -1, +func makeFetchRequest(topics map[string][]int32) *kmsg.FetchRequest { + req := &kmsg.FetchRequest{ + ReplicaID: -1, + MaxWaitMillis: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionID: 0, + SessionEpoch: -1, } for name, parts := range topics { - topic := protocol.FetchTopicRequest{Name: name} + topic := kmsg.FetchRequestTopic{Topic: name} for _, p := range parts { - topic.Partitions = append(topic.Partitions, protocol.FetchPartitionRequest{ - Partition: p, - FetchOffset: 0, - MaxBytes: 1048576, + topic.Partitions = append(topic.Partitions, kmsg.FetchRequestTopicPartition{ + Partition: p, + FetchOffset: 0, + PartitionMaxBytes: 1048576, }) } req.Topics = append(req.Topics, topic) @@ -776,7 +777,7 @@ func makeFetchRequest(topics map[string][]int32) *protocol.FetchRequest { return req } -func countFetchPartitions(req *protocol.FetchRequest) int { +func countFetchPartitions(req *kmsg.FetchRequest) int { n := 0 for _, t := range req.Topics { n += len(t.Partitions) @@ -784,7 +785,7 @@ func countFetchPartitions(req *protocol.FetchRequest) int { return n } -func fetchMapKeys(m map[string]*protocol.FetchRequest) []string { +func fetchMapKeys(m map[string]*kmsg.FetchRequest) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) @@ -809,8 +810,8 @@ func TestGroupFetchPartitionsByBrokerNoRouter(t *testing.T) { if countFetchPartitions(rr) != 4 { t.Fatalf("expected 4 total partitions, got %d", countFetchPartitions(rr)) } - if rr.MaxWaitMs != 500 || rr.MaxBytes != 1048576 { - t.Fatalf("sub-request should preserve settings: got maxWait=%d maxBytes=%d", rr.MaxWaitMs, rr.MaxBytes) + if rr.MaxWaitMillis != 500 || rr.MaxBytes != 1048576 { + t.Fatalf("sub-request should preserve settings: got maxWait=%d maxBytes=%d", rr.MaxWaitMillis, rr.MaxBytes) } } @@ -833,7 +834,7 @@ func TestGroupFetchPartitionsByBrokerNoRouterMultipleTopics(t *testing.T) { } topicNames := make(map[string]int) for _, topic := range rr.Topics { - topicNames[topic.Name] = len(topic.Partitions) + topicNames[topic.Topic] = len(topic.Partitions) } if topicNames["orders"] != 2 || topicNames["events"] != 3 { t.Fatalf("unexpected topic grouping: %v", topicNames) @@ -864,11 +865,11 @@ func TestGroupFetchPartitionsByBrokerFiltersCorrectly(t *testing.T) { } func TestFindOrAddFetchTopicResponse(t *testing.T) { - resp := &protocol.FetchResponse{} + resp := &kmsg.FetchResponse{} topicID := [16]byte{1, 2, 3} tr := findOrAddFetchTopicResponse(resp, "orders", topicID) - tr.Partitions = append(tr.Partitions, protocol.FetchPartitionResponse{Partition: 0}) + tr.Partitions = append(tr.Partitions, kmsg.FetchResponseTopicPartition{Partition: 0}) // Same topic should return the existing entry. tr2 := findOrAddFetchTopicResponse(resp, "orders", topicID) @@ -887,7 +888,7 @@ func TestFindOrAddFetchTopicResponse(t *testing.T) { // v12+: same topicID but different name should match on topicID alone. tr4 := findOrAddFetchTopicResponse(resp, "", topicID) - tr4.Partitions = append(tr4.Partitions, protocol.FetchPartitionResponse{Partition: 1}) + tr4.Partitions = append(tr4.Partitions, kmsg.FetchResponseTopicPartition{Partition: 1}) if len(resp.Topics) != 2 { t.Fatalf("expected 2 topics after topicID-only match, got %d", len(resp.Topics)) } @@ -899,7 +900,7 @@ func TestFindOrAddFetchTopicResponse(t *testing.T) { // Name-only match (zero topicID) should work for pre-v12 topics. tr6 := findOrAddFetchTopicResponse(resp, "logs", [16]byte{}) - tr6.Partitions = append(tr6.Partitions, protocol.FetchPartitionResponse{Partition: 0}) + tr6.Partitions = append(tr6.Partitions, kmsg.FetchResponseTopicPartition{Partition: 0}) tr7 := findOrAddFetchTopicResponse(resp, "logs", [16]byte{}) if len(tr7.Partitions) != 1 { t.Fatalf("expected 1 partition for name-only topic, got %d", len(tr7.Partitions)) @@ -910,7 +911,7 @@ func TestFindOrAddFetchTopicResponse(t *testing.T) { } func TestAddFetchErrorForAllPartitions(t *testing.T) { - resp := &protocol.FetchResponse{} + resp := &kmsg.FetchResponse{} req := makeFetchRequest(map[string][]int32{ "orders": {0, 1}, "events": {0}, @@ -939,9 +940,9 @@ func TestUpdateTopicNames(t *testing.T) { topicID1 := [16]byte{1, 2, 3} topicID2 := [16]byte{4, 5, 6} topics := []protocol.MetadataTopic{ - {Name: "orders", TopicID: topicID1}, - {Name: "events", TopicID: topicID2}, - {Name: "", TopicID: [16]byte{}}, // should be skipped + {Topic: kmsg.StringPtr("orders"), TopicID: topicID1}, + {Topic: kmsg.StringPtr("events"), TopicID: topicID2}, + {Topic: kmsg.StringPtr(""), TopicID: [16]byte{}}, // should be skipped } p.updateTopicNames(topics) @@ -963,15 +964,15 @@ func TestGroupFetchPartitionsByBrokerUnresolvedTopicIDs(t *testing.T) { idA := [16]byte{1, 2, 3} idB := [16]byte{4, 5, 6} p := &proxy{} - req := &protocol.FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - SessionEpoch: -1, - Topics: []protocol.FetchTopicRequest{ - {TopicID: idA, Partitions: []protocol.FetchPartitionRequest{{Partition: 0, MaxBytes: 1048576}}}, - {TopicID: idB, Partitions: []protocol.FetchPartitionRequest{{Partition: 0, MaxBytes: 1048576}}}, + req := &kmsg.FetchRequest{ + ReplicaID: -1, + MaxWaitMillis: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionEpoch: -1, + Topics: []kmsg.FetchRequestTopic{ + {TopicID: idA, Partitions: []kmsg.FetchRequestTopicPartition{{Partition: 0, PartitionMaxBytes: 1048576}}}, + {TopicID: idB, Partitions: []kmsg.FetchRequestTopicPartition{{Partition: 0, PartitionMaxBytes: 1048576}}}, }, } groups := p.groupFetchPartitionsByBroker(context.Background(), req, nil) @@ -993,19 +994,19 @@ func TestGroupFetchPartitionsByBrokerUnresolvedFilter(t *testing.T) { idA := [16]byte{1, 2, 3} idB := [16]byte{4, 5, 6} p := &proxy{} - req := &protocol.FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - SessionEpoch: -1, - Topics: []protocol.FetchTopicRequest{ - {TopicID: idA, Partitions: []protocol.FetchPartitionRequest{ - {Partition: 0, MaxBytes: 1048576}, - {Partition: 1, MaxBytes: 1048576}, + req := &kmsg.FetchRequest{ + ReplicaID: -1, + MaxWaitMillis: 500, + MinBytes: 1, + MaxBytes: 1048576, + SessionEpoch: -1, + Topics: []kmsg.FetchRequestTopic{ + {TopicID: idA, Partitions: []kmsg.FetchRequestTopicPartition{ + {Partition: 0, PartitionMaxBytes: 1048576}, + {Partition: 1, PartitionMaxBytes: 1048576}, }}, - {TopicID: idB, Partitions: []protocol.FetchPartitionRequest{ - {Partition: 0, MaxBytes: 1048576}, + {TopicID: idB, Partitions: []kmsg.FetchRequestTopicPartition{ + {Partition: 0, PartitionMaxBytes: 1048576}, }}, }, } @@ -1034,18 +1035,18 @@ func TestResolveFetchTopicNames(t *testing.T) { p := &proxy{ topicNames: map[[16]byte]string{topicID: "orders"}, } - req := &protocol.FetchRequest{ - Topics: []protocol.FetchTopicRequest{ + req := &kmsg.FetchRequest{ + Topics: []kmsg.FetchRequestTopic{ {TopicID: topicID}, // name not set, should be resolved - {Name: "events"}, // already has name, should be left alone + {Topic: "events"}, // already has name, should be left alone }, } p.resolveFetchTopicNames(context.Background(), req) - if req.Topics[0].Name != "orders" { - t.Fatalf("topic[0] name: got %q, want %q", req.Topics[0].Name, "orders") + if req.Topics[0].Topic != "orders" { + t.Fatalf("topic[0] name: got %q, want %q", req.Topics[0].Topic, "orders") } - if req.Topics[1].Name != "events" { - t.Fatalf("topic[1] name: got %q, want %q", req.Topics[1].Name, "events") + if req.Topics[1].Topic != "events" { + t.Fatalf("topic[1] name: got %q, want %q", req.Topics[1].Topic, "events") } } diff --git a/internal/console/server.go b/internal/console/server.go index 6db02a24..3e3e6a3a 100644 --- a/internal/console/server.go +++ b/internal/console/server.go @@ -359,14 +359,14 @@ func statusFromMetadata(meta *metadata.ClusterMetadata, metrics *MetricsSnapshot partitions := make([]partitionDetails, 0, len(topic.Partitions)) for _, part := range topic.Partitions { partitions = append(partitions, partitionDetails{ - ID: part.PartitionIndex, - Leader: part.LeaderID, - Replicas: len(part.ReplicaNodes), - ISR: len(part.ISRNodes), + ID: part.Partition, + Leader: part.Leader, + Replicas: len(part.Replicas), + ISR: len(part.ISR), }) } resp.Topics = append(resp.Topics, topicInfo{ - Name: topic.Name, + Name: *topic.Topic, Partitions: len(topic.Partitions), State: state, PartitionsDetails: partitions, diff --git a/internal/console/server_test.go b/internal/console/server_test.go index 6527a476..f3697d3d 100644 --- a/internal/console/server_test.go +++ b/internal/console/server_test.go @@ -27,6 +27,7 @@ import ( "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestConsoleStatusEndpoint(t *testing.T) { @@ -66,7 +67,7 @@ func TestStatusFromMetadataInjectsBrokerRuntime(t *testing.T) { {NodeID: 1, Host: "broker-1"}, }, Topics: []protocol.MetadataTopic{ - {Name: "orders"}, + {Topic: kmsg.StringPtr("orders")}, }, } snap := &MetricsSnapshot{ diff --git a/internal/mcpserver/tools.go b/internal/mcpserver/tools.go index 75cb511e..8d1c77fa 100644 --- a/internal/mcpserver/tools.go +++ b/internal/mcpserver/tools.go @@ -394,13 +394,13 @@ func fetchOffsetsHandler(opts Options) mcp.ToolHandlerFor[FetchOffsetsInput, Fet out := FetchOffsetsOutput{GroupID: input.GroupID} for _, topic := range meta.Topics { for _, partition := range topic.Partitions { - offset, metaText, err := store.FetchConsumerOffset(ctx, input.GroupID, topic.Name, partition.PartitionIndex) + offset, metaText, err := store.FetchConsumerOffset(ctx, input.GroupID, *topic.Topic, partition.Partition) if err != nil { return nil, FetchOffsetsOutput{}, err } out.Offsets = append(out.Offsets, OffsetDetails{ - Topic: topic.Name, - Partition: partition.PartitionIndex, + Topic: *topic.Topic, + Partition: partition.Partition, Offset: offset, Metadata: metaText, }) @@ -430,7 +430,7 @@ func describeConfigsHandler(opts Options) mcp.ToolHandlerFor[TopicConfigInput, T } topics = make([]string, 0, len(meta.Topics)) for _, topic := range meta.Topics { - topics = append(topics, topic.Name) + topics = append(topics, *topic.Topic) } } out := make([]TopicConfigOutput, 0, len(topics)) @@ -457,7 +457,7 @@ func summarizeTopics(topics []protocol.MetadataTopic) []TopicSummary { out := make([]TopicSummary, 0, len(topics)) for _, topic := range topics { out = append(out, TopicSummary{ - Name: topic.Name, + Name: *topic.Topic, PartitionCount: len(topic.Partitions), ErrorCode: topic.ErrorCode, }) @@ -470,17 +470,17 @@ func toTopicDetail(topic protocol.MetadataTopic) TopicDetail { partitions := make([]PartitionDetails, 0, len(topic.Partitions)) for _, partition := range topic.Partitions { partitions = append(partitions, PartitionDetails{ - Partition: partition.PartitionIndex, - LeaderID: partition.LeaderID, + Partition: partition.Partition, + LeaderID: partition.Leader, LeaderEpoch: partition.LeaderEpoch, - ReplicaNodes: copyInt32Slice(partition.ReplicaNodes), - ISRNodes: copyInt32Slice(partition.ISRNodes), + ReplicaNodes: copyInt32Slice(partition.Replicas), + ISRNodes: copyInt32Slice(partition.ISR), OfflineReplicas: copyInt32Slice(partition.OfflineReplicas), ErrorCode: partition.ErrorCode, }) } return TopicDetail{ - Name: topic.Name, + Name: *topic.Topic, ErrorCode: topic.ErrorCode, Partitions: partitions, } diff --git a/internal/mcpserver/tools_test.go b/internal/mcpserver/tools_test.go index d32b9145..cd7230d9 100644 --- a/internal/mcpserver/tools_test.go +++ b/internal/mcpserver/tools_test.go @@ -21,6 +21,7 @@ import ( metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestRequireStore(t *testing.T) { @@ -35,8 +36,8 @@ func TestRequireStore(t *testing.T) { func TestSummarizeTopicsSorts(t *testing.T) { topics := []protocol.MetadataTopic{ - {Name: "b", Partitions: []protocol.MetadataPartition{{PartitionIndex: 0}}}, - {Name: "a", Partitions: []protocol.MetadataPartition{{PartitionIndex: 0}, {PartitionIndex: 1}}}, + {Topic: kmsg.StringPtr("b"), Partitions: []protocol.MetadataPartition{{Partition: 0}}}, + {Topic: kmsg.StringPtr("a"), Partitions: []protocol.MetadataPartition{{Partition: 0}, {Partition: 1}}}, } out := summarizeTopics(topics) if len(out) != 2 || out[0].Name != "a" || out[1].Name != "b" { diff --git a/pkg/broker/coordinator.go b/pkg/broker/coordinator.go index 15588faa..a77908e8 100644 --- a/pkg/broker/coordinator.go +++ b/pkg/broker/coordinator.go @@ -25,6 +25,8 @@ import ( "sync" "time" + "github.com/twmb/franz-go/pkg/kmsg" + metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" @@ -112,20 +114,18 @@ func NewGroupCoordinator(store metadata.Store, broker protocol.MetadataBroker, c return c } -func (c *GroupCoordinator) FindCoordinatorResponse(correlationID int32, errorCode int16) *protocol.FindCoordinatorResponse { - return &protocol.FindCoordinatorResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: errorCode, - NodeID: c.broker.NodeID, - Host: c.broker.Host, - Port: c.broker.Port, - } +func (c *GroupCoordinator) FindCoordinatorResponse(errorCode int16) *kmsg.FindCoordinatorResponse { + resp := kmsg.NewPtrFindCoordinatorResponse() + resp.ErrorCode = errorCode + resp.NodeID = c.broker.NodeID + resp.Host = c.broker.Host + resp.Port = c.broker.Port + return resp } -func (c *GroupCoordinator) JoinGroup(ctx context.Context, req *protocol.JoinGroupRequest, correlationID int32) (*protocol.JoinGroupResponse, error) { +func (c *GroupCoordinator) JoinGroup(ctx context.Context, req *kmsg.JoinGroupRequest) (*kmsg.JoinGroupResponse, error) { c.mu.Lock() - state, err := c.ensureGroup(ctx, req.GroupID) + state, err := c.ensureGroup(ctx, req.Group) if err != nil { c.mu.Unlock() return nil, err @@ -135,7 +135,7 @@ func (c *GroupCoordinator) JoinGroup(ctx context.Context, req *protocol.JoinGrou state.protocolName = req.Protocols[0].Name } - timeout := time.Duration(req.RebalanceTimeoutMs) * time.Millisecond + timeout := time.Duration(req.RebalanceTimeoutMillis) * time.Millisecond if timeout <= 0 { timeout = defaultRebalanceTimeout } @@ -143,14 +143,14 @@ func (c *GroupCoordinator) JoinGroup(ctx context.Context, req *protocol.JoinGrou memberID := req.MemberID member, exists := state.members[memberID] if memberID == "" || member == nil { - memberID = c.newMemberID(req.GroupID) + memberID = c.newMemberID(req.Group) member = &memberState{} state.members[memberID] = member exists = false } - if req.SessionTimeoutMs > 0 { - member.sessionTimeout = time.Duration(req.SessionTimeoutMs) * time.Millisecond + if req.SessionTimeoutMillis > 0 { + member.sessionTimeout = time.Duration(req.SessionTimeoutMillis) * time.Millisecond } else if member.sessionTimeout == 0 { member.sessionTimeout = defaultSessionTimeout } @@ -178,81 +178,64 @@ func (c *GroupCoordinator) JoinGroup(ctx context.Context, req *protocol.JoinGrou ready = state.completeIfReady() } - var members []protocol.JoinGroupMember + resp := kmsg.NewPtrJoinGroupResponse() + resp.Generation = state.generationID + resp.Protocol = &state.protocolName + resp.LeaderID = state.leaderID + resp.MemberID = memberID + resp.ErrorCode = protocol.REBALANCE_IN_PROGRESS + if ready && memberID == state.leaderID { - members = c.encodeMemberSubscriptions(state) + resp.Members = c.encodeMemberSubscriptions(state) } else { - members = []protocol.JoinGroupMember{} + resp.Members = []kmsg.JoinGroupResponseMember{} } - resp := &protocol.JoinGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - GenerationID: state.generationID, - ProtocolName: state.protocolName, - LeaderID: state.leaderID, - MemberID: memberID, - Members: members, - ErrorCode: protocol.REBALANCE_IN_PROGRESS, - } if ready { resp.ErrorCode = protocol.NONE } - if err := c.persistGroupLocked(ctx, req.GroupID, state); err != nil { + if err := c.persistGroupLocked(ctx, req.Group, state); err != nil { resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } c.mu.Unlock() return resp, nil } -func (c *GroupCoordinator) SyncGroup(ctx context.Context, req *protocol.SyncGroupRequest, correlationID int32) (*protocol.SyncGroupResponse, error) { +func (c *GroupCoordinator) SyncGroup(ctx context.Context, req *kmsg.SyncGroupRequest) (*kmsg.SyncGroupResponse, error) { c.mu.Lock() - state, err := c.loadGroupIfMissing(ctx, req.GroupID) + state, err := c.loadGroupIfMissing(ctx, req.Group) if err != nil { c.mu.Unlock() return nil, err } + + mkErrResp := func(code int16) *kmsg.SyncGroupResponse { + r := kmsg.NewPtrSyncGroupResponse() + r.ErrorCode = code + return r + } + if state == nil { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - }, nil + return mkErrResp(protocol.UNKNOWN_MEMBER_ID), nil } - if req.GenerationID != state.generationID { + if req.Generation != state.generationID { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.ILLEGAL_GENERATION, - }, nil + return mkErrResp(protocol.ILLEGAL_GENERATION), nil } if _, ok := state.members[req.MemberID]; !ok { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - }, nil + return mkErrResp(protocol.UNKNOWN_MEMBER_ID), nil } if state.state == groupStatePreparingRebalance { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.REBALANCE_IN_PROGRESS, - }, nil + return mkErrResp(protocol.REBALANCE_IN_PROGRESS), nil } if state.state == groupStateCompletingRebalance && len(state.assignments) == 0 { if req.MemberID != state.leaderID { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.REBALANCE_IN_PROGRESS, - }, nil + return mkErrResp(protocol.REBALANCE_IN_PROGRESS), nil } state.assignments = c.assignPartitions(ctx, state) state.markStable() @@ -261,129 +244,95 @@ func (c *GroupCoordinator) SyncGroup(ctx context.Context, req *protocol.SyncGrou assignments := state.assignments[req.MemberID] if assignments == nil && state.state != groupStateStable { c.mu.Unlock() - return &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.REBALANCE_IN_PROGRESS, - }, nil + return mkErrResp(protocol.REBALANCE_IN_PROGRESS), nil } - var protocolTypePtr *string + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = protocol.NONE if state.protocolType != "" { - pt := state.protocolType - protocolTypePtr = &pt + resp.ProtocolType = &state.protocolType } - var protocolNamePtr *string if state.protocolName != "" { - pn := state.protocolName - protocolNamePtr = &pn - } - resp := &protocol.SyncGroupResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.NONE, - ProtocolType: protocolTypePtr, - ProtocolName: protocolNamePtr, - Assignment: encodeAssignment(assignments), - } - if err := c.persistGroupLocked(ctx, req.GroupID, state); err != nil { + resp.Protocol = &state.protocolName + } + resp.MemberAssignment = encodeAssignment(assignments) + + if err := c.persistGroupLocked(ctx, req.Group, state); err != nil { resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } c.mu.Unlock() return resp, nil } -func (c *GroupCoordinator) Heartbeat(ctx context.Context, req *protocol.HeartbeatRequest, correlationID int32) *protocol.HeartbeatResponse { +func (c *GroupCoordinator) Heartbeat(ctx context.Context, req *kmsg.HeartbeatRequest) *kmsg.HeartbeatResponse { c.mu.Lock() - state, err := c.loadGroupIfMissing(ctx, req.GroupID) + state, err := c.loadGroupIfMissing(ctx, req.Group) + + mkResp := func(code int16) *kmsg.HeartbeatResponse { + r := kmsg.NewPtrHeartbeatResponse() + r.ErrorCode = code + return r + } + if err != nil { c.mu.Unlock() - return &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - } + return mkResp(protocol.UNKNOWN_SERVER_ERROR) } if state == nil { c.mu.Unlock() - return &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - } + return mkResp(protocol.UNKNOWN_MEMBER_ID) } member := state.members[req.MemberID] if member == nil { c.mu.Unlock() - return &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - } + return mkResp(protocol.UNKNOWN_MEMBER_ID) } - if req.GenerationID != state.generationID { + if req.Generation != state.generationID { c.mu.Unlock() - return &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.ILLEGAL_GENERATION, - } + return mkResp(protocol.ILLEGAL_GENERATION) } if state.state != groupStateStable { c.mu.Unlock() - return &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.REBALANCE_IN_PROGRESS, - } + return mkResp(protocol.REBALANCE_IN_PROGRESS) } member.lastHeartbeat = time.Now() - resp := &protocol.HeartbeatResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.NONE, - } - if err := c.persistGroupLocked(ctx, req.GroupID, state); err != nil { + resp := mkResp(protocol.NONE) + if err := c.persistGroupLocked(ctx, req.Group, state); err != nil { resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } c.mu.Unlock() return resp } -func (c *GroupCoordinator) LeaveGroup(ctx context.Context, req *protocol.LeaveGroupRequest, correlationID int32) *protocol.LeaveGroupResponse { +func (c *GroupCoordinator) LeaveGroup(ctx context.Context, req *kmsg.LeaveGroupRequest) *kmsg.LeaveGroupResponse { c.mu.Lock() - state, err := c.loadGroupIfMissing(ctx, req.GroupID) + state, err := c.loadGroupIfMissing(ctx, req.Group) + + mkResp := func(code int16) *kmsg.LeaveGroupResponse { + r := kmsg.NewPtrLeaveGroupResponse() + r.ErrorCode = code + return r + } + if err != nil { c.mu.Unlock() - return &protocol.LeaveGroupResponse{ - CorrelationID: correlationID, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - } + return mkResp(protocol.UNKNOWN_SERVER_ERROR) } if state == nil { c.mu.Unlock() - return &protocol.LeaveGroupResponse{ - CorrelationID: correlationID, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - } + return mkResp(protocol.UNKNOWN_MEMBER_ID) } if _, ok := state.members[req.MemberID]; !ok { c.mu.Unlock() - return &protocol.LeaveGroupResponse{ - CorrelationID: correlationID, - ErrorCode: protocol.UNKNOWN_MEMBER_ID, - } + return mkResp(protocol.UNKNOWN_MEMBER_ID) } delete(state.members, req.MemberID) delete(state.assignments, req.MemberID) if len(state.members) == 0 { - delete(c.groups, req.GroupID) - resp := &protocol.LeaveGroupResponse{ - CorrelationID: correlationID, - ErrorCode: protocol.NONE, - } - if err := c.persistGroupLocked(ctx, req.GroupID, nil); err != nil { + delete(c.groups, req.Group) + resp := mkResp(protocol.NONE) + if err := c.persistGroupLocked(ctx, req.Group, nil); err != nil { resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } c.mu.Unlock() @@ -393,20 +342,17 @@ func (c *GroupCoordinator) LeaveGroup(ctx context.Context, req *protocol.LeaveGr state.leaderID = "" } state.startRebalance(0) - resp := &protocol.LeaveGroupResponse{ - CorrelationID: correlationID, - ErrorCode: protocol.NONE, - } - if err := c.persistGroupLocked(ctx, req.GroupID, state); err != nil { + resp := mkResp(protocol.NONE) + if err := c.persistGroupLocked(ctx, req.Group, state); err != nil { resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR } c.mu.Unlock() return resp } -func (c *GroupCoordinator) OffsetCommit(ctx context.Context, req *protocol.OffsetCommitRequest, correlationID int32) (*protocol.OffsetCommitResponse, error) { +func (c *GroupCoordinator) OffsetCommit(ctx context.Context, req *kmsg.OffsetCommitRequest) (*kmsg.OffsetCommitResponse, error) { c.mu.Lock() - state, err := c.loadGroupIfMissing(ctx, req.GroupID) + state, err := c.loadGroupIfMissing(ctx, req.Group) if err != nil { c.mu.Unlock() return nil, err @@ -417,171 +363,164 @@ func (c *GroupCoordinator) OffsetCommit(ctx context.Context, req *protocol.Offse groupErr = protocol.UNKNOWN_MEMBER_ID } else if _, ok := state.members[req.MemberID]; !ok { groupErr = protocol.UNKNOWN_MEMBER_ID - } else if req.GenerationID != state.generationID { + } else if req.Generation != state.generationID { groupErr = protocol.ILLEGAL_GENERATION } c.mu.Unlock() - results := make([]protocol.OffsetCommitTopicResponse, 0, len(req.Topics)) + resp := kmsg.NewPtrOffsetCommitResponse() + resp.Topics = make([]kmsg.OffsetCommitResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - partitions := make([]protocol.OffsetCommitPartitionResponse, 0, len(topic.Partitions)) + topicResp := kmsg.NewOffsetCommitResponseTopic() + topicResp.Topic = topic.Topic + topicResp.Partitions = make([]kmsg.OffsetCommitResponseTopicPartition, 0, len(topic.Partitions)) for _, part := range topic.Partitions { code := groupErr if code == protocol.NONE { - if err := c.store.CommitConsumerOffset(ctx, req.GroupID, topic.Name, part.Partition, part.Offset, part.Metadata); err != nil { + meta := "" + if part.Metadata != nil { + meta = *part.Metadata + } + if err := c.store.CommitConsumerOffset(ctx, req.Group, topic.Topic, part.Partition, part.Offset, meta); err != nil { code = protocol.UNKNOWN_SERVER_ERROR } } - partitions = append(partitions, protocol.OffsetCommitPartitionResponse{ - Partition: part.Partition, - ErrorCode: code, - }) + partResp := kmsg.NewOffsetCommitResponseTopicPartition() + partResp.Partition = part.Partition + partResp.ErrorCode = code + topicResp.Partitions = append(topicResp.Partitions, partResp) } - results = append(results, protocol.OffsetCommitTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return &protocol.OffsetCommitResponse{ - CorrelationID: correlationID, - Topics: results, - }, nil + resp.Topics = append(resp.Topics, topicResp) + } + return resp, nil } -func (c *GroupCoordinator) OffsetFetch(ctx context.Context, req *protocol.OffsetFetchRequest, correlationID int32) (*protocol.OffsetFetchResponse, error) { - topicResponses := make([]protocol.OffsetFetchTopicResponse, 0, len(req.Topics)) +func (c *GroupCoordinator) OffsetFetch(ctx context.Context, req *kmsg.OffsetFetchRequest) (*kmsg.OffsetFetchResponse, error) { + resp := kmsg.NewPtrOffsetFetchResponse() + resp.ErrorCode = protocol.NONE + resp.Topics = make([]kmsg.OffsetFetchResponseTopic, 0, len(req.Topics)) for _, topic := range req.Topics { - partitions := make([]protocol.OffsetFetchPartitionResponse, 0, len(topic.Partitions)) - for _, part := range topic.Partitions { - offset, metadataStr, err := c.store.FetchConsumerOffset(ctx, req.GroupID, topic.Name, part.Partition) - code := protocol.NONE + topicResp := kmsg.NewOffsetFetchResponseTopic() + topicResp.Topic = topic.Topic + topicResp.Partitions = make([]kmsg.OffsetFetchResponseTopicPartition, 0, len(topic.Partitions)) + for _, partID := range topic.Partitions { + offset, metadataStr, err := c.store.FetchConsumerOffset(ctx, req.Group, topic.Topic, partID) + code := int16(protocol.NONE) if err != nil { code = protocol.UNKNOWN_SERVER_ERROR } - leaderEpoch := int32(-1) - metaVal := metadataStr - partitions = append(partitions, protocol.OffsetFetchPartitionResponse{ - Partition: part.Partition, - Offset: offset, - LeaderEpoch: leaderEpoch, - Metadata: &metaVal, - ErrorCode: code, - }) + partResp := kmsg.NewOffsetFetchResponseTopicPartition() + partResp.Partition = partID + partResp.Offset = offset + partResp.LeaderEpoch = -1 + partResp.Metadata = &metadataStr + partResp.ErrorCode = code + topicResp.Partitions = append(topicResp.Partitions, partResp) } - topicResponses = append(topicResponses, protocol.OffsetFetchTopicResponse{ - Name: topic.Name, - Partitions: partitions, - }) - } - return &protocol.OffsetFetchResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - Topics: topicResponses, - ErrorCode: protocol.NONE, - }, nil + resp.Topics = append(resp.Topics, topicResp) + } + return resp, nil } -func (c *GroupCoordinator) DescribeGroups(ctx context.Context, req *protocol.DescribeGroupsRequest, correlationID int32) (*protocol.DescribeGroupsResponse, error) { - groups := make([]protocol.DescribeGroupsResponseGroup, 0, len(req.Groups)) +func (c *GroupCoordinator) DescribeGroups(ctx context.Context, req *kmsg.DescribeGroupsRequest) (*kmsg.DescribeGroupsResponse, error) { + resp := kmsg.NewPtrDescribeGroupsResponse() + resp.Groups = make([]kmsg.DescribeGroupsResponseGroup, 0, len(req.Groups)) for _, groupID := range req.Groups { group, err := c.store.FetchConsumerGroup(ctx, groupID) if err != nil { - groups = append(groups, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - GroupID: groupID, - }) + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + g.Group = groupID + resp.Groups = append(resp.Groups, g) continue } if group == nil { - groups = append(groups, protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.GROUP_ID_NOT_FOUND, - GroupID: groupID, - }) + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.GROUP_ID_NOT_FOUND + g.Group = groupID + resp.Groups = append(resp.Groups, g) continue } - groups = append(groups, buildDescribeGroup(group, req.IncludeAuthorizedOperations)) + resp.Groups = append(resp.Groups, buildDescribeGroup(group, req.IncludeAuthorizedOperations)) } - return &protocol.DescribeGroupsResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - Groups: groups, - }, nil + return resp, nil } -func (c *GroupCoordinator) ListGroups(ctx context.Context, req *protocol.ListGroupsRequest, correlationID int32) (*protocol.ListGroupsResponse, error) { +func (c *GroupCoordinator) ListGroups(ctx context.Context, req *kmsg.ListGroupsRequest) (*kmsg.ListGroupsResponse, error) { groups, err := c.store.ListConsumerGroups(ctx) if err != nil { - return &protocol.ListGroupsResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.UNKNOWN_SERVER_ERROR, - Groups: nil, - }, nil - } - entries := make([]protocol.ListGroupsResponseGroup, 0, len(groups)) + resp := kmsg.NewPtrListGroupsResponse() + resp.ErrorCode = protocol.UNKNOWN_SERVER_ERROR + return resp, nil + } + + var statesFilter, typesFilter []string + if len(req.StatesFilter) > 0 { + statesFilter = req.StatesFilter + } + if len(req.TypesFilter) > 0 { + typesFilter = req.TypesFilter + } + + resp := kmsg.NewPtrListGroupsResponse() + resp.ErrorCode = protocol.NONE + resp.Groups = make([]kmsg.ListGroupsResponseGroup, 0, len(groups)) for _, group := range groups { state := kafkaGroupState(group.GetState()) - if !matchesGroupStateFilter(state, req.StatesFilter) { + if !matchesGroupStateFilter(state, statesFilter) { continue } groupType := "classic" - if !matchesGroupTypeFilter(groupType, req.TypesFilter) { + if !matchesGroupTypeFilter(groupType, typesFilter) { continue } protocolType := group.GetProtocolType() if protocolType == "" { protocolType = "consumer" } - entries = append(entries, protocol.ListGroupsResponseGroup{ - GroupID: group.GetGroupId(), - ProtocolType: protocolType, - GroupState: state, - GroupType: groupType, - }) - } - return &protocol.ListGroupsResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - ErrorCode: protocol.NONE, - Groups: entries, - }, nil + entry := kmsg.NewListGroupsResponseGroup() + entry.Group = group.GetGroupId() + entry.ProtocolType = protocolType + entry.GroupState = state + entry.GroupType = groupType + resp.Groups = append(resp.Groups, entry) + } + return resp, nil } -func (c *GroupCoordinator) DeleteGroups(ctx context.Context, req *protocol.DeleteGroupsRequest, correlationID int32) (*protocol.DeleteGroupsResponse, error) { - results := make([]protocol.DeleteGroupsResponseGroup, 0, len(req.Groups)) +func (c *GroupCoordinator) DeleteGroups(ctx context.Context, req *kmsg.DeleteGroupsRequest) (*kmsg.DeleteGroupsResponse, error) { + resp := kmsg.NewPtrDeleteGroupsResponse() + resp.Groups = make([]kmsg.DeleteGroupsResponseGroup, 0, len(req.Groups)) for _, groupID := range req.Groups { - result := protocol.DeleteGroupsResponseGroup{Group: groupID} + result := kmsg.NewDeleteGroupsResponseGroup() + result.Group = groupID if strings.TrimSpace(groupID) == "" { result.ErrorCode = protocol.INVALID_REQUEST - results = append(results, result) + resp.Groups = append(resp.Groups, result) continue } group, err := c.store.FetchConsumerGroup(ctx, groupID) if err != nil { result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR - results = append(results, result) + resp.Groups = append(resp.Groups, result) continue } if group == nil { result.ErrorCode = protocol.GROUP_ID_NOT_FOUND c.deleteGroupState(groupID) - results = append(results, result) + resp.Groups = append(resp.Groups, result) continue } if err := c.store.DeleteConsumerGroup(ctx, groupID); err != nil { result.ErrorCode = protocol.UNKNOWN_SERVER_ERROR - results = append(results, result) + resp.Groups = append(resp.Groups, result) continue } c.deleteGroupState(groupID) result.ErrorCode = protocol.NONE - results = append(results, result) + resp.Groups = append(resp.Groups, result) } - return &protocol.DeleteGroupsResponse{ - CorrelationID: correlationID, - ThrottleMs: 0, - Groups: results, - }, nil + return resp, nil } func (c *GroupCoordinator) deleteGroupState(groupID string) { @@ -672,45 +611,44 @@ func buildConsumerGroup(groupID string, state *groupState) *metadatapb.ConsumerG return group } -func buildDescribeGroup(group *metadatapb.ConsumerGroup, includeAuthorized bool) protocol.DescribeGroupsResponseGroup { +func buildDescribeGroup(group *metadatapb.ConsumerGroup, includeAuthorized bool) kmsg.DescribeGroupsResponseGroup { protocolType := group.GetProtocolType() if protocolType == "" { protocolType = "consumer" } protocolName := group.GetProtocol() - members := make([]protocol.DescribeGroupsResponseGroupMember, 0, len(group.Members)) + memberIDs := make([]string, 0, len(group.Members)) for memberID := range group.Members { memberIDs = append(memberIDs, memberID) } sort.Strings(memberIDs) + + members := make([]kmsg.DescribeGroupsResponseGroupMember, 0, len(group.Members)) for _, memberID := range memberIDs { member := group.Members[memberID] if member == nil { continue } - members = append(members, protocol.DescribeGroupsResponseGroupMember{ - MemberID: memberID, - InstanceID: nil, - ClientID: member.ClientId, - ClientHost: member.ClientHost, - ProtocolMetadata: nil, - MemberAssignment: nil, - }) + m := kmsg.NewDescribeGroupsResponseGroupMember() + m.MemberID = memberID + m.ClientID = member.ClientId + m.ClientHost = member.ClientHost + members = append(members, m) } authorizedOps := int32(-2147483648) if includeAuthorized { authorizedOps = 0 } - return protocol.DescribeGroupsResponseGroup{ - ErrorCode: protocol.NONE, - GroupID: group.GetGroupId(), - State: kafkaGroupState(group.GetState()), - ProtocolType: protocolType, - Protocol: protocolName, - Members: members, - AuthorizedOperations: authorizedOps, - } + g := kmsg.NewDescribeGroupsResponseGroup() + g.ErrorCode = protocol.NONE + g.Group = group.GetGroupId() + g.State = kafkaGroupState(group.GetState()) + g.ProtocolType = protocolType + g.Protocol = protocolName + g.Members = members + g.AuthorizedOperations = authorizedOps + return g } func restoreGroupState(group *metadatapb.ConsumerGroup) *groupState { @@ -844,7 +782,7 @@ func (c *GroupCoordinator) newMemberID(group string) string { return fmt.Sprintf("%s-%d", group, rand.Int63()) } -func (c *GroupCoordinator) parseSubscriptionTopics(protocols []protocol.JoinGroupProtocol) []string { +func (c *GroupCoordinator) parseSubscriptionTopics(protocols []kmsg.JoinGroupRequestProtocol) []string { if len(protocols) == 0 { return nil } @@ -893,15 +831,15 @@ func (c *GroupCoordinator) encodeSubscription(topics []string) []byte { return buf } -func (c *GroupCoordinator) encodeMemberSubscriptions(state *groupState) []protocol.JoinGroupMember { +func (c *GroupCoordinator) encodeMemberSubscriptions(state *groupState) []kmsg.JoinGroupResponseMember { ids := state.sortedMembers() - members := make([]protocol.JoinGroupMember, 0, len(ids)) + members := make([]kmsg.JoinGroupResponseMember, 0, len(ids)) for _, id := range ids { member := state.members[id] - members = append(members, protocol.JoinGroupMember{ - MemberID: id, - Metadata: c.encodeSubscription(member.topics), - }) + m := kmsg.NewJoinGroupResponseMember() + m.MemberID = id + m.ProtocolMetadata = c.encodeSubscription(member.topics) + members = append(members, m) } return members } @@ -998,13 +936,13 @@ func (c *GroupCoordinator) collectTopicPartitions(ctx context.Context, state *gr for _, topic := range meta.Topics { partitions := make([]int32, 0, len(topic.Partitions)) for _, p := range topic.Partitions { - partitions = append(partitions, p.PartitionIndex) + partitions = append(partitions, p.Partition) } if len(partitions) == 0 { partitions = []int32{0} } sort.Slice(partitions, func(i, j int) bool { return partitions[i] < partitions[j] }) - result[topic.Name] = partitions + result[*topic.Topic] = partitions } return result } diff --git a/pkg/broker/coordinator_test.go b/pkg/broker/coordinator_test.go index 4037adf8..387cee47 100644 --- a/pkg/broker/coordinator_test.go +++ b/pkg/broker/coordinator_test.go @@ -17,9 +17,12 @@ package broker import ( "context" + "encoding/binary" "testing" "time" + "github.com/twmb/franz-go/pkg/kmsg" + metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" @@ -85,20 +88,20 @@ func TestCoordinatorListDescribeGroups(t *testing.T) { } coord := NewGroupCoordinator(store, protocol.MetadataBroker{NodeID: 1, Host: "127.0.0.1", Port: 9092}, nil) - listResp, err := coord.ListGroups(context.Background(), &protocol.ListGroupsRequest{ - StatesFilter: []string{"Stable"}, - TypesFilter: []string{"classic"}, - }, 1) + listReq := kmsg.NewPtrListGroupsRequest() + listReq.StatesFilter = []string{"Stable"} + listReq.TypesFilter = []string{"classic"} + listResp, err := coord.ListGroups(context.Background(), listReq) if err != nil { t.Fatalf("ListGroups: %v", err) } - if len(listResp.Groups) != 1 || listResp.Groups[0].GroupID != "group-1" { + if len(listResp.Groups) != 1 || listResp.Groups[0].Group != "group-1" { t.Fatalf("unexpected list response: %#v", listResp.Groups) } - describeResp, err := coord.DescribeGroups(context.Background(), &protocol.DescribeGroupsRequest{ - Groups: []string{"group-1"}, - }, 2) + describeReq := kmsg.NewPtrDescribeGroupsRequest() + describeReq.Groups = []string{"group-1"} + describeResp, err := coord.DescribeGroups(context.Background(), describeReq) if err != nil { t.Fatalf("DescribeGroups: %v", err) } @@ -107,6 +110,263 @@ func TestCoordinatorListDescribeGroups(t *testing.T) { } } +func TestSubscriptionRoundTrip(t *testing.T) { + store := metadata.NewInMemoryStore(metadata.ClusterMetadata{}) + coord := NewGroupCoordinator(store, protocol.MetadataBroker{}, nil) + + topics := []string{"orders", "events", "metrics"} + encoded := coord.encodeSubscription(topics) + parsed := coord.parseSubscriptionTopics([]kmsg.JoinGroupRequestProtocol{ + {Name: "range", Metadata: encoded}, + }) + if len(parsed) != len(topics) { + t.Fatalf("topic count: got %d want %d", len(parsed), len(topics)) + } + for i, topic := range topics { + if parsed[i] != topic { + t.Fatalf("topic[%d]: got %q want %q", i, parsed[i], topic) + } + } +} + +func TestParseSubscriptionTopics_Malformed(t *testing.T) { + store := metadata.NewInMemoryStore(metadata.ClusterMetadata{}) + coord := NewGroupCoordinator(store, protocol.MetadataBroker{}, nil) + + if got := coord.parseSubscriptionTopics(nil); got != nil { + t.Fatalf("expected nil for empty protocols, got %v", got) + } + if got := coord.parseSubscriptionTopics([]kmsg.JoinGroupRequestProtocol{ + {Metadata: []byte{0}}, + }); got != nil { + t.Fatalf("expected nil for too-short metadata, got %v", got) + } + if got := coord.parseSubscriptionTopics([]kmsg.JoinGroupRequestProtocol{ + {Metadata: []byte{0, 0, 0}}, + }); got != nil { + t.Fatalf("expected nil for truncated count, got %v", got) + } +} + +func TestEncodeAssignment(t *testing.T) { + topics := []assignmentTopic{ + {Name: "orders", Partitions: []int32{0, 1, 2}}, + {Name: "events", Partitions: []int32{0}}, + } + encoded := encodeAssignment(topics) + + // Parse manually to verify wire format. + off := 0 + version := int16(binary.BigEndian.Uint16(encoded[off:])) + off += 2 + if version != 0 { + t.Fatalf("expected version 0, got %d", version) + } + topicCount := int(binary.BigEndian.Uint32(encoded[off:])) + off += 4 + if topicCount != 2 { + t.Fatalf("expected 2 topics, got %d", topicCount) + } + for _, want := range topics { + nameLen := int(binary.BigEndian.Uint16(encoded[off:])) + off += 2 + name := string(encoded[off : off+nameLen]) + off += nameLen + if name != want.Name { + t.Fatalf("topic name: got %q want %q", name, want.Name) + } + partCount := int(binary.BigEndian.Uint32(encoded[off:])) + off += 4 + if partCount != len(want.Partitions) { + t.Fatalf("partition count for %s: got %d want %d", name, partCount, len(want.Partitions)) + } + for j, wantPart := range want.Partitions { + gotPart := int32(binary.BigEndian.Uint32(encoded[off:])) + off += 4 + if gotPart != wantPart { + t.Fatalf("%s partition[%d]: got %d want %d", name, j, gotPart, wantPart) + } + } + } +} + +func TestMemberSubscribes(t *testing.T) { + member := &memberState{topics: []string{"orders", "events"}} + if !memberSubscribes(member, "orders") { + t.Fatal("expected true for subscribed topic") + } + if memberSubscribes(member, "metrics") { + t.Fatal("expected false for unsubscribed topic") + } + if memberSubscribes(nil, "orders") { + t.Fatal("expected false for nil member") + } +} + +func TestGroupState_RebalanceLifecycle(t *testing.T) { + now := time.Now() + state := &groupState{ + state: groupStateStable, + members: make(map[string]*memberState), + assignments: make(map[string][]assignmentTopic), + } + state.members["m-1"] = &memberState{ + topics: []string{"orders"}, lastHeartbeat: now, joinGeneration: 0, + } + state.members["m-2"] = &memberState{ + topics: []string{"orders"}, lastHeartbeat: now, joinGeneration: 0, + } + + // startRebalance increments generation, sets preparing state, picks leader. + state.startRebalance(10 * time.Second) + if state.state != groupStatePreparingRebalance { + t.Fatalf("expected PreparingRebalance, got %d", state.state) + } + if state.generationID != 1 { + t.Fatalf("expected generation 1, got %d", state.generationID) + } + if state.leaderID == "" { + t.Fatal("expected leader to be set") + } + // Both members should have joinGeneration reset to 0. + for id, m := range state.members { + if m.joinGeneration != 0 { + t.Fatalf("member %s joinGeneration should be 0", id) + } + } + + // completeIfReady should return false — no members have joined this generation. + if state.completeIfReady() { + t.Fatal("expected completeIfReady=false before members join") + } + + // Simulate both members joining. + for _, m := range state.members { + m.joinGeneration = state.generationID + } + if !state.completeIfReady() { + t.Fatal("expected completeIfReady=true after all members join") + } + if state.state != groupStateCompletingRebalance { + t.Fatalf("expected CompletingRebalance, got %d", state.state) + } + + // markStable transitions to stable. + state.markStable() + if state.state != groupStateStable { + t.Fatalf("expected Stable, got %d", state.state) + } +} + +func TestGroupState_RemoveExpiredMembers(t *testing.T) { + now := time.Now() + state := &groupState{ + state: groupStateStable, + leaderID: "m-1", + members: make(map[string]*memberState), + assignments: make(map[string][]assignmentTopic), + } + state.members["m-1"] = &memberState{ + lastHeartbeat: now.Add(-60 * time.Second), sessionTimeout: 30 * time.Second, + } + state.members["m-2"] = &memberState{ + lastHeartbeat: now, sessionTimeout: 30 * time.Second, + } + + changed := state.removeExpiredMembers(now) + if !changed { + t.Fatal("expected change after removing expired member") + } + if _, ok := state.members["m-1"]; ok { + t.Fatal("expired member m-1 should be removed") + } + if _, ok := state.members["m-2"]; !ok { + t.Fatal("active member m-2 should remain") + } + if state.leaderID != "" { + t.Fatalf("leader should be cleared when leader is removed, got %q", state.leaderID) + } +} + +func TestGroupState_DropRebalanceLaggers(t *testing.T) { + now := time.Now() + state := &groupState{ + state: groupStatePreparingRebalance, + generationID: 3, + rebalanceDeadline: now.Add(-1 * time.Second), + members: make(map[string]*memberState), + assignments: make(map[string][]assignmentTopic), + } + state.members["m-joined"] = &memberState{joinGeneration: 3} + state.members["m-lagging"] = &memberState{joinGeneration: 2} + + changed := state.dropRebalanceLaggers(now) + if !changed { + t.Fatal("expected laggers to be dropped") + } + if _, ok := state.members["m-lagging"]; ok { + t.Fatal("lagging member should be removed") + } + if _, ok := state.members["m-joined"]; !ok { + t.Fatal("joined member should remain") + } + + // Before deadline: no drops. + state.members["m-new-lagger"] = &memberState{joinGeneration: 2} + state.rebalanceDeadline = now.Add(10 * time.Second) + if state.dropRebalanceLaggers(now) { + t.Fatal("should not drop before deadline") + } +} + +func TestGroupState_StartRebalance_EmptyGroup(t *testing.T) { + state := &groupState{ + state: groupStateStable, + members: make(map[string]*memberState), + assignments: make(map[string][]assignmentTopic), + leaderID: "old-leader", + } + state.startRebalance(0) + if state.state != groupStateEmpty { + t.Fatalf("expected Empty for group with no members, got %d", state.state) + } + if state.leaderID != "" { + t.Fatalf("expected leader cleared, got %q", state.leaderID) + } +} + +func TestSortedMembers(t *testing.T) { + state := &groupState{ + members: map[string]*memberState{ + "charlie": {}, + "alpha": {}, + "bravo": {}, + }, + } + got := state.sortedMembers() + want := []string{"alpha", "bravo", "charlie"} + if len(got) != len(want) { + t.Fatalf("len: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("member[%d]: got %q want %q", i, got[i], want[i]) + } + } +} + +func TestBumpRebalanceDeadline(t *testing.T) { + state := &groupState{} + state.bumpRebalanceDeadline(0) + if state.rebalanceTimeout != defaultRebalanceTimeout { + t.Fatalf("expected default timeout, got %s", state.rebalanceTimeout) + } + state.bumpRebalanceDeadline(15 * time.Second) + if state.rebalanceTimeout != 15*time.Second { + t.Fatalf("expected 15s, got %s", state.rebalanceTimeout) + } +} + func TestCoordinatorDeleteGroups(t *testing.T) { store := metadata.NewInMemoryStore(metadata.ClusterMetadata{}) group := &metadatapb.ConsumerGroup{GroupId: "group-1", State: "stable"} @@ -115,9 +375,9 @@ func TestCoordinatorDeleteGroups(t *testing.T) { } coord := NewGroupCoordinator(store, protocol.MetadataBroker{NodeID: 1, Host: "127.0.0.1", Port: 9092}, nil) - resp, err := coord.DeleteGroups(context.Background(), &protocol.DeleteGroupsRequest{ - Groups: []string{"group-1", "missing"}, - }, 3) + deleteReq := kmsg.NewPtrDeleteGroupsRequest() + deleteReq.Groups = []string{"group-1", "missing"} + resp, err := coord.DeleteGroups(context.Background(), deleteReq) if err != nil { t.Fatalf("DeleteGroups: %v", err) } diff --git a/pkg/broker/server.go b/pkg/broker/server.go index 23ec3e1f..2efdfa4e 100644 --- a/pkg/broker/server.go +++ b/pkg/broker/server.go @@ -23,21 +23,23 @@ import ( "net" "sync" + "github.com/twmb/franz-go/pkg/kmsg" + "github.com/KafScale/platform/pkg/protocol" ) // Handler processes parsed Kafka protocol requests and returns the response payload. type Handler interface { - Handle(ctx context.Context, header *protocol.RequestHeader, req protocol.Request) ([]byte, error) + Handle(ctx context.Context, header *protocol.RequestHeader, req kmsg.Request) ([]byte, error) } // Server implements minimal Kafka TCP handling for milestone 1. type Server struct { - Addr string - Handler Handler - ConnContextFunc ConnContextFunc - listener net.Listener - wg sync.WaitGroup + Addr string + Handler Handler + ConnContextFunc ConnContextFunc + listener net.Listener + wg sync.WaitGroup } // ConnContextFunc can wrap a connection and attach connection-scoped context data. diff --git a/pkg/broker/server_test.go b/pkg/broker/server_test.go index 5b75cb52..3e985f9f 100644 --- a/pkg/broker/server_test.go +++ b/pkg/broker/server_test.go @@ -27,30 +27,29 @@ import ( "syscall" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) type testHandler struct{} -func (h *testHandler) Handle(ctx context.Context, header *protocol.RequestHeader, req protocol.Request) ([]byte, error) { +func (h *testHandler) Handle(ctx context.Context, header *protocol.RequestHeader, req kmsg.Request) ([]byte, error) { switch req.(type) { - case *protocol.ApiVersionsRequest: - return protocol.EncodeApiVersionsResponse(&protocol.ApiVersionsResponse{ - CorrelationID: header.CorrelationID, - Versions: []protocol.ApiVersion{ - {APIKey: protocol.APIKeyApiVersion, MinVersion: 0, MaxVersion: 0}, - }, - }, header.APIVersion) - case *protocol.MetadataRequest: - return protocol.EncodeMetadataResponse(&protocol.MetadataResponse{ - CorrelationID: header.CorrelationID, - Brokers: []protocol.MetadataBroker{ - {NodeID: 1, Host: "localhost", Port: 9092}, - }, - ControllerID: 1, - Topics: []protocol.MetadataTopic{ - {Name: "orders"}, - }, - }, header.APIVersion) + case *kmsg.ApiVersionsRequest: + resp := kmsg.NewPtrApiVersionsResponse() + resp.ApiKeys = []kmsg.ApiVersionsResponseApiKey{ + {ApiKey: protocol.APIKeyApiVersion, MinVersion: 0, MaxVersion: 0}, + } + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil + case *kmsg.MetadataRequest: + resp := kmsg.NewPtrMetadataResponse() + resp.Brokers = []kmsg.MetadataResponseBroker{ + {NodeID: 1, Host: "localhost", Port: 9092}, + } + resp.ControllerID = 1 + topic := kmsg.NewMetadataResponseTopic() + topic.Topic = kmsg.StringPtr("orders") + resp.Topics = []kmsg.MetadataResponseTopic{topic} + return protocol.EncodeResponse(header.CorrelationID, header.APIVersion, resp), nil default: return nil, errors.New("unsupported api") } diff --git a/pkg/metadata/etcd_store.go b/pkg/metadata/etcd_store.go index faeba759..1d37132b 100644 --- a/pkg/metadata/etcd_store.go +++ b/pkg/metadata/etcd_store.go @@ -39,13 +39,17 @@ type EtcdStoreConfig struct { DialTimeout time.Duration } +// etcdError wraps an error for storage in an atomic.Value, which requires a +// consistent concrete type across all Store calls. +type etcdError struct{ err error } + // EtcdStore uses etcd for offset persistence while delegating metadata to an in-memory snapshot. type EtcdStore struct { client *clientv3.Client metadata *InMemoryStore cancel context.CancelFunc available int32 - lastError atomic.Value + lastError atomic.Value // stores etcdError } func (s *EtcdStore) EtcdClient() *clientv3.Client { @@ -103,7 +107,7 @@ func (s *EtcdStore) recordEtcdResult(err error) { return } atomic.StoreInt32(&s.available, 0) - s.lastError.Store(err) + s.lastError.Store(etcdError{err}) } // NextOffset reads the last committed offset from etcd and returns the next offset to assign. @@ -374,8 +378,8 @@ func (s *EtcdStore) CreatePartitions(ctx context.Context, topic string, partitio part := updated.Topics[0].Partitions[i] state := &metadatapb.PartitionState{ Topic: topic, - Partition: part.PartitionIndex, - LeaderBroker: fmt.Sprintf("%d", part.LeaderID), + Partition: part.Partition, + LeaderBroker: fmt.Sprintf("%d", part.Leader), LeaderEpoch: part.LeaderEpoch, LogStartOffset: 0, LogEndOffset: 0, @@ -387,7 +391,7 @@ func (s *EtcdStore) CreatePartitions(ctx context.Context, topic string, partitio return err } ctx, cancel := context.WithTimeout(ctx, 3*time.Second) - _, err = s.client.Put(ctx, PartitionStateKey(topic, part.PartitionIndex), string(payload)) + _, err = s.client.Put(ctx, PartitionStateKey(topic, part.Partition), string(payload)) cancel() if err != nil { s.recordEtcdResult(err) @@ -424,7 +428,7 @@ func (s *EtcdStore) partitionExists(ctx context.Context, topic string, partition return false, nil } for _, part := range t.Partitions { - if part.PartitionIndex == partition { + if part.Partition == partition { return true, nil } } @@ -441,7 +445,7 @@ func (s *EtcdStore) DeleteTopic(ctx context.Context, name string) error { } var found bool for _, topic := range state.Topics { - if topic.Name == name { + if *topic.Topic == name { found = true break } diff --git a/pkg/metadata/etcd_store_test.go b/pkg/metadata/etcd_store_test.go index 8674366c..503678ae 100644 --- a/pkg/metadata/etcd_store_test.go +++ b/pkg/metadata/etcd_store_test.go @@ -27,6 +27,7 @@ import ( "github.com/KafScale/platform/internal/testutil" metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestEtcdStoreCreateTopicPersistsSnapshot(t *testing.T) { @@ -116,10 +117,10 @@ func TestEtcdStoreDeleteTopicRemovesOffsets(t *testing.T) { ControllerID: 1, Topics: []protocol.MetadataTopic{ { - Name: "orders", + Topic: kmsg.StringPtr("orders"), Partitions: []protocol.MetadataPartition{ - {PartitionIndex: 0, LeaderID: 1, ReplicaNodes: []int32{1}, ISRNodes: []int32{1}}, - {PartitionIndex: 1, LeaderID: 1, ReplicaNodes: []int32{1}, ISRNodes: []int32{1}}, + {Partition: 0, Leader: 1, Replicas: []int32{1}, ISR: []int32{1}}, + {Partition: 1, Leader: 1, Replicas: []int32{1}, ISR: []int32{1}}, }, }, }, @@ -278,7 +279,7 @@ func topicExists(meta *ClusterMetadata, topic string) bool { return false } for _, t := range meta.Topics { - if t.Name == topic && t.ErrorCode == 0 { + if *t.Topic == topic && t.ErrorCode == 0 { return true } } diff --git a/pkg/metadata/store.go b/pkg/metadata/store.go index a3e83a0a..2b8dee38 100644 --- a/pkg/metadata/store.go +++ b/pkg/metadata/store.go @@ -26,6 +26,7 @@ import ( metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) // Store exposes read-only access to cluster metadata used by Kafka protocol handlers. @@ -153,7 +154,7 @@ func filterTopics(all []protocol.MetadataTopic, requested []string) []protocol.M } index := make(map[string]protocol.MetadataTopic, len(all)) for _, topic := range all { - index[topic.Name] = topic + index[*topic.Topic] = topic } result := make([]protocol.MetadataTopic, 0, len(requested)) for _, name := range requested { @@ -162,7 +163,7 @@ func filterTopics(all []protocol.MetadataTopic, requested []string) []protocol.M } else { result = append(result, protocol.MetadataTopic{ ErrorCode: 3, // UNKNOWN_TOPIC_OR_PARTITION - Name: name, + Topic: kmsg.StringPtr(name), }) } } @@ -195,16 +196,17 @@ func cloneTopics(topics []protocol.MetadataTopic) []protocol.MetadataTopic { out := make([]protocol.MetadataTopic, len(topics)) for i, topic := range topics { topicID := topic.TopicID + name := *topic.Topic if topicID == ([16]byte{}) { - topicID = TopicIDForName(topic.Name) + topicID = TopicIDForName(name) } out[i] = protocol.MetadataTopic{ - ErrorCode: topic.ErrorCode, - Name: topic.Name, - TopicID: topicID, - IsInternal: topic.IsInternal, - Partitions: clonePartitions(topic.Partitions), - TopicAuthorizedOperations: topic.TopicAuthorizedOperations, + ErrorCode: topic.ErrorCode, + Topic: kmsg.StringPtr(name), + TopicID: topicID, + IsInternal: topic.IsInternal, + Partitions: clonePartitions(topic.Partitions), + AuthorizedOperations: topic.AuthorizedOperations, } } return out @@ -218,11 +220,11 @@ func clonePartitions(parts []protocol.MetadataPartition) []protocol.MetadataPart for i, part := range parts { out[i] = protocol.MetadataPartition{ ErrorCode: part.ErrorCode, - PartitionIndex: part.PartitionIndex, - LeaderID: part.LeaderID, + Partition: part.Partition, + Leader: part.Leader, LeaderEpoch: part.LeaderEpoch, - ReplicaNodes: cloneInt32Slice(part.ReplicaNodes), - ISRNodes: cloneInt32Slice(part.ISRNodes), + Replicas: cloneInt32Slice(part.Replicas), + ISR: cloneInt32Slice(part.ISR), OfflineReplicas: cloneInt32Slice(part.OfflineReplicas), } } @@ -298,7 +300,7 @@ func (s *InMemoryStore) CreateTopic(ctx context.Context, spec TopicSpec) (*proto s.mu.Lock() defer s.mu.Unlock() for _, topic := range s.state.Topics { - if topic.Name == spec.Name { + if *topic.Topic == spec.Name { return nil, ErrTopicExists } } @@ -309,14 +311,14 @@ func (s *InMemoryStore) CreateTopic(ctx context.Context, spec TopicSpec) (*proto partitions := make([]protocol.MetadataPartition, spec.NumPartitions) for i := range partitions { partitions[i] = protocol.MetadataPartition{ - PartitionIndex: int32(i), - LeaderID: leaderID, - ReplicaNodes: []int32{leaderID}, - ISRNodes: []int32{leaderID}, + Partition: int32(i), + Leader: leaderID, + Replicas: []int32{leaderID}, + ISR: []int32{leaderID}, } } newTopic := protocol.MetadataTopic{ - Name: spec.Name, + Topic: kmsg.StringPtr(spec.Name), TopicID: TopicIDForName(spec.Name), IsInternal: false, Partitions: partitions, @@ -328,11 +330,11 @@ func (s *InMemoryStore) CreateTopic(ctx context.Context, spec TopicSpec) (*proto func topicHasPartition(topics []protocol.MetadataTopic, name string, partition int32) bool { for _, topic := range topics { - if topic.Name != name { + if *topic.Topic != name { continue } for _, part := range topic.Partitions { - if part.PartitionIndex == partition { + if part.Partition == partition { return true } } @@ -358,7 +360,7 @@ func (s *InMemoryStore) FetchTopicConfig(ctx context.Context, topic string) (*me s.mu.RLock() defer s.mu.RUnlock() for _, entry := range s.state.Topics { - if entry.Name != topic { + if *entry.Topic != topic { continue } if cfg, ok := s.topicConfigs[topic]; ok { @@ -383,7 +385,7 @@ func (s *InMemoryStore) UpdateTopicConfig(ctx context.Context, cfg *metadatapb.T defer s.mu.Unlock() var topic *protocol.MetadataTopic for i := range s.state.Topics { - if s.state.Topics[i].Name == cfg.Name { + if *s.state.Topics[i].Topic == cfg.Name { topic = &s.state.Topics[i] break } @@ -412,7 +414,7 @@ func (s *InMemoryStore) CreatePartitions(ctx context.Context, topic string, part defer s.mu.Unlock() var target *protocol.MetadataTopic for i := range s.state.Topics { - if s.state.Topics[i].Name == topic { + if *s.state.Topics[i].Topic == topic { target = &s.state.Topics[i] break } @@ -427,10 +429,10 @@ func (s *InMemoryStore) CreatePartitions(ctx context.Context, topic string, part leaderID := s.defaultLeaderID() for i := current; i < partitionCount; i++ { target.Partitions = append(target.Partitions, protocol.MetadataPartition{ - PartitionIndex: i, - LeaderID: leaderID, - ReplicaNodes: []int32{leaderID}, - ISRNodes: []int32{leaderID}, + Partition: i, + Leader: leaderID, + Replicas: []int32{leaderID}, + ISR: []int32{leaderID}, }) } cfg, ok := s.topicConfigs[topic] @@ -453,7 +455,7 @@ func (s *InMemoryStore) DeleteTopic(ctx context.Context, name string) error { defer s.mu.Unlock() index := -1 for i, topic := range s.state.Topics { - if topic.Name == name { + if *topic.Topic == name { index = i break } @@ -634,10 +636,10 @@ func defaultTopicConfigFromTopic(topic *protocol.MetadataTopic, replicationFacto return &metadatapb.TopicConfig{} } if replicationFactor <= 0 && len(topic.Partitions) > 0 { - replicationFactor = int16(len(topic.Partitions[0].ReplicaNodes)) + replicationFactor = int16(len(topic.Partitions[0].Replicas)) } return &metadatapb.TopicConfig{ - Name: topic.Name, + Name: *topic.Topic, Partitions: int32(len(topic.Partitions)), ReplicationFactor: int32(replicationFactor), RetentionMs: -1, diff --git a/pkg/metadata/store_test.go b/pkg/metadata/store_test.go index aaf56309..f38e9ad4 100644 --- a/pkg/metadata/store_test.go +++ b/pkg/metadata/store_test.go @@ -22,6 +22,7 @@ import ( metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestInMemoryStoreMetadata_AllTopics(t *testing.T) { @@ -32,8 +33,8 @@ func TestInMemoryStoreMetadata_AllTopics(t *testing.T) { }, ControllerID: 1, Topics: []protocol.MetadataTopic{ - {Name: "orders"}, - {Name: "payments"}, + {Topic: kmsg.StringPtr("orders")}, + {Topic: kmsg.StringPtr("payments")}, }, ClusterID: &clusterID, }) @@ -57,7 +58,7 @@ func TestInMemoryStoreMetadata_AllTopics(t *testing.T) { func TestInMemoryStoreMetadata_FilterTopics(t *testing.T) { store := NewInMemoryStore(ClusterMetadata{ Topics: []protocol.MetadataTopic{ - {Name: "orders"}, + {Topic: kmsg.StringPtr("orders")}, }, }) @@ -225,7 +226,7 @@ func TestInMemoryStoreCreateDeleteTopic(t *testing.T) { if err != nil { t.Fatalf("CreateTopic: %v", err) } - if topic == nil || topic.Name != "orders" { + if topic == nil || *topic.Topic != "orders" { t.Fatalf("unexpected topic: %#v", topic) } if _, err := store.CreateTopic(ctx, TopicSpec{Name: "orders", NumPartitions: 1}); err == nil { diff --git a/pkg/operator/snapshot.go b/pkg/operator/snapshot.go index f8241575..2844bb9e 100644 --- a/pkg/operator/snapshot.go +++ b/pkg/operator/snapshot.go @@ -30,6 +30,7 @@ import ( kafscalev1alpha1 "github.com/KafScale/platform/api/v1alpha1" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/twmb/franz-go/pkg/kmsg" ) const ( @@ -79,16 +80,18 @@ func mergeExistingSnapshot(ctx context.Context, endpoints []string, next metadat } seen := make(map[string]struct{}, len(next.Topics)) for _, topic := range next.Topics { - if topic.Name == "" { + name := *topic.Topic + if name == "" { continue } - seen[topic.Name] = struct{}{} + seen[name] = struct{}{} } for _, topic := range existing.Topics { - if topic.Name == "" || topic.ErrorCode != 0 { + name := *topic.Topic + if name == "" || topic.ErrorCode != 0 { continue } - if _, ok := seen[topic.Name]; ok { + if _, ok := seen[name]; ok { continue } next.Topics = append(next.Topics, topic) @@ -169,14 +172,14 @@ func BuildClusterMetadata(cluster *kafscalev1alpha1.KafscaleCluster, topics []ka leader = replicaIDs[i%int32(len(replicaIDs))] } partitions[i] = protocol.MetadataPartition{ - PartitionIndex: i, - LeaderID: leader, - ReplicaNodes: replicaIDs, - ISRNodes: replicaIDs, + Partition: i, + Leader: leader, + Replicas: replicaIDs, + ISR: replicaIDs, } } metaTopics = append(metaTopics, protocol.MetadataTopic{ - Name: topic.Name, + Topic: kmsg.StringPtr(topic.Name), TopicID: metadata.TopicIDForName(topic.Name), IsInternal: false, Partitions: partitions, @@ -289,16 +292,18 @@ func mergeSnapshots(next, existing metadata.ClusterMetadata) metadata.ClusterMet } seen := make(map[string]struct{}, len(next.Topics)) for _, topic := range next.Topics { - if topic.Name == "" { + name := *topic.Topic + if name == "" { continue } - seen[topic.Name] = struct{}{} + seen[name] = struct{}{} } for _, topic := range existing.Topics { - if topic.Name == "" || topic.ErrorCode != 0 { + name := *topic.Topic + if name == "" || topic.ErrorCode != 0 { continue } - if _, ok := seen[topic.Name]; ok { + if _, ok := seen[name]; ok { continue } next.Topics = append(next.Topics, topic) diff --git a/pkg/operator/snapshot_test.go b/pkg/operator/snapshot_test.go index 8fb00979..561cb268 100644 --- a/pkg/operator/snapshot_test.go +++ b/pkg/operator/snapshot_test.go @@ -69,17 +69,17 @@ func TestBuildClusterMetadata(t *testing.T) { if meta.ClusterName == nil || *meta.ClusterName != "prod" { t.Fatalf("unexpected cluster name: %v", meta.ClusterName) } - if len(meta.Topics) != 1 || meta.Topics[0].Name != "orders" { + if len(meta.Topics) != 1 || *meta.Topics[0].Topic != "orders" { t.Fatalf("expected orders topic, got %+v", meta.Topics) } if len(meta.Topics[0].Partitions) != 2 { t.Fatalf("expected 2 partitions, got %d", len(meta.Topics[0].Partitions)) } for _, part := range meta.Topics[0].Partitions { - if len(part.ReplicaNodes) != int(replicas) { + if len(part.Replicas) != int(replicas) { t.Fatalf("partition %+v replica mismatch", part) } - if len(part.ISRNodes) != len(part.ReplicaNodes) { + if len(part.ISR) != len(part.Replicas) { t.Fatalf("partition %+v ISR mismatch", part) } } diff --git a/pkg/protocol/api.go b/pkg/protocol/api.go index a3fafc45..b2375784 100644 --- a/pkg/protocol/api.go +++ b/pkg/protocol/api.go @@ -15,10 +15,11 @@ package protocol -// API keys supported by Kafscale in milestone 1. +// API keys supported by KafScale. const ( APIKeyProduce int16 = 0 APIKeyFetch int16 = 1 + APIKeyListOffsets int16 = 2 APIKeyMetadata int16 = 3 APIKeyOffsetCommit int16 = 8 APIKeyOffsetFetch int16 = 9 @@ -33,16 +34,8 @@ const ( APIKeyCreateTopics int16 = 19 APIKeyDeleteTopics int16 = 20 APIKeyOffsetForLeaderEpoch int16 = 23 - APIKeyListOffsets int16 = 2 APIKeyDescribeConfigs int16 = 32 APIKeyAlterConfigs int16 = 33 APIKeyCreatePartitions int16 = 37 APIKeyDeleteGroups int16 = 42 ) - -// ApiVersion describes the supported version range for an API. -type ApiVersion struct { - APIKey int16 - MinVersion int16 - MaxVersion int16 -} diff --git a/pkg/protocol/encoding.go b/pkg/protocol/encoding.go index bf5bf3cc..7df13609 100644 --- a/pkg/protocol/encoding.go +++ b/pkg/protocol/encoding.go @@ -20,6 +20,8 @@ import ( "fmt" ) +// byteReader is a minimal binary reader used to parse Kafka request headers +// and response header fields that kmsg does not handle. type byteReader struct { buf []byte pos int @@ -58,85 +60,6 @@ func (r *byteReader) Int32() (int32, error) { return int32(binary.BigEndian.Uint32(b)), nil } -func (r *byteReader) Int64() (int64, error) { - b, err := r.read(8) - if err != nil { - return 0, err - } - return int64(binary.BigEndian.Uint64(b)), nil -} - -func (r *byteReader) UUID() ([16]byte, error) { - b, err := r.read(16) - if err != nil { - return [16]byte{}, err - } - var id [16]byte - copy(id[:], b) - return id, nil -} - -func (r *byteReader) Bool() (bool, error) { - b, err := r.read(1) - if err != nil { - return false, err - } - switch b[0] { - case 0: - return false, nil - case 1: - return true, nil - default: - return false, fmt.Errorf("invalid bool: %d", b[0]) - } -} - -func (r *byteReader) String() (string, error) { - l, err := r.Int16() - if err != nil { - return "", err - } - if l < 0 { - return "", fmt.Errorf("invalid string length: %d", l) - } - b, err := r.read(int(l)) - if err != nil { - return "", err - } - return string(b), nil -} - -func (r *byteReader) CompactString() (string, error) { - length, err := r.compactLength() - if err != nil { - return "", err - } - if length < 0 { - return "", fmt.Errorf("compact string is null") - } - b, err := r.read(length) - if err != nil { - return "", err - } - return string(b), nil -} - -func (r *byteReader) CompactNullableString() (*string, error) { - length, err := r.compactLength() - if err != nil { - return nil, err - } - if length < 0 { - return nil, nil - } - b, err := r.read(length) - if err != nil { - return nil, err - } - str := string(b) - return &str, nil -} - func (r *byteReader) NullableString() (*string, error) { l, err := r.Int16() if err != nil { @@ -156,14 +79,40 @@ func (r *byteReader) NullableString() (*string, error) { return &str, nil } -func (r *byteReader) Int8() (int8, error) { - b, err := r.read(1) +func (r *byteReader) UVarint() (uint64, error) { + val, n := binary.Uvarint(r.buf[r.pos:]) + if n <= 0 { + return 0, fmt.Errorf("read uvarint: %d", n) + } + r.pos += n + return val, nil +} + +func (r *byteReader) SkipTaggedFields() error { + count, err := r.UVarint() if err != nil { - return 0, err + return err } - return int8(b[0]), nil + for i := uint64(0); i < count; i++ { + if _, err := r.UVarint(); err != nil { + return err + } + size, err := r.UVarint() + if err != nil { + return err + } + if size == 0 { + continue + } + if _, err := r.read(int(size)); err != nil { + return err + } + } + return nil } +// byteWriter is a minimal binary writer used in tests to construct Kafka +// request headers that kmsg does not handle. type byteWriter struct { buf []byte } @@ -188,28 +137,6 @@ func (w *byteWriter) Int32(v int32) { w.write(tmp[:]) } -func (w *byteWriter) Int64(v int64) { - var tmp [8]byte - binary.BigEndian.PutUint64(tmp[:], uint64(v)) - w.write(tmp[:]) -} - -func (w *byteWriter) UUID(id [16]byte) { - w.write(id[:]) -} - -func (w *byteWriter) Bool(v bool) { - if v { - w.write([]byte{1}) - } else { - w.write([]byte{0}) - } -} - -func (w *byteWriter) Int8(v int8) { - w.write([]byte{byte(v)}) -} - func (w *byteWriter) String(v string) { if v == "" { w.Int16(0) @@ -230,124 +157,12 @@ func (w *byteWriter) NullableString(v *string) { w.String(*v) } -func (w *byteWriter) CompactString(v string) { - w.compactLength(len(v)) - if len(v) > 0 { - w.write([]byte(v)) - } -} - -func (w *byteWriter) CompactNullableString(v *string) { - if v == nil { - w.compactLength(-1) - return - } - w.CompactString(*v) -} - -func (r *byteReader) Bytes() ([]byte, error) { - length, err := r.Int32() - if err != nil { - return nil, err - } - if length < 0 { - return nil, fmt.Errorf("invalid bytes length %d", length) - } - b, err := r.read(int(length)) - if err != nil { - return nil, err - } - return b, nil -} - -func (r *byteReader) CompactBytes() ([]byte, error) { - length, err := r.compactLength() - if err != nil { - return nil, err - } - if length < 0 { - return nil, nil - } - return r.read(length) -} - -func (w *byteWriter) BytesWithLength(b []byte) { - w.Int32(int32(len(b))) - w.write(b) -} - -func (w *byteWriter) CompactBytes(b []byte) { - if b == nil { - w.compactLength(-1) - return - } - w.compactLength(len(b)) - if len(b) > 0 { - w.write(b) - } -} - -func (w *byteWriter) Bytes() []byte { - return w.buf -} - -func (r *byteReader) UVarint() (uint64, error) { - val, n := binary.Uvarint(r.buf[r.pos:]) - if n <= 0 { - return 0, fmt.Errorf("read uvarint: %d", n) - } - r.pos += n - return val, nil -} - func (w *byteWriter) UVarint(v uint64) { var tmp [binary.MaxVarintLen64]byte n := binary.PutUvarint(tmp[:], v) w.write(tmp[:n]) } -func (r *byteReader) CompactArrayLen() (int32, error) { - val, err := r.UVarint() - if err != nil { - return 0, err - } - if val == 0 { - return -1, nil - } - return int32(val - 1), nil -} - -func (w *byteWriter) CompactArrayLen(length int) { - if length < 0 { - w.UVarint(0) - return - } - w.UVarint(uint64(length) + 1) -} - -func (r *byteReader) SkipTaggedFields() error { - count, err := r.UVarint() - if err != nil { - return err - } - for i := uint64(0); i < count; i++ { - if _, err := r.UVarint(); err != nil { - return err - } - size, err := r.UVarint() - if err != nil { - return err - } - if size == 0 { - continue - } - if _, err := r.read(int(size)); err != nil { - return err - } - } - return nil -} - func (w *byteWriter) WriteTaggedFields(count int) { if count == 0 { w.UVarint(0) @@ -356,22 +171,6 @@ func (w *byteWriter) WriteTaggedFields(count int) { w.UVarint(uint64(count)) } -func (r *byteReader) compactLength() (int, error) { - val, err := r.UVarint() - if err != nil { - return 0, err - } - if val == 0 { - return -1, nil - } - length := int(val - 1) - return length, nil -} - -func (w *byteWriter) compactLength(length int) { - if length < 0 { - w.UVarint(0) - return - } - w.UVarint(uint64(length) + 1) +func (w *byteWriter) Bytes() []byte { + return w.buf } diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index 445744a8..3a2d7b37 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -17,9 +17,12 @@ package protocol import ( "fmt" + + "github.com/twmb/franz-go/pkg/kmsg" ) -// RequestHeader matches Kafka RequestHeader v1 (simplified without tagged fields). +// RequestHeader carries the parsed Kafka request header fields. +// kmsg does not handle headers — they must be parsed/written manually. type RequestHeader struct { APIKey int16 APIVersion int16 @@ -27,1920 +30,71 @@ type RequestHeader struct { ClientID *string } -// Request is implemented by concrete protocol requests. -type Request interface { - APIKey() int16 -} - -// ApiVersionsRequest describes the ApiVersions call. -type ApiVersionsRequest struct { - ClientSoftwareName string - ClientSoftwareVersion string -} - -func (ApiVersionsRequest) APIKey() int16 { return APIKeyApiVersion } - -// ProduceRequest is a simplified representation of Kafka ProduceRequest v9. -type ProduceRequest struct { - Acks int16 - TimeoutMs int32 - TransactionalID *string - Topics []ProduceTopic -} - -type ProduceTopic struct { - Name string - Partitions []ProducePartition -} - -type ProducePartition struct { - Partition int32 - Records []byte -} - -func (ProduceRequest) APIKey() int16 { return APIKeyProduce } - -// FetchRequest represents a subset of Kafka FetchRequest v13. -type FetchRequest struct { - ReplicaID int32 - MaxWaitMs int32 - MinBytes int32 - MaxBytes int32 - IsolationLevel int8 - SessionID int32 - SessionEpoch int32 - Topics []FetchTopicRequest -} - -type FetchTopicRequest struct { - Name string - TopicID [16]byte - Partitions []FetchPartitionRequest -} - -type FetchPartitionRequest struct { - Partition int32 - FetchOffset int64 - MaxBytes int32 -} - -func (FetchRequest) APIKey() int16 { return APIKeyFetch } - -// MetadataRequest asks for cluster metadata. Empty Topics means "all". -type MetadataRequest struct { - Topics []string - TopicIDs [][16]byte - AllowAutoTopicCreation bool - IncludeClusterAuthOps bool - IncludeTopicAuthOps bool -} - -func (MetadataRequest) APIKey() int16 { return APIKeyMetadata } - -type CreateTopicConfig struct { - Name string - NumPartitions int32 - ReplicationFactor int16 -} - -type CreateTopicsRequest struct { - Topics []CreateTopicConfig - TimeoutMs int32 - ValidateOnly bool -} - -func (CreateTopicsRequest) APIKey() int16 { return APIKeyCreateTopics } - -type DeleteTopicsRequest struct { - TopicNames []string - TimeoutMs int32 -} - -func (DeleteTopicsRequest) APIKey() int16 { return APIKeyDeleteTopics } - -type ListOffsetsPartition struct { - Partition int32 - Timestamp int64 - MaxNumOffsets int32 - CurrentLeaderEpoch int32 -} - -type ListOffsetsTopic struct { - Name string - Partitions []ListOffsetsPartition -} - -type ListOffsetsRequest struct { - ReplicaID int32 - IsolationLevel int8 - Topics []ListOffsetsTopic -} - -func (ListOffsetsRequest) APIKey() int16 { return APIKeyListOffsets } - -// FindCoordinatorRequest targets a group coordinator lookup. -type FindCoordinatorRequest struct { - KeyType int8 - Key string -} - -func (FindCoordinatorRequest) APIKey() int16 { return APIKeyFindCoordinator } - -type JoinGroupProtocol struct { - Name string - Metadata []byte -} - -type JoinGroupRequest struct { - GroupID string - SessionTimeoutMs int32 - RebalanceTimeoutMs int32 - MemberID string - ProtocolType string - Protocols []JoinGroupProtocol -} - -func (JoinGroupRequest) APIKey() int16 { return APIKeyJoinGroup } - -type SyncGroupAssignment struct { - MemberID string - Assignment []byte -} - -type SyncGroupRequest struct { - GroupID string - GenerationID int32 - MemberID string - Assignments []SyncGroupAssignment -} - -func (SyncGroupRequest) APIKey() int16 { return APIKeySyncGroup } - -type HeartbeatRequest struct { - GroupID string - GenerationID int32 - MemberID string - InstanceID *string -} - -func (HeartbeatRequest) APIKey() int16 { return APIKeyHeartbeat } - -type LeaveGroupRequest struct { - GroupID string - MemberID string -} - -func (LeaveGroupRequest) APIKey() int16 { return APIKeyLeaveGroup } - -type OffsetCommitPartition struct { - Partition int32 - Offset int64 - Metadata string -} - -type OffsetCommitTopic struct { - Name string - Partitions []OffsetCommitPartition -} - -type OffsetCommitRequest struct { - GroupID string - GenerationID int32 - MemberID string - RetentionMs int64 - Topics []OffsetCommitTopic -} - -func (OffsetCommitRequest) APIKey() int16 { return APIKeyOffsetCommit } - -type OffsetFetchPartition struct { - Partition int32 -} - -type OffsetFetchTopic struct { - Name string - Partitions []OffsetFetchPartition -} - -type OffsetFetchRequest struct { - GroupID string - Topics []OffsetFetchTopic -} - -func (OffsetFetchRequest) APIKey() int16 { return APIKeyOffsetFetch } - -type OffsetForLeaderEpochPartition struct { - Partition int32 - CurrentLeaderEpoch int32 - LeaderEpoch int32 -} - -type OffsetForLeaderEpochTopic struct { - Name string - Partitions []OffsetForLeaderEpochPartition -} - -type OffsetForLeaderEpochRequest struct { - ReplicaID int32 - Topics []OffsetForLeaderEpochTopic -} - -func (OffsetForLeaderEpochRequest) APIKey() int16 { return APIKeyOffsetForLeaderEpoch } - -type DescribeConfigsResource struct { - ResourceType int8 - ResourceName string - ConfigNames []string -} - -type DescribeConfigsRequest struct { - Resources []DescribeConfigsResource - IncludeSynonyms bool - IncludeDocumentation bool -} - -func (DescribeConfigsRequest) APIKey() int16 { return APIKeyDescribeConfigs } - -type AlterConfigsResourceConfig struct { - Name string - Value *string -} - -type AlterConfigsResource struct { - ResourceType int8 - ResourceName string - Configs []AlterConfigsResourceConfig -} - -type AlterConfigsRequest struct { - Resources []AlterConfigsResource - ValidateOnly bool -} - -func (AlterConfigsRequest) APIKey() int16 { return APIKeyAlterConfigs } - -type CreatePartitionsAssignment struct { - Replicas []int32 -} - -type CreatePartitionsTopic struct { - Name string - Count int32 - Assignments []CreatePartitionsAssignment -} - -type CreatePartitionsRequest struct { - Topics []CreatePartitionsTopic - TimeoutMs int32 - ValidateOnly bool -} - -func (CreatePartitionsRequest) APIKey() int16 { return APIKeyCreatePartitions } - -type DeleteGroupsRequest struct { - Groups []string -} - -func (DeleteGroupsRequest) APIKey() int16 { return APIKeyDeleteGroups } - -// DescribeGroupsRequest asks for metadata about consumer groups. -type DescribeGroupsRequest struct { - Groups []string - IncludeAuthorizedOperations bool -} - -func (DescribeGroupsRequest) APIKey() int16 { return APIKeyDescribeGroups } - -// ListGroupsRequest enumerates consumer groups with optional filters. -type ListGroupsRequest struct { - StatesFilter []string - TypesFilter []string -} - -func (ListGroupsRequest) APIKey() int16 { return APIKeyListGroups } - -func isFlexibleRequest(apiKey, version int16) bool { - switch apiKey { - case APIKeyApiVersion: - return version >= 3 - case APIKeyProduce: - return version >= 9 - case APIKeyMetadata: - return version >= 9 - case APIKeyFetch: - return version >= 12 - case APIKeyFindCoordinator: - return version >= 3 - case APIKeySyncGroup: - return version >= 4 - case APIKeyHeartbeat: - return version >= 4 - case APIKeyListGroups: - return version >= 3 - case APIKeyDescribeGroups: - return version >= 5 - case APIKeyOffsetForLeaderEpoch: - return version >= 4 - case APIKeyDescribeConfigs: - return version >= 4 - case APIKeyAlterConfigs: - return version >= 2 - case APIKeyCreatePartitions: - return version >= 2 - case APIKeyDeleteGroups: - return version >= 2 - default: - return false - } -} - -func compactArrayLenNonNull(r *byteReader) (int32, error) { - n, err := r.CompactArrayLen() - if err != nil { - return 0, err - } - if n < 0 { - return 0, fmt.Errorf("compact array is null") - } - return n, nil -} - -// ParseRequestHeader decodes the header portion from raw bytes. -func ParseRequestHeader(b []byte) (*RequestHeader, *byteReader, error) { - reader := newByteReader(b) - apiKey, err := reader.Int16() +// ParseRequestHeader decodes the header portion from raw frame bytes and +// returns the header plus the remaining body bytes. +func ParseRequestHeader(b []byte) (*RequestHeader, []byte, error) { + r := newByteReader(b) + apiKey, err := r.Int16() if err != nil { return nil, nil, fmt.Errorf("read api key: %w", err) } - version, err := reader.Int16() + version, err := r.Int16() if err != nil { return nil, nil, fmt.Errorf("read api version: %w", err) } - correlationID, err := reader.Int32() + correlationID, err := r.Int32() if err != nil { return nil, nil, fmt.Errorf("read correlation id: %w", err) } - clientID, err := reader.NullableString() + clientID, err := r.NullableString() if err != nil { return nil, nil, fmt.Errorf("read client id: %w", err) } - if isFlexibleRequest(apiKey, version) { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip header tags: %w", err) + + // Flexible versions (KIP-482) add tagged fields to the header. + req := kmsg.RequestForKey(apiKey) + if req != nil { + req.SetVersion(version) + if req.IsFlexible() { + if err := r.SkipTaggedFields(); err != nil { + return nil, nil, fmt.Errorf("skip header tags: %w", err) + } } } - return &RequestHeader{ + + header := &RequestHeader{ APIKey: apiKey, APIVersion: version, CorrelationID: correlationID, ClientID: clientID, - }, reader, nil + } + return header, b[r.pos:], nil } -// ParseRequest decodes a request header and body from bytes. -func ParseRequest(b []byte) (*RequestHeader, Request, error) { - header, reader, err := ParseRequestHeader(b) +// ParseRequest decodes a full request frame (header + body) into the parsed +// header and a concrete kmsg request type. +func ParseRequest(b []byte) (*RequestHeader, kmsg.Request, error) { + header, body, err := ParseRequestHeader(b) if err != nil { return nil, nil, err } - flexible := isFlexibleRequest(header.APIKey, header.APIVersion) - - var req Request - switch header.APIKey { - case APIKeyApiVersion: - apiReq := &ApiVersionsRequest{} - if header.APIVersion >= 3 { - var err error - if flexible { - apiReq.ClientSoftwareName, err = reader.CompactString() - } else { - apiReq.ClientSoftwareName, err = reader.String() - } - if err != nil { - return nil, nil, err - } - if flexible { - apiReq.ClientSoftwareVersion, err = reader.CompactString() - } else { - apiReq.ClientSoftwareVersion, err = reader.String() - } - if err != nil { - return nil, nil, err - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, err - } - } - } - req = apiReq - case APIKeyProduce: - var transactionalID *string - var err error - if header.APIVersion >= 3 { - if flexible { - transactionalID, err = reader.CompactNullableString() - } else { - transactionalID, err = reader.NullableString() - } - if err != nil { - return nil, nil, fmt.Errorf("read produce transactional id: %w", err) - } - } - acks, err := reader.Int16() - if err != nil { - return nil, nil, fmt.Errorf("read produce acks: %w", err) - } - timeout, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read produce timeout: %w", err) - } - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(reader) - } else { - topicCount, err = reader.Int32() - if topicCount < 0 { - return nil, nil, fmt.Errorf("read produce topic count: invalid %d", topicCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read produce topic count: %w", err) - } - topics := make([]ProduceTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read produce topic name: %w", err) - } - var partitionCount int32 - if flexible { - partitionCount, err = compactArrayLenNonNull(reader) - } else { - partitionCount, err = reader.Int32() - if partitionCount < 0 { - return nil, nil, fmt.Errorf("read produce partition count: invalid %d", partitionCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read produce partition count: %w", err) - } - partitions := make([]ProducePartition, 0, partitionCount) - for j := int32(0); j < partitionCount; j++ { - index, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read produce partition index: %w", err) - } - var records []byte - if flexible { - records, err = reader.CompactBytes() - } else { - records, err = reader.Bytes() - } - if err != nil { - return nil, nil, fmt.Errorf("read produce records: %w", err) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip partition tags: %w", err) - } - } - partitions = append(partitions, ProducePartition{ - Partition: index, - Records: records, - }) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip topic tags: %w", err) - } - } - topics = append(topics, ProduceTopic{Name: name, Partitions: partitions}) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip produce tags: %w", err) - } - } - req = &ProduceRequest{ - Acks: acks, - TimeoutMs: timeout, - TransactionalID: transactionalID, - Topics: topics, - } - case APIKeyMetadata: - var topics []string - var topicIDs [][16]byte - var count int32 - var err error - if flexible { - count, err = reader.CompactArrayLen() - } else { - count, err = reader.Int32() - } - if err != nil { - return nil, nil, fmt.Errorf("read metadata topic count: %w", err) - } - if count >= 0 { - topics = make([]string, 0, count) - topicIDs = make([][16]byte, 0, count) - for i := int32(0); i < count; i++ { - if header.APIVersion >= 10 { - id, err := reader.UUID() - if err != nil { - return nil, nil, fmt.Errorf("read metadata topic[%d] id: %w", i, err) - } - var namePtr *string - if flexible { - namePtr, err = reader.CompactNullableString() - } else { - namePtr, err = reader.NullableString() - } - if err != nil { - return nil, nil, fmt.Errorf("read metadata topic[%d] name: %w", i, err) - } - if namePtr != nil { - topics = append(topics, *namePtr) - } - topicIDs = append(topicIDs, id) - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip metadata topic[%d] tags: %w", i, err) - } - } - } else { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read metadata topic[%d]: %w", i, err) - } - topics = append(topics, name) - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip metadata topic[%d] tags: %w", i, err) - } - } - } - } - } - allowAutoTopicCreation := true - if header.APIVersion >= 4 { - if allowAutoTopicCreation, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read metadata allow auto topic creation: %w", err) - } - } - includeClusterAuthOps := false - includeTopicAuthOps := false - if header.APIVersion >= 8 && header.APIVersion <= 10 { - if includeClusterAuthOps, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read metadata include cluster auth ops: %w", err) - } - } - if header.APIVersion >= 8 { - if includeTopicAuthOps, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read metadata include topic auth ops: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip metadata tags: %w", err) - } - } - req = &MetadataRequest{ - Topics: topics, - TopicIDs: topicIDs, - AllowAutoTopicCreation: allowAutoTopicCreation, - IncludeClusterAuthOps: includeClusterAuthOps, - IncludeTopicAuthOps: includeTopicAuthOps, - } - case APIKeyCreateTopics: - topicCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - configs := make([]CreateTopicConfig, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - partitions, err := reader.Int32() - if err != nil { - return nil, nil, err - } - repl, err := reader.Int16() - if err != nil { - return nil, nil, err - } - // Configs map (ignored) - cfgCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - for j := int32(0); j < cfgCount; j++ { - if _, err := reader.String(); err != nil { - return nil, nil, err - } - if _, err := reader.String(); err != nil { - return nil, nil, err - } - } - configs = append(configs, CreateTopicConfig{Name: name, NumPartitions: partitions, ReplicationFactor: repl}) - } - timeoutMs, err := reader.Int32() - if err != nil { - return nil, nil, err - } - validateOnly := false - if header.APIVersion >= 1 { - if validateOnly, err = reader.Bool(); err != nil { - return nil, nil, err - } - } - req = &CreateTopicsRequest{Topics: configs, TimeoutMs: timeoutMs, ValidateOnly: validateOnly} - case APIKeyDeleteTopics: - count, err := reader.Int32() - if err != nil { - return nil, nil, err - } - names := make([]string, 0, count) - for i := int32(0); i < count; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - names = append(names, name) - } - timeoutMs, err := reader.Int32() - if err != nil { - return nil, nil, err - } - req = &DeleteTopicsRequest{TopicNames: names, TimeoutMs: timeoutMs} - case APIKeyListOffsets: - replicaID, err := reader.Int32() - if err != nil { - return nil, nil, err - } - isolationLevel := int8(0) - if header.APIVersion >= 2 { - if isolationLevel, err = reader.Int8(); err != nil { - return nil, nil, err - } - } - topicCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - topics := make([]ListOffsetsTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - partCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - parts := make([]ListOffsetsPartition, 0, partCount) - for j := int32(0); j < partCount; j++ { - partition, err := reader.Int32() - if err != nil { - return nil, nil, err - } - leaderEpoch := int32(-1) - if header.APIVersion >= 4 { - if leaderEpoch, err = reader.Int32(); err != nil { - return nil, nil, err - } - } - timestamp, err := reader.Int64() - if err != nil { - return nil, nil, err - } - maxOffsets := int32(1) - if header.APIVersion == 0 { - maxOffsets, err = reader.Int32() - if err != nil { - return nil, nil, err - } - } - parts = append(parts, ListOffsetsPartition{ - Partition: partition, - Timestamp: timestamp, - MaxNumOffsets: maxOffsets, - CurrentLeaderEpoch: leaderEpoch, - }) - } - topics = append(topics, ListOffsetsTopic{Name: name, Partitions: parts}) - } - req = &ListOffsetsRequest{ReplicaID: replicaID, IsolationLevel: isolationLevel, Topics: topics} - case APIKeyFetch: - version := header.APIVersion - replicaID, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read fetch replica id: %w", err) - } - maxWaitMs, err := reader.Int32() - if err != nil { - return nil, nil, err - } - minBytes, err := reader.Int32() - if err != nil { - return nil, nil, err - } - var maxBytes int32 - if version >= 3 { - maxBytes, err = reader.Int32() - if err != nil { - return nil, nil, err - } - } - isolationLevel := int8(0) - if version >= 4 { - if isolationLevel, err = reader.Int8(); err != nil { - return nil, nil, err - } - } - sessionID := int32(0) - sessionEpoch := int32(0) - if version >= 7 { - if sessionID, err = reader.Int32(); err != nil { - return nil, nil, err - } - if sessionEpoch, err = reader.Int32(); err != nil { - return nil, nil, err - } - } - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(reader) - } else { - topicCount, err = reader.Int32() - if topicCount < 0 { - return nil, nil, fmt.Errorf("fetch topic count invalid %d", topicCount) - } - } - if err != nil { - return nil, nil, err - } - - topics := make([]FetchTopicRequest, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var ( - name string - topicID [16]byte - ) - if version >= 12 { - topicID, err = reader.UUID() - if err != nil { - return nil, nil, err - } - } else { - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, err - } - } - var partCount int32 - if flexible { - partCount, err = compactArrayLenNonNull(reader) - } else { - partCount, err = reader.Int32() - if partCount < 0 { - return nil, nil, fmt.Errorf("fetch partition count invalid %d", partCount) - } - } - if err != nil { - return nil, nil, err - } - partitions := make([]FetchPartitionRequest, 0, partCount) - for j := int32(0); j < partCount; j++ { - partitionID, err := reader.Int32() - if err != nil { - return nil, nil, err - } - if version >= 9 { - if _, err := reader.Int32(); err != nil { // leader epoch - return nil, nil, err - } - } - fetchOffset, err := reader.Int64() - if err != nil { - return nil, nil, err - } - if version >= 12 { - if _, err := reader.Int32(); err != nil { // last fetched epoch - return nil, nil, err - } - } - if version >= 5 { - if _, err := reader.Int64(); err != nil { // log start offset - return nil, nil, err - } - } - maxBytes, err := reader.Int32() - if err != nil { - return nil, nil, err - } - partitions = append(partitions, FetchPartitionRequest{ - Partition: partitionID, - FetchOffset: fetchOffset, - MaxBytes: maxBytes, - }) - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip fetch partition tags: %w", err) - } - } - } - topics = append(topics, FetchTopicRequest{ - Name: name, - TopicID: topicID, - Partitions: partitions, - }) - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip fetch topic tags: %w", err) - } - } - } - if version >= 7 { - var forgottenCount int32 - if flexible { - forgottenCount, err = reader.CompactArrayLen() - } else { - forgottenCount, err = reader.Int32() - } - if err != nil { - return nil, nil, fmt.Errorf("read forgotten topics count: %w", err) - } - if forgottenCount > 0 { - for i := int32(0); i < forgottenCount; i++ { - if version >= 12 { - if _, err := reader.UUID(); err != nil { - return nil, nil, fmt.Errorf("read forgotten topic id: %w", err) - } - } else { - if _, err := reader.String(); err != nil { - return nil, nil, fmt.Errorf("read forgotten topic name: %w", err) - } - } - var partCount int32 - if flexible { - partCount, err = reader.CompactArrayLen() - } else { - partCount, err = reader.Int32() - } - if err != nil { - return nil, nil, fmt.Errorf("read forgotten partitions: %w", err) - } - for j := int32(0); j < partCount; j++ { - if _, err := reader.Int32(); err != nil { - return nil, nil, fmt.Errorf("read forgotten partition: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip forgotten topic tags: %w", err) - } - } - } - } - } - if version >= 11 { - if flexible { - if _, err := reader.CompactNullableString(); err != nil { - return nil, nil, fmt.Errorf("read rack id: %w", err) - } - } else { - if _, err := reader.NullableString(); err != nil { - return nil, nil, fmt.Errorf("read rack id: %w", err) - } - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip fetch request tags: %w", err) - } - } - req = &FetchRequest{ - ReplicaID: replicaID, - MaxWaitMs: maxWaitMs, - MinBytes: minBytes, - MaxBytes: maxBytes, - IsolationLevel: isolationLevel, - SessionID: sessionID, - SessionEpoch: sessionEpoch, - Topics: topics, - } - case APIKeyFindCoordinator: - var key string - if flexible { - key, err = reader.CompactString() - } else { - key, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read coordinator key: %w", err) - } - var keyType int8 - if header.APIVersion >= 1 { - if keyType, err = reader.Int8(); err != nil { - return nil, nil, fmt.Errorf("read coordinator key type: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip coordinator tags: %w", err) - } - } - req = &FindCoordinatorRequest{KeyType: keyType, Key: key} - case APIKeyJoinGroup: - groupID, err := reader.String() - if err != nil { - return nil, nil, err - } - sessionTimeout, err := reader.Int32() - if err != nil { - return nil, nil, err - } - rebalanceTimeout, err := reader.Int32() - if err != nil { - return nil, nil, err - } - memberID, err := reader.String() - if err != nil { - return nil, nil, err - } - protocolType, err := reader.String() - if err != nil { - return nil, nil, err - } - protocolCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - protocols := make([]JoinGroupProtocol, 0, protocolCount) - for i := int32(0); i < protocolCount; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - meta, err := reader.Bytes() - if err != nil { - return nil, nil, err - } - protocols = append(protocols, JoinGroupProtocol{Name: name, Metadata: meta}) - } - req = &JoinGroupRequest{ - GroupID: groupID, - SessionTimeoutMs: sessionTimeout, - RebalanceTimeoutMs: rebalanceTimeout, - MemberID: memberID, - ProtocolType: protocolType, - Protocols: protocols, - } - case APIKeySyncGroup: - var groupID string - var err error - if flexible { - groupID, err = reader.CompactString() - } else { - groupID, err = reader.String() - } - if err != nil { - return nil, nil, err - } - generationID, err := reader.Int32() - if err != nil { - return nil, nil, err - } - var memberID string - if flexible { - memberID, err = reader.CompactString() - } else { - memberID, err = reader.String() - } - if err != nil { - return nil, nil, err - } - if header.APIVersion >= 3 { - if flexible { - if _, err := reader.CompactNullableString(); err != nil { - return nil, nil, err - } - } else { - if _, err := reader.NullableString(); err != nil { - return nil, nil, err - } - } - } - if header.APIVersion >= 5 { - if flexible { - if _, err := reader.CompactNullableString(); err != nil { - return nil, nil, err - } - if _, err := reader.CompactNullableString(); err != nil { - return nil, nil, err - } - } else { - if _, err := reader.NullableString(); err != nil { - return nil, nil, err - } - if _, err := reader.NullableString(); err != nil { - return nil, nil, err - } - } - } - var assignCount int32 - if flexible { - if assignCount, err = compactArrayLenNonNull(reader); err != nil { - return nil, nil, err - } - } else { - assignCount, err = reader.Int32() - if err != nil { - return nil, nil, err - } - } - assignments := make([]SyncGroupAssignment, 0, assignCount) - for i := int32(0); i < assignCount; i++ { - var mid string - if flexible { - mid, err = reader.CompactString() - } else { - mid, err = reader.String() - } - if err != nil { - return nil, nil, err - } - var data []byte - if flexible { - data, err = reader.CompactBytes() - } else { - data, err = reader.Bytes() - } - if err != nil { - return nil, nil, err - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip sync assignment tags: %w", err) - } - } - assignments = append(assignments, SyncGroupAssignment{MemberID: mid, Assignment: data}) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip sync group tags: %w", err) - } - } - req = &SyncGroupRequest{ - GroupID: groupID, - GenerationID: generationID, - MemberID: memberID, - Assignments: assignments, - } - case APIKeyHeartbeat: - var err error - var groupID string - if flexible { - groupID, err = reader.CompactString() - } else { - groupID, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read heartbeat group id: %w", err) - } - generationID, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read heartbeat generation: %w", err) - } - var memberID string - if flexible { - memberID, err = reader.CompactString() - } else { - memberID, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read heartbeat member id: %w", err) - } - var instanceID *string - if header.APIVersion >= 3 { - if flexible { - instanceID, err = reader.CompactNullableString() - } else { - instanceID, err = reader.NullableString() - } - if err != nil { - return nil, nil, fmt.Errorf("read heartbeat group instance id: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip heartbeat tags: %w", err) - } - } - req = &HeartbeatRequest{ - GroupID: groupID, - GenerationID: generationID, - MemberID: memberID, - InstanceID: instanceID, - } - case APIKeyLeaveGroup: - groupID, err := reader.String() - if err != nil { - return nil, nil, err - } - memberID, err := reader.String() - if err != nil { - return nil, nil, err - } - req = &LeaveGroupRequest{ - GroupID: groupID, - MemberID: memberID, - } - case APIKeyOffsetCommit: - version := header.APIVersion - if version != 3 { - return nil, nil, fmt.Errorf("offset commit version %d not supported", version) - } - groupID, err := reader.String() - if err != nil { - return nil, nil, err - } - generationID, err := reader.Int32() - if err != nil { - return nil, nil, err - } - memberID, err := reader.String() - if err != nil { - return nil, nil, err - } - var retentionMs int64 - if version >= 2 && version <= 4 { - retentionMs, err = reader.Int64() - if err != nil { - return nil, nil, err - } - } - topicCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - topics := make([]OffsetCommitTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - partCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - partitions := make([]OffsetCommitPartition, 0, partCount) - for j := int32(0); j < partCount; j++ { - partition, err := reader.Int32() - if err != nil { - return nil, nil, err - } - offset, err := reader.Int64() - if err != nil { - return nil, nil, err - } - metaPtr, err := reader.NullableString() - if err != nil { - return nil, nil, err - } - meta := "" - if metaPtr != nil { - meta = *metaPtr - } - partitions = append(partitions, OffsetCommitPartition{ - Partition: partition, - Offset: offset, - Metadata: meta, - }) - } - topics = append(topics, OffsetCommitTopic{Name: name, Partitions: partitions}) - } - req = &OffsetCommitRequest{ - GroupID: groupID, - GenerationID: generationID, - MemberID: memberID, - RetentionMs: retentionMs, - Topics: topics, - } - case APIKeyOffsetFetch: - groupID, err := reader.String() - if err != nil { - return nil, nil, err - } - topicCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - topics := make([]OffsetFetchTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - name, err := reader.String() - if err != nil { - return nil, nil, err - } - partCount, err := reader.Int32() - if err != nil { - return nil, nil, err - } - partitions := make([]OffsetFetchPartition, 0, partCount) - for j := int32(0); j < partCount; j++ { - partition, err := reader.Int32() - if err != nil { - return nil, nil, err - } - partitions = append(partitions, OffsetFetchPartition{Partition: partition}) - } - topics = append(topics, OffsetFetchTopic{Name: name, Partitions: partitions}) - } - req = &OffsetFetchRequest{ - GroupID: groupID, - Topics: topics, - } - case APIKeyOffsetForLeaderEpoch: - replicaID := int32(-2) - if header.APIVersion >= 3 { - if replicaID, err = reader.Int32(); err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch replica id: %w", err) - } - } - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(reader) - } else { - topicCount, err = reader.Int32() - if topicCount < 0 { - return nil, nil, fmt.Errorf("offset for leader epoch topic count invalid %d", topicCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch topic count: %w", err) - } - topics := make([]OffsetForLeaderEpochTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch topic[%d]: %w", i, err) - } - var partCount int32 - if flexible { - partCount, err = compactArrayLenNonNull(reader) - } else { - partCount, err = reader.Int32() - if partCount < 0 { - return nil, nil, fmt.Errorf("offset for leader epoch partition count invalid %d", partCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch partition count: %w", err) - } - partitions := make([]OffsetForLeaderEpochPartition, 0, partCount) - for j := int32(0); j < partCount; j++ { - partitionID, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch partition: %w", err) - } - currentLeaderEpoch := int32(-1) - if header.APIVersion >= 2 { - if currentLeaderEpoch, err = reader.Int32(); err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch current leader epoch: %w", err) - } - } - leaderEpoch, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read offset for leader epoch leader epoch: %w", err) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip offset for leader epoch partition tags: %w", err) - } - } - partitions = append(partitions, OffsetForLeaderEpochPartition{ - Partition: partitionID, - CurrentLeaderEpoch: currentLeaderEpoch, - LeaderEpoch: leaderEpoch, - }) - } - topics = append(topics, OffsetForLeaderEpochTopic{Name: name, Partitions: partitions}) - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip offset for leader epoch topic tags: %w", err) - } - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip offset for leader epoch tags: %w", err) - } - } - req = &OffsetForLeaderEpochRequest{ - ReplicaID: replicaID, - Topics: topics, - } - case APIKeyDescribeConfigs: - var resourceCount int32 - if flexible { - resourceCount, err = compactArrayLenNonNull(reader) - } else { - resourceCount, err = reader.Int32() - if resourceCount < 0 { - return nil, nil, fmt.Errorf("describe configs resource count invalid %d", resourceCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read describe configs resource count: %w", err) - } - resources := make([]DescribeConfigsResource, 0, resourceCount) - for i := int32(0); i < resourceCount; i++ { - resourceType, err := reader.Int8() - if err != nil { - return nil, nil, fmt.Errorf("read describe configs resource type: %w", err) - } - var resourceName string - if flexible { - resourceName, err = reader.CompactString() - } else { - resourceName, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read describe configs resource name: %w", err) - } - var configCount int32 - if flexible { - configCount, err = reader.CompactArrayLen() - } else { - configCount, err = reader.Int32() - } - if err != nil { - return nil, nil, fmt.Errorf("read describe configs config count: %w", err) - } - var configNames []string - if configCount >= 0 { - configNames = make([]string, 0, configCount) - for j := int32(0); j < configCount; j++ { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read describe configs config name: %w", err) - } - configNames = append(configNames, name) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip describe configs resource tags: %w", err) - } - } - resources = append(resources, DescribeConfigsResource{ - ResourceType: resourceType, - ResourceName: resourceName, - ConfigNames: configNames, - }) - } - includeSynonyms := false - if header.APIVersion >= 1 { - if includeSynonyms, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read describe configs include synonyms: %w", err) - } - } - includeDocumentation := false - if header.APIVersion >= 3 { - if includeDocumentation, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read describe configs include docs: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip describe configs tags: %w", err) - } - } - req = &DescribeConfigsRequest{ - Resources: resources, - IncludeSynonyms: includeSynonyms, - IncludeDocumentation: includeDocumentation, - } - case APIKeyAlterConfigs: - var resourceCount int32 - if flexible { - resourceCount, err = compactArrayLenNonNull(reader) - } else { - resourceCount, err = reader.Int32() - if resourceCount < 0 { - return nil, nil, fmt.Errorf("alter configs resource count invalid %d", resourceCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read alter configs resource count: %w", err) - } - resources := make([]AlterConfigsResource, 0, resourceCount) - for i := int32(0); i < resourceCount; i++ { - resourceType, err := reader.Int8() - if err != nil { - return nil, nil, fmt.Errorf("read alter configs resource type: %w", err) - } - var resourceName string - if flexible { - resourceName, err = reader.CompactString() - } else { - resourceName, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read alter configs resource name: %w", err) - } - var configCount int32 - if flexible { - configCount, err = compactArrayLenNonNull(reader) - } else { - configCount, err = reader.Int32() - if configCount < 0 { - return nil, nil, fmt.Errorf("alter configs config count invalid %d", configCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read alter configs config count: %w", err) - } - configs := make([]AlterConfigsResourceConfig, 0, configCount) - for j := int32(0); j < configCount; j++ { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read alter configs config name: %w", err) - } - var value *string - if flexible { - value, err = reader.CompactNullableString() - } else { - value, err = reader.NullableString() - } - if err != nil { - return nil, nil, fmt.Errorf("read alter configs config value: %w", err) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip alter configs config tags: %w", err) - } - } - configs = append(configs, AlterConfigsResourceConfig{Name: name, Value: value}) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip alter configs resource tags: %w", err) - } - } - resources = append(resources, AlterConfigsResource{ - ResourceType: resourceType, - ResourceName: resourceName, - Configs: configs, - }) - } - validateOnly, err := reader.Bool() - if err != nil { - return nil, nil, fmt.Errorf("read alter configs validate only: %w", err) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip alter configs tags: %w", err) - } - } - req = &AlterConfigsRequest{ - Resources: resources, - ValidateOnly: validateOnly, - } - case APIKeyCreatePartitions: - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(reader) - } else { - topicCount, err = reader.Int32() - if topicCount < 0 { - return nil, nil, fmt.Errorf("create partitions topic count invalid %d", topicCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic count: %w", err) - } - topics := make([]CreatePartitionsTopic, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var name string - if flexible { - name, err = reader.CompactString() - } else { - name, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic[%d] name: %w", i, err) - } - count, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic[%d] count: %w", i, err) - } - var assignmentCount int32 - if flexible { - assignmentCount, err = reader.CompactArrayLen() - } else { - assignmentCount, err = reader.Int32() - } - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic[%d] assignment count: %w", i, err) - } - var assignments []CreatePartitionsAssignment - if assignmentCount >= 0 { - assignments = make([]CreatePartitionsAssignment, 0, assignmentCount) - for j := int32(0); j < assignmentCount; j++ { - var replicaCount int32 - if flexible { - replicaCount, err = compactArrayLenNonNull(reader) - } else { - replicaCount, err = reader.Int32() - if replicaCount < 0 { - return nil, nil, fmt.Errorf("create partitions topic[%d] assignment[%d] replica count invalid %d", i, j, replicaCount) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic[%d] assignment[%d] replica count: %w", i, j, err) - } - replicas := make([]int32, 0, replicaCount) - for k := int32(0); k < replicaCount; k++ { - replica, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read create partitions topic[%d] assignment[%d] replica[%d]: %w", i, j, k, err) - } - replicas = append(replicas, replica) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip create partitions assignment tags: %w", err) - } - } - assignments = append(assignments, CreatePartitionsAssignment{Replicas: replicas}) - } - } else if assignmentCount < -1 { - return nil, nil, fmt.Errorf("create partitions topic[%d] assignment count invalid %d", i, assignmentCount) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip create partitions topic tags: %w", err) - } - } - topics = append(topics, CreatePartitionsTopic{ - Name: name, - Count: count, - Assignments: assignments, - }) - } - timeoutMs, err := reader.Int32() - if err != nil { - return nil, nil, fmt.Errorf("read create partitions timeout: %w", err) - } - validateOnly, err := reader.Bool() - if err != nil { - return nil, nil, fmt.Errorf("read create partitions validate only: %w", err) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip create partitions tags: %w", err) - } - } - req = &CreatePartitionsRequest{ - Topics: topics, - TimeoutMs: timeoutMs, - ValidateOnly: validateOnly, - } - case APIKeyDeleteGroups: - var count int32 - if flexible { - count, err = compactArrayLenNonNull(reader) - } else { - count, err = reader.Int32() - if count < 0 { - return nil, nil, fmt.Errorf("delete groups count invalid %d", count) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read delete groups count: %w", err) - } - groups := make([]string, 0, count) - for i := int32(0); i < count; i++ { - var group string - if flexible { - group, err = reader.CompactString() - } else { - group, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read delete groups[%d]: %w", i, err) - } - groups = append(groups, group) - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip delete groups tags: %w", err) - } - } - req = &DeleteGroupsRequest{Groups: groups} - case APIKeyDescribeGroups: - var count int32 - if flexible { - count, err = compactArrayLenNonNull(reader) - } else { - count, err = reader.Int32() - if count < 0 { - return nil, nil, fmt.Errorf("describe groups count invalid %d", count) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read describe groups count: %w", err) - } - groups := make([]string, 0, count) - for i := int32(0); i < count; i++ { - var group string - if flexible { - group, err = reader.CompactString() - } else { - group, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read describe group[%d]: %w", i, err) - } - groups = append(groups, group) - } - includeAuthorizedOperations := false - if header.APIVersion >= 3 { - if includeAuthorizedOperations, err = reader.Bool(); err != nil { - return nil, nil, fmt.Errorf("read describe groups include auth ops: %w", err) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip describe groups tags: %w", err) - } - } - req = &DescribeGroupsRequest{ - Groups: groups, - IncludeAuthorizedOperations: includeAuthorizedOperations, - } - case APIKeyListGroups: - var ( - states []string - types []string - ) - if header.APIVersion >= 4 { - var count int32 - if flexible { - count, err = reader.CompactArrayLen() - } else { - count, err = reader.Int32() - if count < 0 { - return nil, nil, fmt.Errorf("list groups states count invalid %d", count) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read list groups states count: %w", err) - } - if count < 0 { - count = 0 - } - states = make([]string, 0, count) - for i := int32(0); i < count; i++ { - var state string - if flexible { - state, err = reader.CompactString() - } else { - state, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read list groups state[%d]: %w", i, err) - } - states = append(states, state) - } - } - if header.APIVersion >= 5 { - var count int32 - if flexible { - count, err = reader.CompactArrayLen() - } else { - count, err = reader.Int32() - if count < 0 { - return nil, nil, fmt.Errorf("list groups types count invalid %d", count) - } - } - if err != nil { - return nil, nil, fmt.Errorf("read list groups types count: %w", err) - } - if count < 0 { - count = 0 - } - types = make([]string, 0, count) - for i := int32(0); i < count; i++ { - var groupType string - if flexible { - groupType, err = reader.CompactString() - } else { - groupType, err = reader.String() - } - if err != nil { - return nil, nil, fmt.Errorf("read list groups type[%d]: %w", i, err) - } - types = append(types, groupType) - } - } - if flexible { - if err := reader.SkipTaggedFields(); err != nil { - return nil, nil, fmt.Errorf("skip list groups tags: %w", err) - } - } - req = &ListGroupsRequest{StatesFilter: states, TypesFilter: types} - default: - return nil, nil, fmt.Errorf("unsupported api key %d", header.APIKey) - } - - return header, req, nil -} - -// EncodeFetchRequest encodes a fetch request. Mirrors ParseRequest's fetch case. -func EncodeFetchRequest(header *RequestHeader, req *FetchRequest, version int16) ([]byte, error) { - w := newByteWriter(256) - flexible := isFlexibleRequest(APIKeyFetch, version) - - w.Int16(header.APIKey) - w.Int16(header.APIVersion) - w.Int32(header.CorrelationID) - w.NullableString(header.ClientID) - if flexible { - w.WriteTaggedFields(0) - } - - w.Int32(req.ReplicaID) - w.Int32(req.MaxWaitMs) - w.Int32(req.MinBytes) - if version >= 3 { - w.Int32(req.MaxBytes) - } - if version >= 4 { - w.Int8(req.IsolationLevel) - } - if version >= 7 { - w.Int32(req.SessionID) - w.Int32(req.SessionEpoch) - } - - if flexible { - w.CompactArrayLen(len(req.Topics)) - } else { - w.Int32(int32(len(req.Topics))) - } - for _, topic := range req.Topics { - if version >= 12 { - w.UUID(topic.TopicID) - } else { - if flexible { - w.CompactString(topic.Name) - } else { - w.String(topic.Name) - } - } - if flexible { - w.CompactArrayLen(len(topic.Partitions)) - } else { - w.Int32(int32(len(topic.Partitions))) - } - for _, part := range topic.Partitions { - w.Int32(part.Partition) - if version >= 9 { - w.Int32(-1) // leader epoch (unknown) - } - w.Int64(part.FetchOffset) - if version >= 12 { - w.Int32(-1) // last fetched epoch (unknown) - } - if version >= 5 { - w.Int64(-1) // log start offset - } - w.Int32(part.MaxBytes) - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - - // Forgotten topics (empty) - if version >= 7 { - if flexible { - w.CompactArrayLen(0) - } else { - w.Int32(0) - } - } - - // Rack ID (v11+) - if version >= 11 { - if flexible { - w.CompactString("") - } else { - w.String("") - } - } - - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil + return ParseRequestBody(header, body) } -// EncodeProduceRequest serializes a RequestHeader + ProduceRequest into wire-format -// bytes suitable for WriteFrame. The encoding mirrors what ParseRequest expects. -func EncodeProduceRequest(header *RequestHeader, req *ProduceRequest, version int16) ([]byte, error) { - w := newByteWriter(256) - flexible := isFlexibleRequest(APIKeyProduce, version) - - w.Int16(header.APIKey) - w.Int16(header.APIVersion) - w.Int32(header.CorrelationID) - w.NullableString(header.ClientID) - if flexible { - w.WriteTaggedFields(0) +// ParseRequestBody decodes the body bytes into a concrete kmsg request using +// the API key and version from the already-parsed header. Use this when the +// header has already been parsed via ParseRequestHeader to avoid re-parsing. +func ParseRequestBody(header *RequestHeader, body []byte) (*RequestHeader, kmsg.Request, error) { + req := kmsg.RequestForKey(header.APIKey) + if req == nil { + return nil, nil, fmt.Errorf("unsupported api key %d", header.APIKey) } + req.SetVersion(header.APIVersion) - // Body: transactional_id (v3+), acks, timeout, topics. - if version >= 3 { - if flexible { - w.CompactNullableString(req.TransactionalID) - } else { - w.NullableString(req.TransactionalID) - } + if err := req.ReadFrom(body); err != nil { + return nil, nil, fmt.Errorf("decode %s v%d: %w", + kmsg.NameForKey(header.APIKey), header.APIVersion, err) } - w.Int16(req.Acks) - w.Int32(req.TimeoutMs) - if flexible { - w.CompactArrayLen(len(req.Topics)) - } else { - w.Int32(int32(len(req.Topics))) - } - for _, topic := range req.Topics { - if flexible { - w.CompactString(topic.Name) - } else { - w.String(topic.Name) - } - if flexible { - w.CompactArrayLen(len(topic.Partitions)) - } else { - w.Int32(int32(len(topic.Partitions))) - } - for _, part := range topic.Partitions { - w.Int32(part.Partition) - if flexible { - w.CompactBytes(part.Records) - } else { - w.BytesWithLength(part.Records) - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil + return header, req, nil } diff --git a/pkg/protocol/request_test.go b/pkg/protocol/request_test.go index d18afab2..92b17f41 100644 --- a/pkg/protocol/request_test.go +++ b/pkg/protocol/request_test.go @@ -21,488 +21,30 @@ import ( "github.com/twmb/franz-go/pkg/kmsg" ) -func TestParseApiVersionsRequest(t *testing.T) { - w := newByteWriter(16) - w.Int16(APIKeyApiVersion) - w.Int16(0) - w.Int32(42) - w.NullableString(nil) - - header, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - if header.APIKey != APIKeyApiVersion || header.CorrelationID != 42 { - t.Fatalf("unexpected header: %#v", header) - } - if _, ok := req.(*ApiVersionsRequest); !ok { - t.Fatalf("expected ApiVersionsRequest got %T", req) - } -} - -func TestParseApiVersionsRequestV3(t *testing.T) { - w := newByteWriter(32) - w.Int16(APIKeyApiVersion) - w.Int16(3) - w.Int32(7) - w.NullableString(nil) - w.WriteTaggedFields(0) - w.CompactString("kgo") - w.CompactString("1.0.0") - w.WriteTaggedFields(0) - - header, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - apiReq, ok := req.(*ApiVersionsRequest) - if !ok { - t.Fatalf("expected ApiVersionsRequest got %T", req) - } - if header.APIVersion != 3 { - t.Fatalf("unexpected api versions request version %d", header.APIVersion) - } - if apiReq.ClientSoftwareName != "kgo" || apiReq.ClientSoftwareVersion != "1.0.0" { - t.Fatalf("unexpected client info: %#v", apiReq) - } -} - -func TestParseMetadataRequest(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyMetadata) - w.Int16(0) - w.Int32(7) - clientID := "client-1" - w.NullableString(&clientID) - w.Int32(2) - w.String("orders") - w.String("payments") - - header, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - metaReq, ok := req.(*MetadataRequest) - if !ok { - t.Fatalf("expected MetadataRequest got %T", req) - } - if len(metaReq.Topics) != 2 || metaReq.Topics[0] != "orders" { - t.Fatalf("unexpected topics: %#v", metaReq.Topics) - } - if header.ClientID == nil || *header.ClientID != "client-1" { - t.Fatalf("client id mismatch: %#v", header.ClientID) - } -} - -func TestParseListOffsetsRequestV0(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyListOffsets) - w.Int16(0) - w.Int32(23) - w.NullableString(nil) - w.Int32(-1) - w.Int32(1) - w.String("orders") - w.Int32(1) - w.Int32(0) - w.Int64(-1) - w.Int32(1) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*ListOffsetsRequest) - if !ok { - t.Fatalf("expected ListOffsetsRequest got %T", req) - } - if parsed.ReplicaID != -1 || len(parsed.Topics) != 1 { - t.Fatalf("unexpected list offsets request: %#v", parsed) - } - part := parsed.Topics[0].Partitions[0] - if part.Partition != 0 || part.Timestamp != -1 || part.MaxNumOffsets != 1 { - t.Fatalf("unexpected list offsets partition: %#v", part) - } -} - -func TestParseListOffsetsRequestV2(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyListOffsets) - w.Int16(2) - w.Int32(24) - w.NullableString(nil) - w.Int32(-1) - w.Int8(1) - w.Int32(1) - w.String("orders") - w.Int32(1) - w.Int32(0) - w.Int64(-2) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*ListOffsetsRequest) - if !ok { - t.Fatalf("expected ListOffsetsRequest got %T", req) - } - if parsed.ReplicaID != -1 || parsed.IsolationLevel != 1 { - t.Fatalf("unexpected list offsets request: %#v", parsed) - } - part := parsed.Topics[0].Partitions[0] - if part.Partition != 0 || part.Timestamp != -2 || part.MaxNumOffsets != 1 || part.CurrentLeaderEpoch != -1 { - t.Fatalf("unexpected list offsets partition: %#v", part) - } -} - -func TestParseListOffsetsRequestV4(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyListOffsets) - w.Int16(4) - w.Int32(25) - w.NullableString(nil) - w.Int32(-1) - w.Int8(0) - w.Int32(1) - w.String("orders") - w.Int32(1) - w.Int32(0) - w.Int32(3) - w.Int64(-1) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*ListOffsetsRequest) - if !ok { - t.Fatalf("expected ListOffsetsRequest got %T", req) - } - part := parsed.Topics[0].Partitions[0] - if part.CurrentLeaderEpoch != 3 || part.Timestamp != -1 { - t.Fatalf("unexpected list offsets partition: %#v", part) - } -} - -func TestParseCreateTopicsRequestV1(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyCreateTopics) - w.Int16(1) - w.Int32(11) - w.NullableString(nil) - w.Int32(1) - w.String("orders") - w.Int32(3) - w.Int16(1) - w.Int32(0) - w.Int32(15000) - w.Bool(true) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*CreateTopicsRequest) - if !ok { - t.Fatalf("expected CreateTopicsRequest got %T", req) - } - if parsed.TimeoutMs != 15000 || !parsed.ValidateOnly { - t.Fatalf("unexpected create topics request: %#v", parsed) - } - if len(parsed.Topics) != 1 || parsed.Topics[0].Name != "orders" { - t.Fatalf("unexpected create topics: %#v", parsed.Topics) - } -} - -func TestParseDeleteTopicsRequestV1(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyDeleteTopics) - w.Int16(1) - w.Int32(12) - w.NullableString(nil) - w.Int32(2) - w.String("orders") - w.String("payments") - w.Int32(12000) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*DeleteTopicsRequest) - if !ok { - t.Fatalf("expected DeleteTopicsRequest got %T", req) - } - if parsed.TimeoutMs != 12000 || len(parsed.TopicNames) != 2 { - t.Fatalf("unexpected delete topics request: %#v", parsed) - } -} - -func TestParseProduceRequest(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyProduce) - w.Int16(9) - w.Int32(100) - clientID := "producer-1" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.CompactNullableString(nil) - w.Int16(1) // acks - w.Int32(1500) - w.CompactArrayLen(1) // topic count - w.CompactString("orders") - w.CompactArrayLen(1) // partitions - w.Int32(0) // partition id - batch := []byte("record") // placeholder bytes - w.CompactBytes(batch) - // partition tagged fields (count=1, tag=0, len=1, val=0x7f) - w.UVarint(1) - w.UVarint(0) - w.UVarint(1) - w.write([]byte{0x7f}) - w.WriteTaggedFields(0) // topic tags - w.WriteTaggedFields(0) // request tags - // fmt.Printf(\"% x\\n\", w.Bytes()) - - header, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - if header.APIKey != APIKeyProduce { - t.Fatalf("unexpected api key %d", header.APIKey) - } - produceReq, ok := req.(*ProduceRequest) - if !ok { - t.Fatalf("expected ProduceRequest got %T", req) - } - if produceReq.Acks != 1 || len(produceReq.Topics) != 1 { - t.Fatalf("produce data mismatch: %#v", produceReq) - } - if string(produceReq.Topics[0].Partitions[0].Records) != "record" { - t.Fatalf("records mismatch") - } -} - -func TestParseProduceRequestInvalidCompactArray(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyProduce) - w.Int16(9) - w.Int32(1) - w.NullableString(nil) - w.WriteTaggedFields(0) - w.CompactNullableString(nil) - w.Int16(1) - w.Int32(100) - w.UVarint(0) // compact array len => null - - if _, _, err := ParseRequest(w.Bytes()); err == nil { - t.Fatalf("expected error for null topic array") - } -} - -func TestParseDescribeGroupsRequestV5(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyDescribeGroups) - w.Int16(5) - w.Int32(11) - w.NullableString(nil) - w.WriteTaggedFields(0) // header tags - w.CompactArrayLen(2) - w.CompactString("group-1") - w.CompactString("group-2") - w.Bool(true) - w.WriteTaggedFields(0) // request tags - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*DescribeGroupsRequest) - if !ok { - t.Fatalf("expected DescribeGroupsRequest got %T", req) - } - if len(parsed.Groups) != 2 || parsed.Groups[0] != "group-1" { - t.Fatalf("unexpected groups: %#v", parsed.Groups) - } - if !parsed.IncludeAuthorizedOperations { - t.Fatalf("expected IncludeAuthorizedOperations true") - } -} - -func TestParseListGroupsRequestV5(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyListGroups) - w.Int16(5) - w.Int32(12) - w.NullableString(nil) - w.WriteTaggedFields(0) // header tags - w.CompactArrayLen(1) - w.CompactString("Stable") - w.CompactArrayLen(1) - w.CompactString("classic") - w.WriteTaggedFields(0) // request tags - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*ListGroupsRequest) - if !ok { - t.Fatalf("expected ListGroupsRequest got %T", req) - } - if len(parsed.StatesFilter) != 1 || parsed.StatesFilter[0] != "Stable" { - t.Fatalf("unexpected states filter: %#v", parsed.StatesFilter) - } - if len(parsed.TypesFilter) != 1 || parsed.TypesFilter[0] != "classic" { - t.Fatalf("unexpected types filter: %#v", parsed.TypesFilter) - } -} - -func TestParseOffsetForLeaderEpochRequestV3(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyOffsetForLeaderEpoch) - w.Int16(3) - w.Int32(21) - w.NullableString(nil) - w.Int32(-1) // replica id - w.Int32(1) // topic count - w.String("logs") - w.Int32(1) // partition count - w.Int32(0) // partition - w.Int32(1) // current leader epoch - w.Int32(1) // leader epoch - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*OffsetForLeaderEpochRequest) - if !ok { - t.Fatalf("expected OffsetForLeaderEpochRequest got %T", req) - } - if parsed.ReplicaID != -1 || len(parsed.Topics) != 1 || parsed.Topics[0].Name != "logs" { - t.Fatalf("unexpected offset for leader epoch request: %#v", parsed) - } -} - -func TestParseDescribeConfigsRequestV4(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyDescribeConfigs) - w.Int16(4) - w.Int32(31) - w.NullableString(nil) - w.WriteTaggedFields(0) - w.CompactArrayLen(1) - w.Int8(ConfigResourceTopic) - w.CompactString("orders") - w.CompactArrayLen(2) - w.CompactString("retention.ms") - w.CompactString("segment.bytes") - w.WriteTaggedFields(0) // resource tags - w.Bool(false) // include synonyms - w.Bool(false) // include docs - w.WriteTaggedFields(0) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*DescribeConfigsRequest) - if !ok { - t.Fatalf("expected DescribeConfigsRequest got %T", req) - } - if len(parsed.Resources) != 1 || parsed.Resources[0].ResourceName != "orders" { - t.Fatalf("unexpected describe configs request: %#v", parsed) - } -} -func TestParseAlterConfigsRequestV1(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyAlterConfigs) - w.Int16(1) - w.Int32(41) - w.NullableString(nil) - w.Int32(1) // resource count - w.Int8(ConfigResourceTopic) - w.String("orders") - w.Int32(1) - w.String("retention.ms") - value := "1000" - w.NullableString(&value) - w.Bool(false) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*AlterConfigsRequest) - if !ok { - t.Fatalf("expected AlterConfigsRequest got %T", req) - } - if len(parsed.Resources) != 1 || parsed.Resources[0].ResourceName != "orders" { - t.Fatalf("unexpected alter configs request: %#v", parsed) - } -} - -func TestParseCreatePartitionsRequestV3(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyCreatePartitions) - w.Int16(3) - w.Int32(55) - w.NullableString(nil) - w.WriteTaggedFields(0) - w.CompactArrayLen(1) - w.CompactString("orders") - w.Int32(6) - w.CompactArrayLen(-1) // assignments null - w.WriteTaggedFields(0) - w.Int32(15000) - w.Bool(false) - w.WriteTaggedFields(0) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*CreatePartitionsRequest) - if !ok { - t.Fatalf("expected CreatePartitionsRequest got %T", req) - } - if len(parsed.Topics) != 1 || parsed.Topics[0].Name != "orders" || parsed.Topics[0].Count != 6 { - t.Fatalf("unexpected create partitions request: %#v", parsed) - } - if parsed.ValidateOnly { - t.Fatalf("expected ValidateOnly false") - } -} - -func TestParseDeleteGroupsRequestV2(t *testing.T) { - w := newByteWriter(64) - w.Int16(APIKeyDeleteGroups) - w.Int16(2) - w.Int32(57) - w.NullableString(nil) - w.WriteTaggedFields(0) - w.CompactArrayLen(2) - w.CompactString("group-1") - w.CompactString("group-2") - w.WriteTaggedFields(0) - - _, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - parsed, ok := req.(*DeleteGroupsRequest) - if !ok { - t.Fatalf("expected DeleteGroupsRequest got %T", req) - } - if len(parsed.Groups) != 2 || parsed.Groups[1] != "group-2" { - t.Fatalf("unexpected delete groups request: %#v", parsed) +// buildRequestFrame prepends a Kafka request header to a kmsg-encoded body. +// This mirrors what a real Kafka client does: header first, then the body +// serialized by kmsg. +func buildRequestFrame(apiKey, version int16, correlationID int32, clientID *string, body []byte) []byte { + w := newByteWriter(len(body) + 32) + w.Int16(apiKey) + w.Int16(version) + w.Int32(correlationID) + w.NullableString(clientID) + + // Flexible versions (KIP-482) include tagged fields in the header. + req := kmsg.RequestForKey(apiKey) + if req != nil { + req.SetVersion(version) + if req.IsFlexible() { + w.WriteTaggedFields(0) + } } + w.write(body) + return w.Bytes() } -func TestParseProduceRequestFranzEncoding(t *testing.T) { +func TestParseRequest_Produce(t *testing.T) { req := kmsg.NewPtrProduceRequest() req.Version = 9 req.Acks = 1 @@ -514,209 +56,117 @@ func TestParseProduceRequestFranzEncoding(t *testing.T) { part.Records = []byte("record batch payload") topic.Partitions = append(topic.Partitions, part) req.Topics = append(req.Topics, topic) - body := req.AppendTo(nil) - w := newByteWriter(len(body) + 16) - w.Int16(APIKeyProduce) - w.Int16(9) - w.Int32(42) - clientID := "kgo" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.write(body) - - header, parsed, err := ParseRequest(w.Bytes()) + frame := buildRequestFrame(APIKeyProduce, 9, 42, kmsg.StringPtr("kgo"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } - if header.APIKey != APIKeyProduce { - t.Fatalf("unexpected api key %d", header.APIKey) + if header.APIKey != APIKeyProduce || header.CorrelationID != 42 { + t.Fatalf("unexpected header: %+v", header) } - produceReq, ok := parsed.(*ProduceRequest) + produceReq, ok := parsed.(*kmsg.ProduceRequest) if !ok { - t.Fatalf("expected ProduceRequest got %T", parsed) + t.Fatalf("expected *kmsg.ProduceRequest got %T", parsed) } - if len(produceReq.Topics) != 1 || len(produceReq.Topics[0].Partitions) != 1 { - t.Fatalf("unexpected partitions: %#v", produceReq.Topics) - } - if produceReq.Topics[0].Partitions[0].Partition != 0 { - t.Fatalf("expected partition 0 got %d", produceReq.Topics[0].Partitions[0].Partition) + if produceReq.Acks != 1 || len(produceReq.Topics) != 1 { + t.Fatalf("produce data mismatch: acks=%d topics=%d", produceReq.Acks, len(produceReq.Topics)) } if string(produceReq.Topics[0].Partitions[0].Records) != "record batch payload" { - t.Fatalf("records mismatch: %q", produceReq.Topics[0].Partitions[0].Records) + t.Fatalf("records mismatch") } } -func TestParseFetchRequestV13(t *testing.T) { - var topicID [16]byte - for i := range topicID { - topicID[i] = byte(i + 1) - } - w := newByteWriter(256) - w.Int16(APIKeyFetch) - w.Int16(13) - w.Int32(9) - clientID := "client" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.Int32(0) // replica id - w.Int32(500) // max wait ms - w.Int32(1) // min bytes - w.Int32(1048576) // max bytes - w.Int8(0) // isolation level - w.Int32(0) // session id - w.Int32(0) // session epoch - w.CompactArrayLen(1) - w.UUID(topicID) - w.CompactArrayLen(1) - w.Int32(0) // partition - w.Int32(-1) // current leader epoch - w.Int64(0) // fetch offset - w.Int32(-1) // last fetched epoch - w.Int64(0) // log start offset - w.Int32(1048576) - w.WriteTaggedFields(0) // partition tags - w.WriteTaggedFields(0) // topic tags - w.CompactArrayLen(0) // forgotten topics - w.CompactNullableString(nil) - w.WriteTaggedFields(0) // request tags - - header, req, err := ParseRequest(w.Bytes()) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - if header.APIKey != APIKeyFetch || header.APIVersion != 13 { - t.Fatalf("unexpected header: %#v", header) - } - fetchReq, ok := req.(*FetchRequest) - if !ok { - t.Fatalf("expected FetchRequest got %T", req) - } - if len(fetchReq.Topics) != 1 { - t.Fatalf("expected 1 topic got %d", len(fetchReq.Topics)) - } - if fetchReq.Topics[0].TopicID != topicID { - t.Fatalf("unexpected topic id %v", fetchReq.Topics[0].TopicID) - } - if fetchReq.Topics[0].Name != "" { - t.Fatalf("expected empty topic name got %q", fetchReq.Topics[0].Name) - } - if len(fetchReq.Topics[0].Partitions) != 1 { - t.Fatalf("expected 1 partition got %d", len(fetchReq.Topics[0].Partitions)) +func TestParseRequest_Metadata(t *testing.T) { + req := kmsg.NewPtrMetadataRequest() + req.Version = 12 + req.AllowAutoTopicCreation = true + req.Topics = []kmsg.MetadataRequestTopic{ + {Topic: kmsg.StringPtr("orders-3eb53935-0")}, } -} - -func TestParseMetadataRequestV12TaggedFields(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyMetadata) - w.Int16(12) - w.Int32(42) - clientID := "kgo" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.CompactArrayLen(2) - w.UUID([16]byte{}) - w.CompactNullableString(strPtr("orders-0")) - w.WriteTaggedFields(0) - w.UUID([16]byte{}) - w.CompactNullableString(strPtr("orders-1")) - w.WriteTaggedFields(0) - w.Bool(true) - w.Bool(false) - w.WriteTaggedFields(0) - header, req, err := ParseRequest(w.Bytes()) + frame := buildRequestFrame(APIKeyMetadata, 12, 1, kmsg.StringPtr("kgo"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } if header.APIKey != APIKeyMetadata || header.APIVersion != 12 { - t.Fatalf("unexpected header: %#v", header) + t.Fatalf("unexpected header: %+v", header) } - metaReq, ok := req.(*MetadataRequest) + metaReq, ok := parsed.(*kmsg.MetadataRequest) if !ok { - t.Fatalf("expected MetadataRequest got %T", req) + t.Fatalf("expected *kmsg.MetadataRequest got %T", parsed) } - if len(metaReq.Topics) != 2 { - t.Fatalf("expected 2 topics got %d", len(metaReq.Topics)) + if len(metaReq.Topics) != 1 || metaReq.Topics[0].Topic == nil || *metaReq.Topics[0].Topic != "orders-3eb53935-0" { + t.Fatalf("unexpected topics: %+v", metaReq.Topics) } if !metaReq.AllowAutoTopicCreation { - t.Fatalf("expected allow auto topic creation true") - } - if metaReq.IncludeClusterAuthOps || metaReq.IncludeTopicAuthOps { - t.Fatalf("expected auth ops false") + t.Fatalf("expected AllowAutoTopicCreation true") } } -func TestParseMetadataRequestFranzEncoding(t *testing.T) { - req := kmsg.NewPtrMetadataRequest() - req.Version = 12 - req.AllowAutoTopicCreation = true - req.IncludeTopicAuthorizedOperations = false - req.Topics = []kmsg.MetadataRequestTopic{ - {Topic: strPtr("orders-3eb53935-0")}, - } - - formatter := kmsg.NewRequestFormatter(kmsg.FormatterClientID("kgo")) - payload := formatter.AppendRequest(nil, req, 1) - payload = payload[4:] // drop the length prefix to match ParseRequest input +func TestParseRequest_FindCoordinator(t *testing.T) { + req := kmsg.NewPtrFindCoordinatorRequest() + req.Version = 3 + req.CoordinatorKey = "franz-e2e-consumer" - header, parsed, err := ParseRequest(payload) + frame := buildRequestFrame(APIKeyFindCoordinator, 3, 1, kmsg.StringPtr("kgo"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } - if header.APIKey != APIKeyMetadata || header.APIVersion != 12 { - t.Fatalf("unexpected header: %#v", header) + if header.APIKey != APIKeyFindCoordinator { + t.Fatalf("unexpected api key %d", header.APIKey) } - metaReq, ok := parsed.(*MetadataRequest) + findReq, ok := parsed.(*kmsg.FindCoordinatorRequest) if !ok { - t.Fatalf("expected MetadataRequest got %T", parsed) + t.Fatalf("expected *kmsg.FindCoordinatorRequest got %T", parsed) } - if len(metaReq.Topics) != 1 || metaReq.Topics[0] != "orders-3eb53935-0" { - t.Fatalf("unexpected topics: %#v", metaReq.Topics) - } - if !metaReq.AllowAutoTopicCreation { - t.Fatalf("expected allow auto topic creation true") - } - if metaReq.IncludeClusterAuthOps || metaReq.IncludeTopicAuthOps { - t.Fatalf("expected auth ops false") + if findReq.CoordinatorKey != "franz-e2e-consumer" { + t.Fatalf("unexpected coordinator key %q", findReq.CoordinatorKey) } } -func TestParseFindCoordinatorFlexible(t *testing.T) { - req := kmsg.NewPtrFindCoordinatorRequest() - req.Version = 3 - req.CoordinatorKey = "franz-e2e-consumer" - body := req.AppendTo(nil) - - w := newByteWriter(len(body) + 16) - w.Int16(APIKeyFindCoordinator) - w.Int16(3) - w.Int32(1) - clientID := "kgo" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.write(body) +func TestParseRequest_Fetch(t *testing.T) { + var topicID [16]byte + for i := range topicID { + topicID[i] = byte(i + 1) + } + req := kmsg.NewPtrFetchRequest() + req.Version = 13 + req.MaxWaitMillis = 500 + req.MinBytes = 1 + req.MaxBytes = 1048576 + topic := kmsg.NewFetchRequestTopic() + topic.TopicID = topicID + part := kmsg.NewFetchRequestTopicPartition() + part.Partition = 0 + part.FetchOffset = 42 + part.PartitionMaxBytes = 1048576 + topic.Partitions = append(topic.Partitions, part) + req.Topics = append(req.Topics, topic) - header, parsed, err := ParseRequest(w.Bytes()) + frame := buildRequestFrame(APIKeyFetch, 13, 9, kmsg.StringPtr("client"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } - if header.APIKey != APIKeyFindCoordinator { - t.Fatalf("unexpected api key %d", header.APIKey) + if header.APIKey != APIKeyFetch || header.APIVersion != 13 { + t.Fatalf("unexpected header: %+v", header) } - findReq, ok := parsed.(*FindCoordinatorRequest) + fetchReq, ok := parsed.(*kmsg.FetchRequest) if !ok { - t.Fatalf("expected FindCoordinatorRequest got %T", parsed) + t.Fatalf("expected *kmsg.FetchRequest got %T", parsed) } - if findReq.Key != "franz-e2e-consumer" { - t.Fatalf("unexpected coordinator key %q", findReq.Key) + if len(fetchReq.Topics) != 1 || fetchReq.Topics[0].TopicID != topicID { + t.Fatalf("unexpected topics: %+v", fetchReq.Topics) } - if findReq.KeyType != 0 { - t.Fatalf("unexpected key type %d", findReq.KeyType) + if fetchReq.Topics[0].Partitions[0].FetchOffset != 42 { + t.Fatalf("unexpected fetch offset %d", fetchReq.Topics[0].Partitions[0].FetchOffset) } } -func TestParseOffsetCommitRequestV3(t *testing.T) { +func TestParseRequest_OffsetCommit(t *testing.T) { req := kmsg.NewPtrOffsetCommitRequest() req.Version = 3 req.Group = "group-1" @@ -732,305 +182,95 @@ func TestParseOffsetCommitRequestV3(t *testing.T) { part.Metadata = &meta topic.Partitions = append(topic.Partitions, part) req.Topics = append(req.Topics, topic) - body := req.AppendTo(nil) - w := newByteWriter(len(body) + 16) - w.Int16(APIKeyOffsetCommit) - w.Int16(3) - w.Int32(7) - clientID := "kgo" - w.NullableString(&clientID) - w.write(body) - - header, parsed, err := ParseRequest(w.Bytes()) + // OffsetCommit v3 is pre-flexible (no tagged fields in header). + frame := buildRequestFrame(APIKeyOffsetCommit, 3, 7, kmsg.StringPtr("kgo"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } if header.APIKey != APIKeyOffsetCommit { t.Fatalf("unexpected api key %d", header.APIKey) } - commitReq, ok := parsed.(*OffsetCommitRequest) + commitReq, ok := parsed.(*kmsg.OffsetCommitRequest) if !ok { - t.Fatalf("expected OffsetCommitRequest got %T", parsed) - } - if commitReq.GroupID != "group-1" || commitReq.GenerationID != 4 { - t.Fatalf("unexpected group data: %#v", commitReq) + t.Fatalf("expected *kmsg.OffsetCommitRequest got %T", parsed) } - if len(commitReq.Topics) != 1 || len(commitReq.Topics[0].Partitions) != 1 { - t.Fatalf("unexpected partitions: %#v", commitReq.Topics) + if commitReq.Group != "group-1" || commitReq.Generation != 4 { + t.Fatalf("unexpected group data: group=%q gen=%d", commitReq.Group, commitReq.Generation) } - if got := commitReq.Topics[0].Partitions[0]; got.Offset != 100 || got.Metadata != "checkpoint" { - t.Fatalf("unexpected partition data: %#v", got) + if len(commitReq.Topics) != 1 || commitReq.Topics[0].Partitions[0].Offset != 100 { + t.Fatalf("unexpected partition data") } } -func TestParseSyncGroupFlexible(t *testing.T) { +func TestParseRequest_SyncGroup(t *testing.T) { req := kmsg.NewPtrSyncGroupRequest() req.Version = 4 req.Group = "franz-e2e-consumer" req.Generation = 1 req.MemberID = "member-1" req.GroupAssignment = []kmsg.SyncGroupRequestGroupAssignment{ - { - MemberID: "member-1", - MemberAssignment: []byte{0x00, 0x01}, - }, + {MemberID: "member-1", MemberAssignment: []byte{0x00, 0x01}}, } - body := req.AppendTo(nil) - w := newByteWriter(len(body) + 16) - w.Int16(APIKeySyncGroup) - w.Int16(4) - w.Int32(9) - clientID := "kgo" - w.NullableString(&clientID) - w.WriteTaggedFields(0) - w.write(body) - - header, parsed, err := ParseRequest(w.Bytes()) + frame := buildRequestFrame(APIKeySyncGroup, 4, 9, kmsg.StringPtr("kgo"), req.AppendTo(nil)) + header, parsed, err := ParseRequest(frame) if err != nil { t.Fatalf("ParseRequest: %v", err) } if header.APIKey != APIKeySyncGroup { t.Fatalf("unexpected api key %d", header.APIKey) } - syncReq, ok := parsed.(*SyncGroupRequest) + syncReq, ok := parsed.(*kmsg.SyncGroupRequest) if !ok { - t.Fatalf("expected SyncGroupRequest got %T", parsed) - } - if syncReq.GroupID != "franz-e2e-consumer" { - t.Fatalf("unexpected group id %q", syncReq.GroupID) + t.Fatalf("expected *kmsg.SyncGroupRequest got %T", parsed) } - if len(syncReq.Assignments) != 1 || syncReq.Assignments[0].MemberID != "member-1" { - t.Fatalf("unexpected assignments %#v", syncReq.Assignments) + if syncReq.Group != "franz-e2e-consumer" { + t.Fatalf("unexpected group id %q", syncReq.Group) } - if len(syncReq.Assignments[0].Assignment) != 2 { - t.Fatalf("unexpected assignment payload") + if len(syncReq.GroupAssignment) != 1 || syncReq.GroupAssignment[0].MemberID != "member-1" { + t.Fatalf("unexpected assignments %+v", syncReq.GroupAssignment) } } -func TestParseFetchRequest(t *testing.T) { - w := newByteWriter(128) - w.Int16(APIKeyFetch) - w.Int16(11) - w.Int32(9) // correlation - clientID := "consumer" - w.NullableString(&clientID) - w.Int32(1) // replica id - w.Int32(0) // max wait - w.Int32(0) // min bytes - w.Int32(1024) - w.Int8(0) - w.Int32(0) // session id - w.Int32(0) // session epoch - w.Int32(1) // topic count - w.String("orders") - w.Int32(1) // partition count - w.Int32(0) // partition - w.Int32(0) // leader epoch - w.Int64(0) // fetch offset - w.Int64(0) // log start offset - w.Int32(1024) - w.Int32(0) // forgotten topics count - w.NullableString(nil) +func TestParseRequestHeader_ReturnsRemainingBody(t *testing.T) { + req := kmsg.NewPtrApiVersionsRequest() + req.Version = 3 + req.ClientSoftwareName = "kgo" + req.ClientSoftwareVersion = "1.0.0" + body := req.AppendTo(nil) - header, req, err := ParseRequest(w.Bytes()) + frame := buildRequestFrame(APIKeyApiVersion, 3, 7, kmsg.StringPtr("kgo"), body) + header, remaining, err := ParseRequestHeader(frame) if err != nil { - t.Fatalf("ParseRequest: %v", err) + t.Fatalf("ParseRequestHeader: %v", err) } - if header.APIKey != APIKeyFetch { - t.Fatalf("expected fetch api key got %d", header.APIKey) + if header.APIKey != APIKeyApiVersion || header.APIVersion != 3 || header.CorrelationID != 7 { + t.Fatalf("unexpected header: %+v", header) } - fetchReq, ok := req.(*FetchRequest) - if !ok { - t.Fatalf("expected FetchRequest got %T", req) - } - if len(fetchReq.Topics) != 1 || len(fetchReq.Topics[0].Partitions) != 1 { - t.Fatalf("unexpected fetch data: %#v", fetchReq.Topics) + if len(remaining) != len(body) { + t.Fatalf("remaining body length: got %d, want %d", len(remaining), len(body)) } } -func TestEncodeFetchRequest_RoundTrip(t *testing.T) { - tests := []struct { - name string - version int16 - req *FetchRequest - topicID [16]byte - }{ - { - name: "v11 name-based", - version: 11, - req: &FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - IsolationLevel: 0, - SessionID: 0, - SessionEpoch: -1, - Topics: []FetchTopicRequest{ - { - Name: "orders", - Partitions: []FetchPartitionRequest{ - {Partition: 0, FetchOffset: 10, MaxBytes: 1048576}, - {Partition: 1, FetchOffset: 20, MaxBytes: 1048576}, - }, - }, - { - Name: "events", - Partitions: []FetchPartitionRequest{ - {Partition: 0, FetchOffset: 0, MaxBytes: 524288}, - }, - }, - }, - }, - }, - { - name: "v13 topic-id-based", - version: 13, - topicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - req: &FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - IsolationLevel: 1, - SessionID: 42, - SessionEpoch: 3, - Topics: []FetchTopicRequest{ - { - TopicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - Partitions: []FetchPartitionRequest{ - {Partition: 0, FetchOffset: 100, MaxBytes: 1048576}, - }, - }, - }, - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - header := &RequestHeader{ - APIKey: APIKeyFetch, - APIVersion: tc.version, - CorrelationID: 42, - ClientID: strPtr("test-client"), - } - encoded, err := EncodeFetchRequest(header, tc.req, tc.version) - if err != nil { - t.Fatalf("EncodeFetchRequest: %v", err) - } - - parsedHeader, parsedReq, err := ParseRequest(encoded) - if err != nil { - t.Fatalf("ParseRequest: %v", err) - } - if parsedHeader.APIKey != APIKeyFetch { - t.Fatalf("expected APIKeyFetch, got %d", parsedHeader.APIKey) - } - if parsedHeader.CorrelationID != 42 { - t.Fatalf("expected correlation 42, got %d", parsedHeader.CorrelationID) - } +func TestParseRequest_UnsupportedAPIKey(t *testing.T) { + w := newByteWriter(16) + w.Int16(9999) + w.Int16(0) + w.Int32(1) + w.NullableString(nil) - fetchReq, ok := parsedReq.(*FetchRequest) - if !ok { - t.Fatalf("expected *FetchRequest, got %T", parsedReq) - } - if fetchReq.MaxWaitMs != tc.req.MaxWaitMs { - t.Fatalf("MaxWaitMs: got %d, want %d", fetchReq.MaxWaitMs, tc.req.MaxWaitMs) - } - if fetchReq.SessionID != tc.req.SessionID { - t.Fatalf("SessionID: got %d, want %d", fetchReq.SessionID, tc.req.SessionID) - } - if len(fetchReq.Topics) != len(tc.req.Topics) { - t.Fatalf("topic count: got %d, want %d", len(fetchReq.Topics), len(tc.req.Topics)) - } - for ti, topic := range fetchReq.Topics { - wantTopic := tc.req.Topics[ti] - if tc.version >= 12 { - if topic.TopicID != wantTopic.TopicID { - t.Fatalf("topic[%d] ID mismatch", ti) - } - } else { - if topic.Name != wantTopic.Name { - t.Fatalf("topic[%d] name: got %q, want %q", ti, topic.Name, wantTopic.Name) - } - } - if len(topic.Partitions) != len(wantTopic.Partitions) { - t.Fatalf("topic[%d] partition count: got %d, want %d", ti, len(topic.Partitions), len(wantTopic.Partitions)) - } - for pi, part := range topic.Partitions { - wantPart := wantTopic.Partitions[pi] - if part.Partition != wantPart.Partition { - t.Fatalf("topic[%d] part[%d] id: got %d, want %d", ti, pi, part.Partition, wantPart.Partition) - } - if part.FetchOffset != wantPart.FetchOffset { - t.Fatalf("topic[%d] part[%d] offset: got %d, want %d", ti, pi, part.FetchOffset, wantPart.FetchOffset) - } - if part.MaxBytes != wantPart.MaxBytes { - t.Fatalf("topic[%d] part[%d] maxBytes: got %d, want %d", ti, pi, part.MaxBytes, wantPart.MaxBytes) - } - } - } - }) + _, _, err := ParseRequest(w.Bytes()) + if err == nil { + t.Fatalf("expected error for unsupported api key") } } -func TestEncodeFetchRequest_KmsgValidation(t *testing.T) { - // Encode a v13 request and validate it parses with franz-go's kmsg. - header := &RequestHeader{ - APIKey: APIKeyFetch, - APIVersion: 13, - CorrelationID: 99, - ClientID: strPtr("kmsg-test"), - } - topicID := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - req := &FetchRequest{ - ReplicaID: -1, - MaxWaitMs: 500, - MinBytes: 1, - MaxBytes: 1048576, - IsolationLevel: 0, - SessionID: 0, - SessionEpoch: -1, - Topics: []FetchTopicRequest{ - { - TopicID: topicID, - Partitions: []FetchPartitionRequest{ - {Partition: 0, FetchOffset: 42, MaxBytes: 1048576}, - }, - }, - }, - } - encoded, err := EncodeFetchRequest(header, req, 13) - if err != nil { - t.Fatalf("EncodeFetchRequest: %v", err) - } - - // Use ParseRequestHeader to find where the body starts (same as the real code). - _, reader, err := ParseRequestHeader(encoded) - if err != nil { - t.Fatalf("ParseRequestHeader: %v", err) - } - bodyStart := len(encoded) - reader.remaining() - - kmsgReq := kmsg.NewPtrFetchRequest() - kmsgReq.Version = 13 - if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil { - t.Fatalf("kmsg.ReadFrom: %v", err) - } - if len(kmsgReq.Topics) != 1 { - t.Fatalf("expected 1 topic, got %d", len(kmsgReq.Topics)) - } - if kmsgReq.Topics[0].TopicID != topicID { - t.Fatalf("topic ID mismatch") - } - if len(kmsgReq.Topics[0].Partitions) != 1 { - t.Fatalf("expected 1 partition, got %d", len(kmsgReq.Topics[0].Partitions)) - } - if kmsgReq.Topics[0].Partitions[0].FetchOffset != 42 { - t.Fatalf("fetch offset: got %d, want 42", kmsgReq.Topics[0].Partitions[0].FetchOffset) +func TestParseRequest_TruncatedHeader(t *testing.T) { + _, _, err := ParseRequest([]byte{0x00, 0x03}) + if err == nil { + t.Fatalf("expected error for truncated header") } } @@ -1068,9 +308,9 @@ func TestProduceMultiPartitionFranzCompat(t *testing.T) { if err != nil { t.Fatalf("ParseRequest: %v", err) } - got, ok := parsed.(*ProduceRequest) + got, ok := parsed.(*kmsg.ProduceRequest) if !ok { - t.Fatalf("expected *ProduceRequest, got %T", parsed) + t.Fatalf("expected *kmsg.ProduceRequest, got %T", parsed) } if len(got.Topics) != 1 { t.Fatalf("topic count: got %d want 1", len(got.Topics)) @@ -1089,57 +329,4 @@ func TestProduceMultiPartitionFranzCompat(t *testing.T) { } }) - t.Run("kafscale-encode-franz-parse", func(t *testing.T) { - header := &RequestHeader{ - APIKey: APIKeyProduce, - APIVersion: 9, - CorrelationID: 66, - ClientID: strPtr("test"), - } - req := &ProduceRequest{ - Acks: -1, - TimeoutMs: 3000, - Topics: []ProduceTopic{ - { - Name: "orders", - Partitions: []ProducePartition{ - {Partition: 0, Records: []byte{1, 2}}, - {Partition: 1, Records: []byte{3, 4}}, - {Partition: 2, Records: []byte{5, 6}}, - }, - }, - }, - } - encoded, err := EncodeProduceRequest(header, req, 9) - if err != nil { - t.Fatalf("encode: %v", err) - } - - _, reader, err := ParseRequestHeader(encoded) - if err != nil { - t.Fatalf("ParseRequestHeader: %v", err) - } - bodyStart := len(encoded) - reader.remaining() - - kmsgReq := kmsg.NewPtrProduceRequest() - kmsgReq.Version = 9 - if err := kmsgReq.ReadFrom(encoded[bodyStart:]); err != nil { - t.Fatalf("kmsg.ReadFrom: %v", err) - } - if len(kmsgReq.Topics) != 1 { - t.Fatalf("topic count: got %d want 1", len(kmsgReq.Topics)) - } - if len(kmsgReq.Topics[0].Partitions) != 3 { - t.Fatalf("partition count: got %d want 3", len(kmsgReq.Topics[0].Partitions)) - } - for pi, part := range kmsgReq.Topics[0].Partitions { - if part.Partition != int32(pi) { - t.Fatalf("part[%d] index: got %d want %d", pi, part.Partition, pi) - } - want := []byte{byte(pi*2 + 1), byte(pi*2 + 2)} - if string(part.Records) != string(want) { - t.Fatalf("part[%d] records: got %x want %x", pi, part.Records, want) - } - } - }) } diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index 8f3c27e8..94b66fa8 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -15,1770 +15,180 @@ package protocol -import "fmt" - -// ApiVersionsResponse describes server capabilities. -type ApiVersionsResponse struct { - CorrelationID int32 - ErrorCode int16 - ThrottleMs int32 - Versions []ApiVersion -} - -// MetadataBroker describes a broker in Metadata response. -type MetadataBroker struct { - NodeID int32 - Host string - Port int32 - Rack *string -} - -// MetadataTopic describes a topic in Metadata response. -type MetadataTopic struct { - ErrorCode int16 - Name string - TopicID [16]byte - IsInternal bool - Partitions []MetadataPartition - TopicAuthorizedOperations int32 -} - -// MetadataPartition describes partition metadata. -type MetadataPartition struct { - ErrorCode int16 - PartitionIndex int32 - LeaderID int32 - LeaderEpoch int32 - ReplicaNodes []int32 - ISRNodes []int32 - OfflineReplicas []int32 -} - -// MetadataResponse holds topic + broker info. -type MetadataResponse struct { - CorrelationID int32 - ThrottleMs int32 - Brokers []MetadataBroker - ClusterID *string - ControllerID int32 - Topics []MetadataTopic - ClusterAuthorizedOperations int32 -} - -// ProduceResponse contains per-partition acknowledgement info. -type ProduceResponse struct { - CorrelationID int32 - Topics []ProduceTopicResponse - ThrottleMs int32 -} - -type ProduceTopicResponse struct { - Name string - Partitions []ProducePartitionResponse -} - -type ProducePartitionResponse struct { - Partition int32 - ErrorCode int16 - BaseOffset int64 - LogAppendTimeMs int64 - LogStartOffset int64 -} - -// FetchResponse represents data returned to consumers. -type FetchResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 - SessionID int32 - Topics []FetchTopicResponse -} - -type FetchTopicResponse struct { - Name string - TopicID [16]byte - Partitions []FetchPartitionResponse -} - -type FetchAbortedTransaction struct { - ProducerID int64 - FirstOffset int64 -} - -type FetchPartitionResponse struct { - Partition int32 - ErrorCode int16 - HighWatermark int64 - LastStableOffset int64 - LogStartOffset int64 - PreferredReadReplica int32 - RecordSet []byte - AbortedTransactions []FetchAbortedTransaction -} - -type CreateTopicResult struct { - Name string - ErrorCode int16 - ErrorMessage string -} - -type CreateTopicsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []CreateTopicResult -} - -type DeleteTopicResult struct { - Name string - ErrorCode int16 - ErrorMessage string -} - -type DeleteTopicsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []DeleteTopicResult -} - -type ListOffsetsPartitionResponse struct { - Partition int32 - ErrorCode int16 - Timestamp int64 - Offset int64 - LeaderEpoch int32 - OldStyleOffsets []int64 -} - -type ListOffsetsTopicResponse struct { - Name string - Partitions []ListOffsetsPartitionResponse -} - -type ListOffsetsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []ListOffsetsTopicResponse -} - -type FindCoordinatorResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 - ErrorMessage *string - NodeID int32 - Host string - Port int32 -} - -type JoinGroupMember struct { - MemberID string - InstanceID *string - Metadata []byte -} - -type JoinGroupResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 - GenerationID int32 - ProtocolName string - LeaderID string - MemberID string - Members []JoinGroupMember -} - -type SyncGroupResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 - ProtocolType *string - ProtocolName *string - Assignment []byte -} - -type HeartbeatResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 -} - -type LeaveGroupResponse struct { - CorrelationID int32 - ErrorCode int16 -} - -type OffsetCommitPartitionResponse struct { - Partition int32 - ErrorCode int16 -} - -type OffsetCommitTopicResponse struct { - Name string - Partitions []OffsetCommitPartitionResponse -} - -type OffsetCommitResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []OffsetCommitTopicResponse -} - -type OffsetFetchPartitionResponse struct { - Partition int32 - Offset int64 - LeaderEpoch int32 - Metadata *string - ErrorCode int16 -} - -type OffsetFetchTopicResponse struct { - Name string - Partitions []OffsetFetchPartitionResponse -} - -type OffsetFetchResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []OffsetFetchTopicResponse - ErrorCode int16 -} - -type OffsetForLeaderEpochPartitionResponse struct { - Partition int32 - ErrorCode int16 - LeaderEpoch int32 - EndOffset int64 -} - -type OffsetForLeaderEpochTopicResponse struct { - Name string - Partitions []OffsetForLeaderEpochPartitionResponse -} - -type OffsetForLeaderEpochResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []OffsetForLeaderEpochTopicResponse -} - -type DescribeGroupsResponseGroupMember struct { - MemberID string - InstanceID *string - ClientID string - ClientHost string - ProtocolMetadata []byte - MemberAssignment []byte -} - -type DescribeGroupsResponseGroup struct { - ErrorCode int16 - GroupID string - State string - ProtocolType string - Protocol string - Members []DescribeGroupsResponseGroupMember - AuthorizedOperations int32 -} - -type DescribeGroupsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Groups []DescribeGroupsResponseGroup -} - -type ListGroupsResponseGroup struct { - GroupID string - ProtocolType string - GroupState string - GroupType string -} - -type ListGroupsResponse struct { - CorrelationID int32 - ThrottleMs int32 - ErrorCode int16 - Groups []ListGroupsResponseGroup -} - -type DescribeConfigsResponseConfigSynonym struct { - Name string - Value *string - Source int8 -} - -type DescribeConfigsResponseConfig struct { - Name string - Value *string - ReadOnly bool - IsDefault bool - Source int8 - IsSensitive bool - Synonyms []DescribeConfigsResponseConfigSynonym - ConfigType int8 - Documentation *string -} - -type DescribeConfigsResponseResource struct { - ErrorCode int16 - ErrorMessage *string - ResourceType int8 - ResourceName string - Configs []DescribeConfigsResponseConfig -} - -type DescribeConfigsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Resources []DescribeConfigsResponseResource -} - -type AlterConfigsResponseResource struct { - ErrorCode int16 - ErrorMessage *string - ResourceType int8 - ResourceName string -} - -type AlterConfigsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Resources []AlterConfigsResponseResource -} - -type CreatePartitionsResponseTopic struct { - Name string - ErrorCode int16 - ErrorMessage *string -} - -type CreatePartitionsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Topics []CreatePartitionsResponseTopic -} - -type DeleteGroupsResponseGroup struct { - Group string - ErrorCode int16 -} - -type DeleteGroupsResponse struct { - CorrelationID int32 - ThrottleMs int32 - Groups []DeleteGroupsResponseGroup -} - -// EncodeApiVersionsResponse renders bytes ready to send on the wire. -func EncodeApiVersionsResponse(resp *ApiVersionsResponse, version int16) ([]byte, error) { - if version < 0 || version > 4 { - return nil, fmt.Errorf("api versions response version %d not supported", version) - } - w := newByteWriter(64) - w.Int32(resp.CorrelationID) - w.Int16(resp.ErrorCode) - if version >= 3 { - w.CompactArrayLen(len(resp.Versions)) - } else { - w.Int32(int32(len(resp.Versions))) - } - for _, v := range resp.Versions { - w.Int16(v.APIKey) - w.Int16(v.MinVersion) - w.Int16(v.MaxVersion) - if version >= 3 { - w.WriteTaggedFields(0) - } - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - if version >= 3 { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeMetadataResponse renders bytes for metadata responses. version should match -// the Metadata request version that triggered this response. -func EncodeMetadataResponse(resp *MetadataResponse, version int16) ([]byte, error) { - if version < 0 || version > 12 { - return nil, fmt.Errorf("metadata response version %d not supported", version) - } - flexible := version >= 9 - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 3 { - w.Int32(resp.ThrottleMs) - } - if flexible { - w.CompactArrayLen(len(resp.Brokers)) - } else { - w.Int32(int32(len(resp.Brokers))) - } - for _, b := range resp.Brokers { - w.Int32(b.NodeID) - if flexible { - w.CompactString(b.Host) - } else { - w.String(b.Host) - } - w.Int32(b.Port) - if version >= 1 { - if flexible { - w.CompactNullableString(b.Rack) - } else { - w.NullableString(b.Rack) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if version >= 2 { - if flexible { - w.CompactNullableString(resp.ClusterID) - } else { - w.NullableString(resp.ClusterID) - } - } - if version >= 1 { - w.Int32(resp.ControllerID) - } - if flexible { - w.CompactArrayLen(len(resp.Topics)) - } else { - w.Int32(int32(len(resp.Topics))) - } - for _, t := range resp.Topics { - w.Int16(t.ErrorCode) - if version >= 10 { - var namePtr *string - if t.Name != "" { - namePtr = &t.Name - } - if flexible { - w.CompactNullableString(namePtr) - } else { - w.NullableString(namePtr) - } - w.UUID(t.TopicID) - if version >= 1 { - w.Bool(t.IsInternal) - } - } else { - if flexible { - w.CompactString(t.Name) - } else { - w.String(t.Name) - } - if version >= 1 { - w.Bool(t.IsInternal) - } - } - if flexible { - w.CompactArrayLen(len(t.Partitions)) - } else { - w.Int32(int32(len(t.Partitions))) - } - for _, p := range t.Partitions { - w.Int16(p.ErrorCode) - w.Int32(p.PartitionIndex) - w.Int32(p.LeaderID) - if version >= 7 { - w.Int32(p.LeaderEpoch) - } - if flexible { - w.CompactArrayLen(len(p.ReplicaNodes)) - } else { - w.Int32(int32(len(p.ReplicaNodes))) - } - for _, replica := range p.ReplicaNodes { - w.Int32(replica) - } - if flexible { - w.CompactArrayLen(len(p.ISRNodes)) - } else { - w.Int32(int32(len(p.ISRNodes))) - } - for _, isr := range p.ISRNodes { - w.Int32(isr) - } - if version >= 5 { - if flexible { - w.CompactArrayLen(len(p.OfflineReplicas)) - } else { - w.Int32(int32(len(p.OfflineReplicas))) - } - for _, offline := range p.OfflineReplicas { - w.Int32(offline) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if version >= 8 { - w.Int32(t.TopicAuthorizedOperations) - } - if flexible { - w.WriteTaggedFields(0) - } - } - if version >= 8 { - w.Int32(resp.ClusterAuthorizedOperations) - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeProduceResponse renders bytes for produce responses. -func EncodeProduceResponse(resp *ProduceResponse, version int16) ([]byte, error) { - w := newByteWriter(128) - flexible := version >= 9 - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if flexible { - w.CompactArrayLen(len(resp.Topics)) - } else { - w.Int32(int32(len(resp.Topics))) - } - for _, topic := range resp.Topics { - if flexible { - w.CompactString(topic.Name) - } else { - w.String(topic.Name) - } - if flexible { - w.CompactArrayLen(len(topic.Partitions)) - } else { - w.Int32(int32(len(topic.Partitions))) - } - for _, p := range topic.Partitions { - w.Int32(p.Partition) - w.Int16(p.ErrorCode) - w.Int64(p.BaseOffset) - if version >= 3 { - w.Int64(p.LogAppendTimeMs) - } - if version >= 5 { - w.Int64(p.LogStartOffset) - } - if version >= 8 { - if flexible { - w.CompactArrayLen(0) // error_records - w.CompactNullableString(nil) - } else { - w.Int32(0) // error_records - w.NullableString(nil) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// ParseProduceResponse decodes a produce response from wire-format bytes. -// This is the inverse of EncodeProduceResponse and supports the same version range. -func ParseProduceResponse(payload []byte, version int16) (*ProduceResponse, error) { - r := newByteReader(payload) - flexible := version >= 9 - - corrID, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read correlation id: %w", err) - } - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip response header tags: %w", err) - } - } - - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(r) - } else { - topicCount, err = r.Int32() - } - if err != nil { - return nil, fmt.Errorf("read topic count: %w", err) - } - - topics := make([]ProduceTopicResponse, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var name string - if flexible { - name, err = r.CompactString() - } else { - name, err = r.String() - } - if err != nil { - return nil, fmt.Errorf("read topic name: %w", err) - } - - var partCount int32 - if flexible { - partCount, err = compactArrayLenNonNull(r) - } else { - partCount, err = r.Int32() - } - if err != nil { - return nil, fmt.Errorf("read partition count: %w", err) - } - - partitions := make([]ProducePartitionResponse, 0, partCount) - for j := int32(0); j < partCount; j++ { - partIdx, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read partition index: %w", err) - } - errorCode, err := r.Int16() - if err != nil { - return nil, fmt.Errorf("read error code: %w", err) - } - baseOffset, err := r.Int64() - if err != nil { - return nil, fmt.Errorf("read base offset: %w", err) - } - var logAppendTimeMs, logStartOffset int64 - if version >= 3 { - logAppendTimeMs, err = r.Int64() - if err != nil { - return nil, fmt.Errorf("read log append time: %w", err) - } - } - if version >= 5 { - logStartOffset, err = r.Int64() - if err != nil { - return nil, fmt.Errorf("read log start offset: %w", err) - } - } - if version >= 8 { - if flexible { - arrLen, err := r.CompactArrayLen() - if err != nil { - return nil, fmt.Errorf("read error records len: %w", err) - } - for k := int32(0); k < arrLen; k++ { - if _, err := r.CompactBytes(); err != nil { - return nil, fmt.Errorf("skip error record: %w", err) - } - } - if _, err := r.CompactNullableString(); err != nil { - return nil, fmt.Errorf("read error message: %w", err) - } - } else { - arrLen, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read error records len: %w", err) - } - for k := int32(0); k < arrLen; k++ { - if _, err := r.Bytes(); err != nil { - return nil, fmt.Errorf("skip error record: %w", err) - } - } - if _, err := r.NullableString(); err != nil { - return nil, fmt.Errorf("read error message: %w", err) - } - } - } - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip partition tags: %w", err) - } - } - partitions = append(partitions, ProducePartitionResponse{ - Partition: partIdx, - ErrorCode: errorCode, - BaseOffset: baseOffset, - LogAppendTimeMs: logAppendTimeMs, - LogStartOffset: logStartOffset, - }) - } - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip topic tags: %w", err) - } - } - topics = append(topics, ProduceTopicResponse{Name: name, Partitions: partitions}) - } - - var throttleMs int32 - if version >= 1 { - throttleMs, err = r.Int32() - if err != nil { - return nil, fmt.Errorf("read throttle ms: %w", err) - } - } - if flexible { - _ = r.SkipTaggedFields() - } - - return &ProduceResponse{ - CorrelationID: corrID, - Topics: topics, - ThrottleMs: throttleMs, - }, nil -} - -// EncodeFetchResponse renders bytes for fetch responses. -func EncodeFetchResponse(resp *FetchResponse, version int16) ([]byte, error) { - if version < 1 || version > 13 { - return nil, fmt.Errorf("fetch response version %d not supported", version) - } - flexible := version >= 12 - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - w.Int32(resp.ThrottleMs) - if version >= 7 { - w.Int16(resp.ErrorCode) - w.Int32(resp.SessionID) - } else { - if resp.ErrorCode != 0 || resp.SessionID != 0 { - return nil, fmt.Errorf("fetch version %d cannot include session fields", version) - } - } - if flexible { - w.CompactArrayLen(len(resp.Topics)) - } else { - w.Int32(int32(len(resp.Topics))) - } - for _, topic := range resp.Topics { - if flexible { - w.UUID(topic.TopicID) - } else { - w.String(topic.Name) - } - if flexible { - w.CompactArrayLen(len(topic.Partitions)) - } else { - w.Int32(int32(len(topic.Partitions))) - } - for _, part := range topic.Partitions { - w.Int32(part.Partition) - w.Int16(part.ErrorCode) - w.Int64(part.HighWatermark) - if version >= 4 { - w.Int64(part.LastStableOffset) - } - if version >= 5 { - w.Int64(part.LogStartOffset) - } - if version >= 4 { - if flexible { - w.CompactArrayLen(len(part.AbortedTransactions)) - } else { - w.Int32(int32(len(part.AbortedTransactions))) - } - for _, aborted := range part.AbortedTransactions { - w.Int64(aborted.ProducerID) - w.Int64(aborted.FirstOffset) - } - } - if version >= 11 { - w.Int32(part.PreferredReadReplica) - } - if flexible { - if part.RecordSet == nil { - w.CompactBytes([]byte{}) - } else { - w.CompactBytes(part.RecordSet) - } - w.WriteTaggedFields(0) - } else { - if part.RecordSet == nil { - w.Int32(0) - } else { - w.Int32(int32(len(part.RecordSet))) - w.write(part.RecordSet) - } - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// ParseFetchResponse decodes a fetch response. Inverse of EncodeFetchResponse. -func ParseFetchResponse(payload []byte, version int16) (*FetchResponse, error) { - if version < 1 || version > 13 { - return nil, fmt.Errorf("fetch response version %d not supported", version) - } - r := newByteReader(payload) - flexible := version >= 12 - - corrID, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read correlation id: %w", err) - } - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip response header tags: %w", err) - } - } - - throttleMs, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read throttle ms: %w", err) - } - - var errorCode int16 - var sessionID int32 - if version >= 7 { - errorCode, err = r.Int16() - if err != nil { - return nil, fmt.Errorf("read error code: %w", err) - } - sessionID, err = r.Int32() - if err != nil { - return nil, fmt.Errorf("read session id: %w", err) - } - } - - var topicCount int32 - if flexible { - topicCount, err = compactArrayLenNonNull(r) - } else { - topicCount, err = r.Int32() - } - if err != nil { - return nil, fmt.Errorf("read topic count: %w", err) - } - - topics := make([]FetchTopicResponse, 0, topicCount) - for i := int32(0); i < topicCount; i++ { - var ( - name string - topicID [16]byte - ) - if flexible { - topicID, err = r.UUID() - if err != nil { - return nil, fmt.Errorf("read topic id: %w", err) - } - } else { - name, err = r.String() - if err != nil { - return nil, fmt.Errorf("read topic name: %w", err) - } - } - - var partCount int32 - if flexible { - partCount, err = compactArrayLenNonNull(r) - } else { - partCount, err = r.Int32() - } - if err != nil { - return nil, fmt.Errorf("read partition count: %w", err) - } - - partitions := make([]FetchPartitionResponse, 0, partCount) - for j := int32(0); j < partCount; j++ { - partIdx, err := r.Int32() - if err != nil { - return nil, fmt.Errorf("read partition index: %w", err) - } - ec, err := r.Int16() - if err != nil { - return nil, fmt.Errorf("read partition error code: %w", err) - } - highWatermark, err := r.Int64() - if err != nil { - return nil, fmt.Errorf("read high watermark: %w", err) - } - var lastStableOffset, logStartOffset int64 - if version >= 4 { - lastStableOffset, err = r.Int64() - if err != nil { - return nil, fmt.Errorf("read last stable offset: %w", err) - } - } - if version >= 5 { - logStartOffset, err = r.Int64() - if err != nil { - return nil, fmt.Errorf("read log start offset: %w", err) - } - } - - var abortedTransactions []FetchAbortedTransaction - if version >= 4 { - var abortedCount int32 - // Nullable: brokers may return null (no aborted transactions). - if flexible { - abortedCount, err = r.CompactArrayLen() - } else { - abortedCount, err = r.Int32() - } - if err != nil { - return nil, fmt.Errorf("read aborted count: %w", err) - } - if abortedCount > 0 { - abortedTransactions = make([]FetchAbortedTransaction, 0, abortedCount) - for k := int32(0); k < abortedCount; k++ { - producerID, err := r.Int64() - if err != nil { - return nil, fmt.Errorf("read aborted producer id: %w", err) - } - firstOffset, err := r.Int64() - if err != nil { - return nil, fmt.Errorf("read aborted first offset: %w", err) - } - abortedTransactions = append(abortedTransactions, FetchAbortedTransaction{ - ProducerID: producerID, - FirstOffset: firstOffset, - }) - } - } - } - - var preferredReadReplica int32 - if version >= 11 { - preferredReadReplica, err = r.Int32() - if err != nil { - return nil, fmt.Errorf("read preferred read replica: %w", err) - } - } - - var recordSet []byte - if flexible { - recordSet, err = r.CompactBytes() - } else { - recordSet, err = r.Bytes() - } - if err != nil { - return nil, fmt.Errorf("read record set: %w", err) - } - - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip partition tags: %w", err) - } - } - - partitions = append(partitions, FetchPartitionResponse{ - Partition: partIdx, - ErrorCode: ec, - HighWatermark: highWatermark, - LastStableOffset: lastStableOffset, - LogStartOffset: logStartOffset, - PreferredReadReplica: preferredReadReplica, - RecordSet: recordSet, - AbortedTransactions: abortedTransactions, - }) - } - - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return nil, fmt.Errorf("skip topic tags: %w", err) - } - } - - topics = append(topics, FetchTopicResponse{ - Name: name, - TopicID: topicID, - Partitions: partitions, - }) - } - - if flexible { - _ = r.SkipTaggedFields() - } - - return &FetchResponse{ - CorrelationID: corrID, - ThrottleMs: throttleMs, - ErrorCode: errorCode, - SessionID: sessionID, - Topics: topics, - }, nil -} - -func EncodeCreateTopicsResponse(resp *CreateTopicsResponse, version int16) ([]byte, error) { - if version < 0 || version > 2 { - return nil, fmt.Errorf("create topics response version %d not supported", version) - } - w := newByteWriter(128) - w.Int32(resp.CorrelationID) - if version >= 2 { - w.Int32(resp.ThrottleMs) - } - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int16(topic.ErrorCode) - if version >= 1 { - w.String(topic.ErrorMessage) - } - } - return w.Bytes(), nil -} - -func EncodeDeleteTopicsResponse(resp *DeleteTopicsResponse, version int16) ([]byte, error) { - if version < 0 || version > 2 { - return nil, fmt.Errorf("delete topics response version %d not supported", version) - } - w := newByteWriter(128) - w.Int32(resp.CorrelationID) - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int16(topic.ErrorCode) - } - return w.Bytes(), nil -} - -func EncodeListOffsetsResponse(version int16, resp *ListOffsetsResponse) ([]byte, error) { - if version < 0 || version > 4 { - return nil, fmt.Errorf("list offsets response version %d not supported", version) - } - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if version >= 2 { - w.Int32(resp.ThrottleMs) - } - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int32(int32(len(topic.Partitions))) - for _, part := range topic.Partitions { - w.Int32(part.Partition) - w.Int16(part.ErrorCode) - if version == 0 { - offsets := part.OldStyleOffsets - if offsets == nil { - offsets = []int64{} - } - w.Int32(int32(len(offsets))) - for _, off := range offsets { - w.Int64(off) - } - continue - } - w.Int64(part.Timestamp) - w.Int64(part.Offset) - if version >= 4 { - w.Int32(part.LeaderEpoch) - } - } - } - return w.Bytes(), nil -} - -func EncodeFindCoordinatorResponse(resp *FindCoordinatorResponse, version int16) ([]byte, error) { - if version >= 4 { - return nil, fmt.Errorf("find coordinator version %d not supported", version) - } - w := newByteWriter(64) - flexible := version >= 3 - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - w.Int16(resp.ErrorCode) - if version >= 1 { - if flexible { - w.CompactNullableString(resp.ErrorMessage) - } else { - w.NullableString(resp.ErrorMessage) - } - } - w.Int32(resp.NodeID) - if flexible { - w.CompactString(resp.Host) - } else { - w.String(resp.Host) - } - w.Int32(resp.Port) - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -func EncodeJoinGroupResponse(resp *JoinGroupResponse, version int16) ([]byte, error) { - if version >= 6 { - return nil, fmt.Errorf("join group response version %d not supported", version) - } - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if version >= 2 { - w.Int32(resp.ThrottleMs) - } - w.Int16(resp.ErrorCode) - w.Int32(resp.GenerationID) - w.String(resp.ProtocolName) - w.String(resp.LeaderID) - w.String(resp.MemberID) - w.Int32(int32(len(resp.Members))) - for _, member := range resp.Members { - w.String(member.MemberID) - if version >= 5 { - if member.InstanceID == nil { - w.Int16(-1) - } else { - w.String(*member.InstanceID) - } - } - w.BytesWithLength(member.Metadata) - } - return w.Bytes(), nil -} - -func EncodeSyncGroupResponse(resp *SyncGroupResponse, version int16) ([]byte, error) { - if version > 5 { - return nil, fmt.Errorf("sync group response version %d not supported", version) - } - flexible := version >= 4 - w := newByteWriter(192) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - w.Int16(resp.ErrorCode) - if version >= 5 { - if flexible { - w.CompactNullableString(resp.ProtocolType) - w.CompactNullableString(resp.ProtocolName) - } else { - w.NullableString(resp.ProtocolType) - w.NullableString(resp.ProtocolName) - } - } - if flexible { - w.CompactBytes(resp.Assignment) - w.WriteTaggedFields(0) - } else { - w.BytesWithLength(resp.Assignment) - } - return w.Bytes(), nil -} - -func EncodeHeartbeatResponse(resp *HeartbeatResponse, version int16) ([]byte, error) { - if version > 4 { - return nil, fmt.Errorf("heartbeat response version %d not supported", version) - } - flexible := version >= 4 - w := newByteWriter(64) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - w.Int16(resp.ErrorCode) - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -func EncodeLeaveGroupResponse(resp *LeaveGroupResponse) ([]byte, error) { - w := newByteWriter(32) - w.Int32(resp.CorrelationID) - w.Int16(resp.ErrorCode) - return w.Bytes(), nil -} - -func EncodeOffsetCommitResponse(resp *OffsetCommitResponse) ([]byte, error) { - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - w.Int32(resp.ThrottleMs) - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int32(int32(len(topic.Partitions))) - for _, part := range topic.Partitions { - w.Int32(part.Partition) - w.Int16(part.ErrorCode) - } - } - return w.Bytes(), nil -} - -func EncodeOffsetFetchResponse(resp *OffsetFetchResponse, version int16) ([]byte, error) { - if version < 3 || version > 5 { - return nil, fmt.Errorf("offset fetch response version %d not supported", version) - } - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if version >= 3 { - w.Int32(resp.ThrottleMs) - } - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int32(int32(len(topic.Partitions))) - for _, part := range topic.Partitions { - w.Int32(part.Partition) - w.Int64(part.Offset) - if version >= 5 { - w.Int32(part.LeaderEpoch) - } - w.NullableString(part.Metadata) - w.Int16(part.ErrorCode) - } - } - if version >= 2 { - w.Int16(resp.ErrorCode) - } - return w.Bytes(), nil -} - -// EncodeOffsetForLeaderEpochResponse renders bytes for offset for leader epoch responses. -func EncodeOffsetForLeaderEpochResponse(resp *OffsetForLeaderEpochResponse, version int16) ([]byte, error) { - if version != 3 { - return nil, fmt.Errorf("offset for leader epoch response version %d not supported", version) - } - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if version >= 2 { - w.Int32(resp.ThrottleMs) - } - w.Int32(int32(len(resp.Topics))) - for _, topic := range resp.Topics { - w.String(topic.Name) - w.Int32(int32(len(topic.Partitions))) - for _, part := range topic.Partitions { - w.Int32(part.Partition) - w.Int16(part.ErrorCode) - w.Int32(part.LeaderEpoch) - w.Int64(part.EndOffset) - } - } - return w.Bytes(), nil -} - -// EncodeDescribeGroupsResponse renders bytes for describe groups responses. -func EncodeDescribeGroupsResponse(resp *DescribeGroupsResponse, version int16) ([]byte, error) { - if version != 5 { - return nil, fmt.Errorf("describe groups response version %d not supported", version) - } - flexible := version >= 5 - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - if flexible { - w.CompactArrayLen(len(resp.Groups)) - } else { - w.Int32(int32(len(resp.Groups))) - } - for _, group := range resp.Groups { - w.Int16(group.ErrorCode) - if flexible { - w.CompactString(group.GroupID) - w.CompactString(group.State) - w.CompactString(group.ProtocolType) - w.CompactString(group.Protocol) - } else { - w.String(group.GroupID) - w.String(group.State) - w.String(group.ProtocolType) - w.String(group.Protocol) - } - if flexible { - w.CompactArrayLen(len(group.Members)) - } else { - w.Int32(int32(len(group.Members))) - } - for _, member := range group.Members { - if flexible { - w.CompactString(member.MemberID) - w.CompactNullableString(member.InstanceID) - w.CompactString(member.ClientID) - w.CompactString(member.ClientHost) - w.CompactBytes(member.ProtocolMetadata) - w.CompactBytes(member.MemberAssignment) - w.WriteTaggedFields(0) - } else { - w.String(member.MemberID) - w.NullableString(member.InstanceID) - w.String(member.ClientID) - w.String(member.ClientHost) - w.BytesWithLength(member.ProtocolMetadata) - w.BytesWithLength(member.MemberAssignment) - } - } - if version >= 3 { - w.Int32(group.AuthorizedOperations) - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeListGroupsResponse renders bytes for list groups responses. -func EncodeListGroupsResponse(resp *ListGroupsResponse, version int16) ([]byte, error) { - if version != 5 { - return nil, fmt.Errorf("list groups response version %d not supported", version) - } - flexible := version >= 3 - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - if version >= 1 { - w.Int32(resp.ThrottleMs) - } - w.Int16(resp.ErrorCode) - if flexible { - w.CompactArrayLen(len(resp.Groups)) - } else { - w.Int32(int32(len(resp.Groups))) - } - for _, group := range resp.Groups { - if flexible { - w.CompactString(group.GroupID) - w.CompactString(group.ProtocolType) - w.CompactString(group.GroupState) - w.CompactString(group.GroupType) - w.WriteTaggedFields(0) - } else { - w.String(group.GroupID) - w.String(group.ProtocolType) - if version >= 4 { - w.String(group.GroupState) - } - if version >= 5 { - w.String(group.GroupType) - } - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeDescribeConfigsResponse renders bytes for describe configs responses. -func EncodeDescribeConfigsResponse(resp *DescribeConfigsResponse, version int16) ([]byte, error) { - if version != 4 { - return nil, fmt.Errorf("describe configs response version %d not supported", version) - } - flexible := version >= 4 - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - w.Int32(resp.ThrottleMs) - if flexible { - w.CompactArrayLen(len(resp.Resources)) - } else { - w.Int32(int32(len(resp.Resources))) - } - for _, resource := range resp.Resources { - w.Int16(resource.ErrorCode) - if flexible { - w.CompactNullableString(resource.ErrorMessage) - } else { - w.NullableString(resource.ErrorMessage) - } - w.Int8(resource.ResourceType) - if flexible { - w.CompactString(resource.ResourceName) - } else { - w.String(resource.ResourceName) - } - if flexible { - w.CompactArrayLen(len(resource.Configs)) - } else { - w.Int32(int32(len(resource.Configs))) - } - for _, cfg := range resource.Configs { - if flexible { - w.CompactString(cfg.Name) - w.CompactNullableString(cfg.Value) - } else { - w.String(cfg.Name) - w.NullableString(cfg.Value) - } - w.Bool(cfg.ReadOnly) - w.Int8(cfg.Source) - w.Bool(cfg.IsSensitive) - if flexible { - w.CompactArrayLen(len(cfg.Synonyms)) - } else { - w.Int32(int32(len(cfg.Synonyms))) - } - for _, synonym := range cfg.Synonyms { - if flexible { - w.CompactString(synonym.Name) - w.CompactNullableString(synonym.Value) - } else { - w.String(synonym.Name) - w.NullableString(synonym.Value) - } - w.Int8(synonym.Source) - if flexible { - w.WriteTaggedFields(0) - } - } - w.Int8(cfg.ConfigType) - if flexible { - w.CompactNullableString(cfg.Documentation) - w.WriteTaggedFields(0) - } else { - w.NullableString(cfg.Documentation) - } - } - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeAlterConfigsResponse renders bytes for alter configs responses. -func EncodeAlterConfigsResponse(resp *AlterConfigsResponse, version int16) ([]byte, error) { - if version != 1 { - return nil, fmt.Errorf("alter configs response version %d not supported", version) - } - w := newByteWriter(256) - w.Int32(resp.CorrelationID) - w.Int32(resp.ThrottleMs) - w.Int32(int32(len(resp.Resources))) - for _, resource := range resp.Resources { - w.Int16(resource.ErrorCode) - w.NullableString(resource.ErrorMessage) - w.Int8(resource.ResourceType) - w.String(resource.ResourceName) - } - return w.Bytes(), nil -} - -// EncodeCreatePartitionsResponse renders bytes for create partitions responses. -func EncodeCreatePartitionsResponse(resp *CreatePartitionsResponse, version int16) ([]byte, error) { - if version < 0 || version > 3 { - return nil, fmt.Errorf("create partitions response version %d not supported", version) - } - flexible := version >= 2 - w := newByteWriter(128) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - w.Int32(resp.ThrottleMs) - if flexible { - w.CompactArrayLen(len(resp.Topics)) - } else { - w.Int32(int32(len(resp.Topics))) - } - for _, topic := range resp.Topics { - if flexible { - w.CompactString(topic.Name) - } else { - w.String(topic.Name) - } - w.Int16(topic.ErrorCode) - if flexible { - w.CompactNullableString(topic.ErrorMessage) - w.WriteTaggedFields(0) - } else { - w.NullableString(topic.ErrorMessage) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeDeleteGroupsResponse renders bytes for delete groups responses. -func EncodeDeleteGroupsResponse(resp *DeleteGroupsResponse, version int16) ([]byte, error) { - if version < 0 || version > 2 { - return nil, fmt.Errorf("delete groups response version %d not supported", version) - } - flexible := version >= 2 - w := newByteWriter(128) - w.Int32(resp.CorrelationID) - if flexible { - w.WriteTaggedFields(0) - } - w.Int32(resp.ThrottleMs) - if flexible { - w.CompactArrayLen(len(resp.Groups)) - } else { - w.Int32(int32(len(resp.Groups))) - } - for _, group := range resp.Groups { - if flexible { - w.CompactString(group.Group) - } else { - w.String(group.Group) - } - w.Int16(group.ErrorCode) - if flexible { - w.WriteTaggedFields(0) - } - } - if flexible { - w.WriteTaggedFields(0) - } - return w.Bytes(), nil -} - -// EncodeResponse wraps a response payload into a Kafka frame. -func EncodeResponse(payload []byte) ([]byte, error) { - if len(payload) > int(^uint32(0)>>1) { - return nil, fmt.Errorf("response too large: %d", len(payload)) - } - w := newByteWriter(len(payload) + 4) - w.Int32(int32(len(payload))) - w.write(payload) - return w.Bytes(), nil -} - -// GroupResponseErrorCode extracts the top-level error code from a group-related -// response. It accounts for version-dependent encoding (flexible headers, -// optional throttle_ms). Returns (errorCode, true) on success, or (0, false) -// if the response cannot be parsed. +import "github.com/twmb/franz-go/pkg/kmsg" + +// encodeResponseHeader builds the response header bytes that must be prepended +// to the kmsg response body. For non-flexible versions this is just the +// correlation ID (4 bytes). For flexible versions it also includes an empty +// tagged field section (1 byte: varint 0). +func encodeResponseHeader(correlationID int32, flexible bool) []byte { + if flexible { + buf := make([]byte, 5) + buf[0] = byte(correlationID >> 24) + buf[1] = byte(correlationID >> 16) + buf[2] = byte(correlationID >> 8) + buf[3] = byte(correlationID) + buf[4] = 0 // empty tagged fields + return buf + } + buf := make([]byte, 4) + buf[0] = byte(correlationID >> 24) + buf[1] = byte(correlationID >> 16) + buf[2] = byte(correlationID >> 8) + buf[3] = byte(correlationID) + return buf +} + +// EncodeResponse serializes a kmsg response with the proper header prepended. +// The API key and flexibility are derived from the response itself. // -// For DescribeGroups, which can contain multiple groups, this returns the error -// code of the first group in the response. +// ApiVersions (key 18) is a special case per KIP-511: the response body uses +// flexible encoding starting at version 3, but the response *header* is always +// the non-flexible v0 header (no tagged-fields byte). Clients rely on this to +// implement the v0 fallback during version negotiation. +func EncodeResponse(correlationID int32, apiVersion int16, resp kmsg.Response) []byte { + resp.SetVersion(apiVersion) + flexibleHeader := resp.IsFlexible() && resp.Key() != APIKeyApiVersion + header := encodeResponseHeader(correlationID, flexibleHeader) + return resp.AppendTo(header) +} + +// GroupResponseErrorCode extracts the top-level error code from already- +// serialized group response bytes, accounting for flexible headers and +// version-dependent layout. func GroupResponseErrorCode(apiKey int16, apiVersion int16, resp []byte) (int16, bool) { - r := newByteReader(resp) - - // All responses start with correlation_id. - if _, err := r.Int32(); err != nil { + body, ok := SkipResponseHeader(apiKey, apiVersion, resp) + if !ok { return 0, false } switch apiKey { case APIKeyJoinGroup: - return readGroupErrorCode(r, apiVersion >= 6, apiVersion >= 2) + return decodeJoinGroupErrorCode(apiVersion, body) case APIKeySyncGroup: - return readGroupErrorCode(r, apiVersion >= 4, apiVersion >= 1) + return decodeSyncGroupErrorCode(apiVersion, body) case APIKeyHeartbeat: - return readGroupErrorCode(r, apiVersion >= 4, apiVersion >= 1) + return decodeHeartbeatErrorCode(apiVersion, body) case APIKeyLeaveGroup: - // LeaveGroupResponse: correlation_id + error_code (no throttle_ms, no flex header in supported versions). - ec, err := r.Int16() - if err != nil { - return 0, false - } - return ec, true + return decodeLeaveGroupErrorCode(apiVersion, body) case APIKeyOffsetCommit: - return readOffsetCommitErrorCode(r) + return decodeOffsetCommitErrorCode(apiVersion, body) case APIKeyOffsetFetch: - return readOffsetFetchErrorCode(r, apiVersion) + return decodeOffsetFetchErrorCode(apiVersion, body) case APIKeyDescribeGroups: - return readDescribeGroupsFirstErrorCode(r, apiVersion) + return decodeDescribeGroupsErrorCode(apiVersion, body) default: return 0, false } } -// readGroupErrorCode reads the error code from a response with the layout: -// [tagged_fields if flexible] [throttle_ms(4) if hasThrottle] error_code(2) -func readGroupErrorCode(r *byteReader, flexible bool, hasThrottle bool) (int16, bool) { - if flexible { +// SkipResponseHeader skips the correlation ID and (for flexible versions) the +// tagged fields section. Returns the remaining body bytes. +func SkipResponseHeader(apiKey, apiVersion int16, data []byte) ([]byte, bool) { + if len(data) < 4 { + return nil, false + } + pos := 4 // skip correlation_id + + kResp := kmsg.ResponseForKey(apiKey) + if kResp == nil { + return nil, false + } + kResp.SetVersion(apiVersion) + if kResp.IsFlexible() { + if pos >= len(data) { + return nil, false + } + // Tagged fields length is a varint; always 0 for our responses. + r := newByteReader(data[pos:]) if err := r.SkipTaggedFields(); err != nil { - return 0, false + return nil, false } + pos += r.pos } - if hasThrottle { - if _, err := r.Int32(); err != nil { - return 0, false - } + return data[pos:], true +} + +func decodeJoinGroupErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrJoinGroupResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { + return 0, false } - ec, err := r.Int16() - if err != nil { + return resp.ErrorCode, true +} + +func decodeSyncGroupErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrSyncGroupResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { return 0, false } - return ec, true + return resp.ErrorCode, true } -// readOffsetCommitErrorCode reads the first partition error code from an OffsetCommitResponse. -// Layout: correlation_id(4) + throttle_ms(4) + topics... -func readOffsetCommitErrorCode(r *byteReader) (int16, bool) { - // throttle_ms - if _, err := r.Int32(); err != nil { +func decodeHeartbeatErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrHeartbeatResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { return 0, false } - // topic count - topicCount, err := r.Int32() - if err != nil { + return resp.ErrorCode, true +} + +func decodeLeaveGroupErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrLeaveGroupResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { return 0, false } - for i := int32(0); i < topicCount; i++ { - // topic name - if _, err := r.String(); err != nil { - return 0, false - } - // partition count - partCount, err := r.Int32() - if err != nil { - return 0, false - } - for j := int32(0); j < partCount; j++ { - // partition index - if _, err := r.Int32(); err != nil { - return 0, false - } - // error code - ec, err := r.Int16() - if err != nil { - return 0, false - } - if ec != 0 { - return ec, true + return resp.ErrorCode, true +} + +func decodeOffsetCommitErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrOffsetCommitResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { + return 0, false + } + for _, topic := range resp.Topics { + for _, part := range topic.Partitions { + if part.ErrorCode != 0 { + return part.ErrorCode, true } } } return 0, true } -// readOffsetFetchErrorCode reads the top-level error code from an OffsetFetchResponse. -// The top-level error code exists at version >= 2 and is at the end of the response. -// For simplicity, we read through the structure to find it. -func readOffsetFetchErrorCode(r *byteReader, version int16) (int16, bool) { - if version >= 3 { - // throttle_ms - if _, err := r.Int32(); err != nil { - return 0, false - } - } - // topic count - topicCount, err := r.Int32() - if err != nil { +func decodeOffsetFetchErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrOffsetFetchResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { return 0, false } - for i := int32(0); i < topicCount; i++ { - if _, err := r.String(); err != nil { - return 0, false - } - partCount, err := r.Int32() - if err != nil { - return 0, false - } - for j := int32(0); j < partCount; j++ { - // partition - if _, err := r.Int32(); err != nil { - return 0, false - } - // offset - if _, err := r.Int64(); err != nil { - return 0, false - } - // leader_epoch (version >= 5) - if version >= 5 { - if _, err := r.Int32(); err != nil { - return 0, false - } - } - // metadata (nullable string) - if _, err := r.NullableString(); err != nil { - return 0, false - } - // partition error code - ec, err := r.Int16() - if err != nil { - return 0, false - } - if ec != 0 { - return ec, true + for _, topic := range resp.Topics { + for _, part := range topic.Partitions { + if part.ErrorCode != 0 { + return part.ErrorCode, true } } } - // top-level error code (version >= 2) + // v2+ has a top-level error code. if version >= 2 { - ec, err := r.Int16() - if err != nil { - return 0, false - } - return ec, true + return resp.ErrorCode, true } return 0, true } -// readDescribeGroupsFirstErrorCode reads the error code of the first group -// in a DescribeGroupsResponse. -func readDescribeGroupsFirstErrorCode(r *byteReader, version int16) (int16, bool) { - flexible := version >= 5 - if flexible { - if err := r.SkipTaggedFields(); err != nil { - return 0, false - } - } - // throttle_ms (version >= 1) - if version >= 1 { - if _, err := r.Int32(); err != nil { - return 0, false - } - } - // group count - var groupCount int32 - if flexible { - gc, err := r.CompactArrayLen() - if err != nil { - return 0, false - } - groupCount = gc - } else { - gc, err := r.Int32() - if err != nil { - return 0, false - } - groupCount = gc +func decodeDescribeGroupsErrorCode(version int16, body []byte) (int16, bool) { + resp := kmsg.NewPtrDescribeGroupsResponse() + resp.SetVersion(version) + if err := resp.ReadFrom(body); err != nil { + return 0, false } - if groupCount == 0 { + if len(resp.Groups) == 0 { return 0, true } - // first group error code - ec, err := r.Int16() - if err != nil { - return 0, false - } - return ec, true + return resp.Groups[0].ErrorCode, true } diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 89a63ff3..dbcfdd46 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -16,1617 +16,41 @@ package protocol import ( - "encoding/binary" - "fmt" "testing" "github.com/twmb/franz-go/pkg/kmsg" ) -func strPtr(s string) *string { - return &s -} - -func TestEncodeApiVersionsResponseV0(t *testing.T) { - payload, err := EncodeApiVersionsResponse(&ApiVersionsResponse{ - CorrelationID: 99, - ErrorCode: 0, - Versions: []ApiVersion{ - {APIKey: APIKeyMetadata, MinVersion: 0, MaxVersion: 1}, - }, - }, 0) - if err != nil { - t.Fatalf("EncodeApiVersionsResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 99 { - t.Fatalf("unexpected correlation id %d", corr) - } -} - -func TestEncodeApiVersionsResponseV3(t *testing.T) { - resp := &ApiVersionsResponse{ - CorrelationID: 101, - ErrorCode: 0, - Versions: []ApiVersion{ - {APIKey: APIKeyMetadata, MinVersion: 0, MaxVersion: 12}, - }, - } - payload, err := EncodeApiVersionsResponse(resp, 3) - if err != nil { - t.Fatalf("EncodeApiVersionsResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 101 { - t.Fatalf("unexpected correlation id %d", corr) - } - body := payload[4:] - kmsgResp := kmsg.NewPtrApiVersionsResponse() - kmsgResp.Version = 3 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("decode api versions response: %v", err) - } - if len(kmsgResp.ApiKeys) != 1 || kmsgResp.ApiKeys[0].ApiKey != APIKeyMetadata { - t.Fatalf("unexpected api versions response: %#v", kmsgResp.ApiKeys) - } -} - -func TestEncodeApiVersionsResponseV4(t *testing.T) { - resp := &ApiVersionsResponse{ - CorrelationID: 102, - ErrorCode: 0, - Versions: []ApiVersion{ - {APIKey: APIKeyMetadata, MinVersion: 0, MaxVersion: 12}, - }, - } - payload, err := EncodeApiVersionsResponse(resp, 4) - if err != nil { - t.Fatalf("EncodeApiVersionsResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 102 { - t.Fatalf("unexpected correlation id %d", corr) - } - body := payload[4:] - kmsgResp := kmsg.NewPtrApiVersionsResponse() - kmsgResp.Version = 4 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("decode api versions response: %v", err) - } - if len(kmsgResp.ApiKeys) != 1 || kmsgResp.ApiKeys[0].ApiKey != APIKeyMetadata { - t.Fatalf("unexpected api versions response: %#v", kmsgResp.ApiKeys) - } -} - -func TestEncodeMetadataResponse(t *testing.T) { - clusterID := "cluster-1" - payload, err := EncodeMetadataResponse(&MetadataResponse{ - CorrelationID: 5, - ThrottleMs: 0, - Brokers: []MetadataBroker{ - {NodeID: 1, Host: "localhost", Port: 9092}, - }, - ClusterID: &clusterID, - ControllerID: 1, - Topics: []MetadataTopic{ - { - ErrorCode: 0, - Name: "orders", - Partitions: []MetadataPartition{ - { - ErrorCode: 0, - PartitionIndex: 0, - LeaderID: 1, - ReplicaNodes: []int32{1}, - ISRNodes: []int32{1}, - }, - }, - }, - }, - }, 0) - if err != nil { - t.Fatalf("EncodeMetadataResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 5 { - t.Fatalf("unexpected correlation id %d", corr) - } -} - -func TestEncodeCreateTopicsResponseV2(t *testing.T) { - resp := &CreateTopicsResponse{ - CorrelationID: 31, - ThrottleMs: 0, - Topics: []CreateTopicResult{ - {Name: "orders", ErrorCode: NONE}, - }, - } - payload, err := EncodeCreateTopicsResponse(resp, 2) - if err != nil { - t.Fatalf("EncodeCreateTopicsResponse: %v", err) - } - body := payload[4:] - kmsgResp := kmsg.NewPtrCreateTopicsResponse() - kmsgResp.Version = 2 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("decode create topics response: %v", err) - } - if len(kmsgResp.Topics) != 1 || kmsgResp.Topics[0].Topic != "orders" { - t.Fatalf("unexpected create topics response: %#v", kmsgResp.Topics) - } -} - -func TestEncodeDeleteTopicsResponseV1(t *testing.T) { - resp := &DeleteTopicsResponse{ - CorrelationID: 41, - ThrottleMs: 0, - Topics: []DeleteTopicResult{ - {Name: "orders", ErrorCode: NONE}, - }, - } - payload, err := EncodeDeleteTopicsResponse(resp, 1) - if err != nil { - t.Fatalf("EncodeDeleteTopicsResponse: %v", err) - } - body := payload[4:] - kmsgResp := kmsg.NewPtrDeleteTopicsResponse() - kmsgResp.Version = 1 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("decode delete topics response: %v", err) - } - if len(kmsgResp.Topics) != 1 || kmsgResp.Topics[0].Topic == nil || *kmsgResp.Topics[0].Topic != "orders" { - t.Fatalf("unexpected delete topics response: %#v", kmsgResp.Topics) - } -} - -func TestEncodeMetadataResponseV10IncludesTopicID(t *testing.T) { - clusterID := "cluster-1" - var topicID [16]byte - for i := range topicID { - topicID[i] = byte(i + 1) - } - payload, err := EncodeMetadataResponse(&MetadataResponse{ - CorrelationID: 7, - ThrottleMs: 0, - Brokers: []MetadataBroker{ - {NodeID: 1, Host: "localhost", Port: 9092}, - }, - ClusterID: &clusterID, - ControllerID: 1, - Topics: []MetadataTopic{ - { - ErrorCode: 0, - Name: "orders", - TopicID: topicID, - IsInternal: false, - Partitions: []MetadataPartition{ - { - ErrorCode: 0, - PartitionIndex: 0, - LeaderID: 1, - ReplicaNodes: []int32{1}, - ISRNodes: []int32{1}, - }, - }, - }, - }, - }, 10) - if err != nil { - t.Fatalf("EncodeMetadataResponse v10: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 7 { - t.Fatalf("unexpected correlation id %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - if _, err := reader.Int32(); err != nil { // throttle - t.Fatalf("read throttle: %v", err) - } - if brokers, _ := reader.CompactArrayLen(); brokers != 1 { - t.Fatalf("expected 1 broker got %d", brokers) - } - if _, err := reader.Int32(); err != nil { - t.Fatalf("read broker id: %v", err) - } - if host, _ := reader.CompactString(); host != "localhost" { - t.Fatalf("unexpected broker host %q", host) - } - reader.Int32() // port - if _, err := reader.CompactNullableString(); err != nil { - t.Fatalf("read rack: %v", err) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero broker tags got %d", tags) - } - if _, err := reader.CompactNullableString(); err != nil { - t.Fatalf("read cluster id: %v", err) - } - reader.Int32() // controller id - if topics, _ := reader.CompactArrayLen(); topics != 1 { - t.Fatalf("expected 1 topic got %d", topics) - } - reader.Int16() // error code - if name, _ := reader.CompactNullableString(); name == nil || *name != "orders" { - t.Fatalf("unexpected topic name %v", name) - } - id, err := reader.UUID() - if err != nil { - t.Fatalf("read topic id: %v", err) - } - if id != topicID { - t.Fatalf("unexpected topic id %v", id) - } - if internal, _ := reader.Bool(); internal { - t.Fatalf("expected non-internal topic") - } - if parts, _ := reader.CompactArrayLen(); parts != 1 { - t.Fatalf("expected 1 partition got %d", parts) - } - reader.Int16() // partition error - reader.Int32() // partition index - reader.Int32() // leader - reader.Int32() // leader epoch - if replicas, _ := reader.CompactArrayLen(); replicas != 1 { - t.Fatalf("expected 1 replica got %d", replicas) - } - reader.Int32() - if isr, _ := reader.CompactArrayLen(); isr != 1 { - t.Fatalf("expected 1 isr got %d", isr) - } - reader.Int32() - if offline, _ := reader.CompactArrayLen(); offline != 0 { - t.Fatalf("expected 0 offline replicas got %d", offline) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero partition tags got %d", tags) - } - reader.Int32() // authorized ops - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero topic tags got %d", tags) - } - reader.Int32() // cluster authorized ops - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes: %d", reader.remaining()) - } -} - -func TestEncodeProduceResponse(t *testing.T) { - payload, err := EncodeProduceResponse(&ProduceResponse{ - CorrelationID: 7, - Topics: []ProduceTopicResponse{ - { - Name: "orders", - Partitions: []ProducePartitionResponse{ - {Partition: 0, ErrorCode: 0, BaseOffset: 10, LogAppendTimeMs: 1234, LogStartOffset: 10}, - }, - }, - }, - ThrottleMs: 5, - }, 8) - if err != nil { - t.Fatalf("EncodeProduceResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 7 { - t.Fatalf("unexpected correlation id %d", corr) - } - topicCount, _ := reader.Int32() - if topicCount != 1 { - t.Fatalf("expected 1 topic got %d", topicCount) - } - if name, _ := reader.String(); name != "orders" { - t.Fatalf("unexpected topic %q", name) - } - partCount, _ := reader.Int32() - if partCount != 1 { - t.Fatalf("expected 1 partition got %d", partCount) - } - reader.Int32() // partition - reader.Int16() // error code - reader.Int64() // base offset - reader.Int64() // log append time - reader.Int64() // log start offset - if errCount, _ := reader.Int32(); errCount != 0 { - t.Fatalf("expected 0 record errors got %d", errCount) - } - if msg, _ := reader.NullableString(); msg != nil { - t.Fatalf("expected nil record error message got %v", msg) - } -} - -func TestEncodeProduceResponseFlexible(t *testing.T) { - payload, err := EncodeProduceResponse(&ProduceResponse{ - CorrelationID: 9, - Topics: []ProduceTopicResponse{ - { - Name: "orders", - Partitions: []ProducePartitionResponse{ - {Partition: 0, ErrorCode: 0, BaseOffset: 42, LogAppendTimeMs: 11, LogStartOffset: 5}, - }, - }, - }, - ThrottleMs: 3, - }, 9) - if err != nil { - t.Fatalf("EncodeProduceResponse flexible: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 9 { - t.Fatalf("unexpected correlation id %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - topicCount, _ := reader.CompactArrayLen() - if topicCount != 1 { - t.Fatalf("expected 1 topic got %d", topicCount) - } - name, _ := reader.CompactString() - if name != "orders" { - t.Fatalf("unexpected topic %q", name) - } - partCount, _ := reader.CompactArrayLen() - if partCount != 1 { - t.Fatalf("expected 1 partition got %d", partCount) - } - if partition, _ := reader.Int32(); partition != 0 { - t.Fatalf("unexpected partition %d", partition) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if base, _ := reader.Int64(); base != 42 { - t.Fatalf("unexpected base offset %d", base) - } - reader.Int64() // log append time - reader.Int64() // log start offset - if errCount, _ := reader.CompactArrayLen(); errCount != 0 { - t.Fatalf("expected 0 record errors got %d", errCount) - } - if msg, _ := reader.CompactNullableString(); msg != nil { - t.Fatalf("expected nil record error message got %v", msg) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero partition tags got %d", tags) - } - if topicTags, _ := reader.UVarint(); topicTags != 0 { - t.Fatalf("expected zero topic tags got %d", topicTags) - } - if throttle, _ := reader.Int32(); throttle != 3 { - t.Fatalf("unexpected throttle %d", throttle) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes: %d", reader.remaining()) - } -} - -func TestEncodeProduceResponseLegacyVersions(t *testing.T) { - resp := &ProduceResponse{ - CorrelationID: 7, - Topics: []ProduceTopicResponse{ - { - Name: "orders", - Partitions: []ProducePartitionResponse{ - {Partition: 0, ErrorCode: 0, BaseOffset: 10, LogAppendTimeMs: 123, LogStartOffset: 5}, - }, - }, - }, - ThrottleMs: 0, - } - - tests := []struct { - name string - version int16 - }{ - {name: "v0", version: 0}, - {name: "v7", version: 7}, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - payload, err := EncodeProduceResponse(resp, tc.version) - if err != nil { - t.Fatalf("EncodeProduceResponse v%d: %v", tc.version, err) - } - reader := newByteReader(payload) - if _, err := reader.Int32(); err != nil { - t.Fatalf("read correlation: %v", err) - } - topicCount, err := reader.Int32() - if err != nil { - t.Fatalf("read topic count: %v", err) - } - for i := int32(0); i < topicCount; i++ { - if _, err := reader.String(); err != nil { - t.Fatalf("read topic name: %v", err) - } - partCount, err := reader.Int32() - if err != nil { - t.Fatalf("read partition count: %v", err) - } - for j := int32(0); j < partCount; j++ { - if _, err := reader.Int32(); err != nil { - t.Fatalf("read partition id: %v", err) - } - if _, err := reader.Int16(); err != nil { - t.Fatalf("read error code: %v", err) - } - if _, err := reader.Int64(); err != nil { - t.Fatalf("read base offset: %v", err) - } - if tc.version >= 3 { - if _, err := reader.Int64(); err != nil { - t.Fatalf("read log append time: %v", err) - } - } - if tc.version >= 5 { - if _, err := reader.Int64(); err != nil { - t.Fatalf("read log start offset: %v", err) - } - } - if tc.version >= 8 { - if _, err := reader.Int32(); err != nil { - t.Fatalf("read log offset delta: %v", err) - } - } - } - } - if tc.version >= 1 { - if _, err := reader.Int32(); err != nil { - t.Fatalf("read throttle ms: %v", err) - } - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes: %d", reader.remaining()) - } - }) - } -} - -func TestEncodeListOffsetsResponseV0(t *testing.T) { - payload, err := EncodeListOffsetsResponse(0, &ListOffsetsResponse{ - CorrelationID: 15, - Topics: []ListOffsetsTopicResponse{ - { - Name: "orders", - Partitions: []ListOffsetsPartitionResponse{ - {Partition: 0, ErrorCode: 0, OldStyleOffsets: []int64{42}}, - }, - }, - }, - }) - if err != nil { - t.Fatalf("EncodeListOffsetsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 15 { - t.Fatalf("unexpected correlation id %d", corr) - } - if topics, _ := reader.Int32(); topics != 1 { - t.Fatalf("unexpected topic count %d", topics) - } - if name, _ := reader.String(); name != "orders" { - t.Fatalf("unexpected topic name %q", name) - } - if parts, _ := reader.Int32(); parts != 1 { - t.Fatalf("unexpected partition count %d", parts) - } - if part, _ := reader.Int32(); part != 0 { - t.Fatalf("unexpected partition %d", part) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if count, _ := reader.Int32(); count != 1 { - t.Fatalf("unexpected offset count %d", count) - } - if offset, _ := reader.Int64(); offset != 42 { - t.Fatalf("unexpected offset %d", offset) - } - if reader.remaining() != 0 { - t.Fatalf("expected no remaining bytes, got %d", reader.remaining()) - } -} - -func TestEncodeFetchResponse(t *testing.T) { - payload, err := EncodeFetchResponse(&FetchResponse{ - CorrelationID: 3, - ThrottleMs: 9, - ErrorCode: NONE, - SessionID: 7, - Topics: []FetchTopicResponse{ - { - Name: "orders", - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 10, - LastStableOffset: 10, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: []byte("records"), - }, - }, - }, - }, - }, 11) - if err != nil { - t.Fatalf("EncodeFetchResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 3 { - t.Fatalf("unexpected correlation id %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 9 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if session, _ := reader.Int32(); session != 7 { - t.Fatalf("unexpected session id %d", session) - } - if topicCount, _ := reader.Int32(); topicCount != 1 { - t.Fatalf("unexpected topic count %d", topicCount) - } - name, _ := reader.String() - if name != "orders" { - t.Fatalf("unexpected topic %q", name) - } - if partCount, _ := reader.Int32(); partCount != 1 { - t.Fatalf("unexpected partition count %d", partCount) - } - if partition, _ := reader.Int32(); partition != 0 { - t.Fatalf("unexpected partition %d", partition) - } - if perr, _ := reader.Int16(); perr != 0 { - t.Fatalf("unexpected partition error %d", perr) - } - if hw, _ := reader.Int64(); hw != 10 { - t.Fatalf("unexpected high watermark %d", hw) - } - if lso, _ := reader.Int64(); lso != 10 { - t.Fatalf("unexpected lso %d", lso) - } - if lsoff, _ := reader.Int64(); lsoff != 0 { - t.Fatalf("unexpected log start offset %d", lsoff) - } - if abortedCount, _ := reader.Int32(); abortedCount != 0 { - t.Fatalf("unexpected aborted txns %d", abortedCount) - } - if pref, _ := reader.Int32(); pref != -1 { - t.Fatalf("unexpected preferred replica %d", pref) - } - recordLen, _ := reader.Int32() - if recordLen != int32(len("records")) { - t.Fatalf("unexpected record set length %d", recordLen) - } - if _, err := reader.read(int(recordLen)); err != nil { - t.Fatalf("read record set: %v", err) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeFetchResponseV13(t *testing.T) { - var topicID [16]byte - for i := range topicID { - topicID[i] = byte(i + 1) - } - payload, err := EncodeFetchResponse(&FetchResponse{ - CorrelationID: 11, - ThrottleMs: 1, - ErrorCode: NONE, - SessionID: 2, - Topics: []FetchTopicResponse{ - { - TopicID: topicID, - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 5, - LastStableOffset: 5, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: []byte("records"), - }, - }, - }, - }, - }, 13) - if err != nil { - t.Fatalf("EncodeFetchResponse v13: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 11 { - t.Fatalf("unexpected correlation id %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - if throttle, _ := reader.Int32(); throttle != 1 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if session, _ := reader.Int32(); session != 2 { - t.Fatalf("unexpected session id %d", session) - } - if topicCount, _ := reader.CompactArrayLen(); topicCount != 1 { - t.Fatalf("unexpected topic count %d", topicCount) - } - gotID, err := reader.UUID() - if err != nil { - t.Fatalf("read topic id: %v", err) - } - if gotID != topicID { - t.Fatalf("unexpected topic id %v", gotID) - } - if partCount, _ := reader.CompactArrayLen(); partCount != 1 { - t.Fatalf("unexpected partition count %d", partCount) - } - if partition, _ := reader.Int32(); partition != 0 { - t.Fatalf("unexpected partition %d", partition) - } - if perr, _ := reader.Int16(); perr != 0 { - t.Fatalf("unexpected partition error %d", perr) - } - if hw, _ := reader.Int64(); hw != 5 { - t.Fatalf("unexpected high watermark %d", hw) - } - if lso, _ := reader.Int64(); lso != 5 { - t.Fatalf("unexpected lso %d", lso) - } - if lsoff, _ := reader.Int64(); lsoff != 0 { - t.Fatalf("unexpected log start offset %d", lsoff) - } - if abortedCount, _ := reader.CompactArrayLen(); abortedCount != 0 { - t.Fatalf("unexpected aborted txns %d", abortedCount) - } - if pref, _ := reader.Int32(); pref != -1 { - t.Fatalf("unexpected preferred replica %d", pref) - } - recordSet, err := reader.CompactBytes() - if err != nil { - t.Fatalf("read record set: %v", err) - } - if string(recordSet) != "records" { - t.Fatalf("unexpected record set %q", recordSet) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero partition tags got %d", tags) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero topic tags got %d", tags) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeFetchResponseV13EmptyRecordSet(t *testing.T) { - var topicID [16]byte - for i := range topicID { - topicID[i] = byte(i + 1) - } - payload, err := EncodeFetchResponse(&FetchResponse{ - CorrelationID: 12, - ThrottleMs: 0, - ErrorCode: NONE, - SessionID: 0, - Topics: []FetchTopicResponse{ - { - TopicID: topicID, - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 5, - LastStableOffset: 5, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: nil, - }, - }, - }, - }, - }, 13) - if err != nil { - t.Fatalf("EncodeFetchResponse v13 empty: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 12 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - if throttle, _ := reader.Int32(); throttle != 0 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if session, _ := reader.Int32(); session != 0 { - t.Fatalf("unexpected session id %d", session) - } - if topicCount, _ := reader.CompactArrayLen(); topicCount != 1 { - t.Fatalf("unexpected topic count %d", topicCount) - } - gotID, err := reader.UUID() - if err != nil { - t.Fatalf("read topic id: %v", err) - } - if gotID != topicID { - t.Fatalf("unexpected topic id %v", gotID) - } - if partCount, _ := reader.CompactArrayLen(); partCount != 1 { - t.Fatalf("unexpected partition count %d", partCount) - } - if partition, _ := reader.Int32(); partition != 0 { - t.Fatalf("unexpected partition %d", partition) - } - if perr, _ := reader.Int16(); perr != 0 { - t.Fatalf("unexpected partition error %d", perr) - } - if hw, _ := reader.Int64(); hw != 5 { - t.Fatalf("unexpected high watermark %d", hw) - } - if lso, _ := reader.Int64(); lso != 5 { - t.Fatalf("unexpected lso %d", lso) - } - if lsoff, _ := reader.Int64(); lsoff != 0 { - t.Fatalf("unexpected log start offset %d", lsoff) - } - if abortedCount, _ := reader.CompactArrayLen(); abortedCount != 0 { - t.Fatalf("unexpected aborted txns %d", abortedCount) - } - if pref, _ := reader.Int32(); pref != -1 { - t.Fatalf("unexpected preferred replica %d", pref) - } - recordSet, err := reader.CompactBytes() - if err != nil { - t.Fatalf("read record set: %v", err) - } - if recordSet == nil || len(recordSet) != 0 { - t.Fatalf("expected empty record set, got %#v", recordSet) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero partition tags got %d", tags) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero topic tags got %d", tags) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeFetchResponseV13KmsgRoundTrip(t *testing.T) { - var topicID [16]byte - for i := range topicID { - topicID[i] = byte(i + 1) - } - recordSet := makeTestRecordBatch(2, 0) - payload, err := EncodeFetchResponse(&FetchResponse{ - CorrelationID: 21, - ThrottleMs: 0, - ErrorCode: NONE, - SessionID: 0, - Topics: []FetchTopicResponse{ - { - TopicID: topicID, - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 2, - LastStableOffset: 2, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: recordSet, - }, - }, - }, - }, - }, 13) - if err != nil { - t.Fatalf("EncodeFetchResponse v13: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 21 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrFetchResponse() - kmsgResp.Version = 13 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Topics) != 1 || len(kmsgResp.Topics[0].Partitions) != 1 { - t.Fatalf("unexpected topic/partition counts: %+v", kmsgResp.Topics) - } - if kmsgResp.Topics[0].TopicID != topicID { - t.Fatalf("unexpected topic id %v", kmsgResp.Topics[0].TopicID) - } - part := kmsgResp.Topics[0].Partitions[0] - if part.ErrorCode != 0 { - t.Fatalf("unexpected partition error %d", part.ErrorCode) - } - if len(part.RecordBatches) != len(recordSet) { - t.Fatalf("unexpected record batch length %d", len(part.RecordBatches)) - } -} - -func TestEncodeFindCoordinatorResponseFlexible(t *testing.T) { - payload, err := EncodeFindCoordinatorResponse(&FindCoordinatorResponse{ - CorrelationID: 4, - ThrottleMs: 7, - ErrorCode: 0, - NodeID: 1, - Host: "127.0.0.1", - Port: 39092, - }, 3) - if err != nil { - t.Fatalf("EncodeFindCoordinatorResponse: %v", err) - } - reader := newByteReader(payload) - corr, _ := reader.Int32() - if corr != 4 { - t.Fatalf("unexpected correlation id %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - if throttle, _ := reader.Int32(); throttle != 7 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if errMsg, _ := reader.CompactNullableString(); errMsg != nil { - t.Fatalf("expected nil error message got %q", *errMsg) - } - if nodeID, _ := reader.Int32(); nodeID != 1 { - t.Fatalf("unexpected node id %d", nodeID) - } - host, _ := reader.CompactString() - if host != "127.0.0.1" { - t.Fatalf("unexpected host %q", host) - } - if port, _ := reader.Int32(); port != 39092 { - t.Fatalf("unexpected port %d", port) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeDescribeGroupsResponseV5KmsgRoundTrip(t *testing.T) { - payload, err := EncodeDescribeGroupsResponse(&DescribeGroupsResponse{ - CorrelationID: 55, - ThrottleMs: 0, - Groups: []DescribeGroupsResponseGroup{ - { - ErrorCode: NONE, - GroupID: "group-1", - State: "Stable", - ProtocolType: "consumer", - Protocol: "range", - AuthorizedOperations: 0, - Members: []DescribeGroupsResponseGroupMember{ - { - MemberID: "member-1", - ClientID: "client-1", - ClientHost: "127.0.0.1", - ProtocolMetadata: []byte{0x01}, - MemberAssignment: []byte{0x02}, - }, - }, - }, - }, - }, 5) - if err != nil { - t.Fatalf("EncodeDescribeGroupsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 55 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrDescribeGroupsResponse() - kmsgResp.Version = 5 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Groups) != 1 { - t.Fatalf("unexpected groups: %+v", kmsgResp.Groups) - } - group := kmsgResp.Groups[0] - if group.Group != "group-1" || group.State != "Stable" { - t.Fatalf("unexpected group data: %+v", group) - } - if len(group.Members) != 1 || group.Members[0].MemberID != "member-1" { - t.Fatalf("unexpected member data: %+v", group.Members) - } -} - -func TestEncodeListGroupsResponseV5KmsgRoundTrip(t *testing.T) { - payload, err := EncodeListGroupsResponse(&ListGroupsResponse{ - CorrelationID: 77, - ThrottleMs: 0, - ErrorCode: NONE, - Groups: []ListGroupsResponseGroup{ - { - GroupID: "group-1", - ProtocolType: "consumer", - GroupState: "Stable", - GroupType: "classic", - }, - }, - }, 5) - if err != nil { - t.Fatalf("EncodeListGroupsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 77 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrListGroupsResponse() - kmsgResp.Version = 5 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Groups) != 1 || kmsgResp.Groups[0].Group != "group-1" { - t.Fatalf("unexpected list groups: %+v", kmsgResp.Groups) - } -} - -func TestEncodeOffsetForLeaderEpochResponseV3KmsgRoundTrip(t *testing.T) { - payload, err := EncodeOffsetForLeaderEpochResponse(&OffsetForLeaderEpochResponse{ - CorrelationID: 13, - ThrottleMs: 0, - Topics: []OffsetForLeaderEpochTopicResponse{ - { - Name: "orders", - Partitions: []OffsetForLeaderEpochPartitionResponse{ - {Partition: 0, ErrorCode: NONE, LeaderEpoch: 1, EndOffset: 12}, - }, - }, - }, - }, 3) - if err != nil { - t.Fatalf("EncodeOffsetForLeaderEpochResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 13 { - t.Fatalf("unexpected correlation id %d", corr) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrOffsetForLeaderEpochResponse() - kmsgResp.Version = 3 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Topics) != 1 || kmsgResp.Topics[0].Topic != "orders" { - t.Fatalf("unexpected response: %+v", kmsgResp.Topics) - } -} - -func TestEncodeDescribeConfigsResponseV4KmsgRoundTrip(t *testing.T) { - payload, err := EncodeDescribeConfigsResponse(&DescribeConfigsResponse{ - CorrelationID: 19, - ThrottleMs: 0, - Resources: []DescribeConfigsResponseResource{ - { - ErrorCode: NONE, - ResourceType: ConfigResourceTopic, - ResourceName: "orders", - Configs: []DescribeConfigsResponseConfig{ - { - Name: "retention.ms", - Value: strPtr("1000"), - ReadOnly: false, - IsDefault: false, - Source: ConfigSourceDynamicTopic, - IsSensitive: false, - ConfigType: ConfigTypeLong, - }, - }, - }, - }, - }, 4) - if err != nil { - t.Fatalf("EncodeDescribeConfigsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 19 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrDescribeConfigsResponse() - kmsgResp.Version = 4 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Resources) != 1 || kmsgResp.Resources[0].ResourceName != "orders" { - t.Fatalf("unexpected resources: %+v", kmsgResp.Resources) - } -} - -func TestEncodeAlterConfigsResponseV1KmsgRoundTrip(t *testing.T) { - payload, err := EncodeAlterConfigsResponse(&AlterConfigsResponse{ - CorrelationID: 27, - ThrottleMs: 0, - Resources: []AlterConfigsResponseResource{ - { - ErrorCode: NONE, - ResourceType: ConfigResourceTopic, - ResourceName: "orders", - }, - }, - }, 1) - if err != nil { - t.Fatalf("EncodeAlterConfigsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 27 { - t.Fatalf("unexpected correlation id %d", corr) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrAlterConfigsResponse() - kmsgResp.Version = 1 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Resources) != 1 || kmsgResp.Resources[0].ResourceName != "orders" { - t.Fatalf("unexpected response: %+v", kmsgResp.Resources) - } -} - -func TestEncodeCreatePartitionsResponseV3KmsgRoundTrip(t *testing.T) { - payload, err := EncodeCreatePartitionsResponse(&CreatePartitionsResponse{ - CorrelationID: 33, - ThrottleMs: 0, - Topics: []CreatePartitionsResponseTopic{ - {Name: "orders", ErrorCode: NONE}, - }, - }, 3) - if err != nil { - t.Fatalf("EncodeCreatePartitionsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 33 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrCreatePartitionsResponse() - kmsgResp.Version = 3 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Topics) != 1 || kmsgResp.Topics[0].Topic != "orders" { - t.Fatalf("unexpected response: %+v", kmsgResp.Topics) - } -} - -func TestEncodeDeleteGroupsResponseV2KmsgRoundTrip(t *testing.T) { - payload, err := EncodeDeleteGroupsResponse(&DeleteGroupsResponse{ - CorrelationID: 35, - ThrottleMs: 0, - Groups: []DeleteGroupsResponseGroup{ - {Group: "group-1", ErrorCode: NONE}, - }, - }, 2) - if err != nil { - t.Fatalf("EncodeDeleteGroupsResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 35 { - t.Fatalf("unexpected correlation id %d", corr) - } - if err := reader.SkipTaggedFields(); err != nil { - t.Fatalf("skip response header tags: %v", err) - } - body := payload[reader.pos:] - kmsgResp := kmsg.NewPtrDeleteGroupsResponse() - kmsgResp.Version = 2 - if err := kmsgResp.ReadFrom(body); err != nil { - t.Fatalf("kmsg decode: %v", err) - } - if len(kmsgResp.Groups) != 1 || kmsgResp.Groups[0].Group != "group-1" { - t.Fatalf("unexpected response: %+v", kmsgResp.Groups) - } -} - -func TestEncodeFindCoordinatorResponseLegacy(t *testing.T) { - errMsg := "ok" - payload, err := EncodeFindCoordinatorResponse(&FindCoordinatorResponse{ - CorrelationID: 2, - ThrottleMs: 9, - ErrorCode: 1, - ErrorMessage: &errMsg, - NodeID: 5, - Host: "node-1", - Port: 9092, - }, 2) - if err != nil { - t.Fatalf("EncodeFindCoordinatorResponse legacy: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 2 { - t.Fatalf("unexpected correlation id %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 9 { - t.Fatalf("unexpected throttle %d", throttle) - } - if code, _ := reader.Int16(); code != 1 { - t.Fatalf("unexpected error code %d", code) - } - msg, _ := reader.NullableString() - if msg == nil || *msg != "ok" { - t.Fatalf("unexpected error message %v", msg) - } - if node, _ := reader.Int32(); node != 5 { - t.Fatalf("unexpected node %d", node) - } - host, _ := reader.String() - if host != "node-1" { - t.Fatalf("unexpected host %q", host) - } - if port, _ := reader.Int32(); port != 9092 { - t.Fatalf("unexpected port %d", port) - } -} - -func TestEncodeJoinGroupResponseV4(t *testing.T) { - payload, err := EncodeJoinGroupResponse(&JoinGroupResponse{ - CorrelationID: 5, - ThrottleMs: 7, - ErrorCode: 0, - GenerationID: 3, - ProtocolName: "range", - LeaderID: "member-1", - MemberID: "member-2", - Members: []JoinGroupMember{ - {MemberID: "member-1", Metadata: []byte{0x01}}, - {MemberID: "member-2", Metadata: []byte{0x02}}, - }, - }, 4) - if err != nil { - t.Fatalf("EncodeJoinGroupResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 5 { - t.Fatalf("unexpected correlation id %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 7 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if gen, _ := reader.Int32(); gen != 3 { - t.Fatalf("unexpected generation %d", gen) - } - if proto, _ := reader.String(); proto != "range" { - t.Fatalf("unexpected protocol %q", proto) - } - if leader, _ := reader.String(); leader != "member-1" { - t.Fatalf("unexpected leader %q", leader) - } - if member, _ := reader.String(); member != "member-2" { - t.Fatalf("unexpected member %q", member) - } - if count, _ := reader.Int32(); count != 2 { - t.Fatalf("unexpected member count %d", count) - } - for i := 0; i < 2; i++ { - id, _ := reader.String() - if id != fmt.Sprintf("member-%d", i+1) { - t.Fatalf("unexpected member id %q", id) - } - length, _ := reader.Int32() - if length != 1 { - t.Fatalf("unexpected metadata length %d", length) - } - reader.read(int(length)) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeSyncGroupResponseV2(t *testing.T) { - payload, err := EncodeSyncGroupResponse(&SyncGroupResponse{ - CorrelationID: 11, - ThrottleMs: 8, - ErrorCode: NONE, - Assignment: []byte{0x01, 0x02}, - }, 2) - if err != nil { - t.Fatalf("EncodeSyncGroupResponse v2: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 11 { - t.Fatalf("unexpected correlation %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 8 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - length, _ := reader.Int32() - if length != 2 { - t.Fatalf("unexpected assignment length %d", length) - } - if data, _ := reader.read(int(length)); len(data) != 2 || data[0] != 0x01 || data[1] != 0x02 { - t.Fatalf("unexpected assignment payload %v", data) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} -func TestEncodeSyncGroupResponseFlexibleV4(t *testing.T) { - payload, err := EncodeSyncGroupResponse(&SyncGroupResponse{ - CorrelationID: 13, - ThrottleMs: 4, - ErrorCode: NONE, - Assignment: []byte{0xaa}, - }, 4) - if err != nil { - t.Fatalf("EncodeSyncGroupResponse flexible: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 13 { - t.Fatalf("unexpected correlation %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - if throttle, _ := reader.Int32(); throttle != 4 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if b, _ := reader.CompactBytes(); len(b) != 1 || b[0] != 0xaa { - t.Fatalf("unexpected assignment %v", b) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeHeartbeatResponseV2(t *testing.T) { - payload, err := EncodeHeartbeatResponse(&HeartbeatResponse{ - CorrelationID: 21, - ThrottleMs: 9, - ErrorCode: NONE, - }, 2) - if err != nil { - t.Fatalf("EncodeHeartbeatResponse v2: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 21 { - t.Fatalf("unexpected correlation %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 9 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeHeartbeatResponseFlexibleV4(t *testing.T) { - payload, err := EncodeHeartbeatResponse(&HeartbeatResponse{ - CorrelationID: 22, - ThrottleMs: 3, - ErrorCode: NONE, - }, 4) - if err != nil { - t.Fatalf("EncodeHeartbeatResponse flexible: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 22 { - t.Fatalf("unexpected correlation %d", corr) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero header tags got %d", tags) - } - if throttle, _ := reader.Int32(); throttle != 3 { - t.Fatalf("unexpected throttle %d", throttle) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected error code %d", errCode) - } - if tags, _ := reader.UVarint(); tags != 0 { - t.Fatalf("expected zero response tags got %d", tags) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestEncodeOffsetFetchResponse(t *testing.T) { - resp := &OffsetFetchResponse{ - CorrelationID: 31, - ThrottleMs: 12, - Topics: []OffsetFetchTopicResponse{ - { - Name: "orders", - Partitions: []OffsetFetchPartitionResponse{ - {Partition: 0, Offset: 42, LeaderEpoch: -1, Metadata: strPtr("meta"), ErrorCode: NONE}, - }, - }, - }, - ErrorCode: NONE, - } - payload, err := EncodeOffsetFetchResponse(resp, 5) - if err != nil { - t.Fatalf("EncodeOffsetFetchResponse: %v", err) - } - reader := newByteReader(payload) - if corr, _ := reader.Int32(); corr != 31 { - t.Fatalf("unexpected correlation %d", corr) - } - if throttle, _ := reader.Int32(); throttle != 12 { - t.Fatalf("unexpected throttle %d", throttle) - } - if topics, _ := reader.Int32(); topics != 1 { - t.Fatalf("unexpected topic count %d", topics) - } - name, _ := reader.String() - if name != "orders" { - t.Fatalf("unexpected topic %q", name) - } - if partitions, _ := reader.Int32(); partitions != 1 { - t.Fatalf("unexpected partition count %d", partitions) - } - if part, _ := reader.Int32(); part != 0 { - t.Fatalf("unexpected partition %d", part) - } - if offset, _ := reader.Int64(); offset != 42 { - t.Fatalf("unexpected offset %d", offset) - } - if leader, _ := reader.Int32(); leader != -1 { - t.Fatalf("unexpected leader epoch %d", leader) - } - metaStr, _ := reader.NullableString() - if metaStr == nil || *metaStr != "meta" { - t.Fatalf("unexpected metadata %v", metaStr) - } - if perr, _ := reader.Int16(); perr != 0 { - t.Fatalf("unexpected partition error %d", perr) - } - if errCode, _ := reader.Int16(); errCode != 0 { - t.Fatalf("unexpected response error %d", errCode) - } - if reader.remaining() != 0 { - t.Fatalf("unexpected trailing bytes %d", reader.remaining()) - } -} - -func TestParseProduceResponseRoundTrip(t *testing.T) { - resp := &ProduceResponse{ - CorrelationID: 99, - Topics: []ProduceTopicResponse{ - { - Name: "orders", - Partitions: []ProducePartitionResponse{ - {Partition: 0, ErrorCode: NONE, BaseOffset: 42, LogAppendTimeMs: 1000, LogStartOffset: 5}, - {Partition: 1, ErrorCode: NOT_LEADER_OR_FOLLOWER, BaseOffset: -1}, - }, - }, - { - Name: "events", - Partitions: []ProducePartitionResponse{ - {Partition: 0, ErrorCode: NONE, BaseOffset: 100, LogAppendTimeMs: 2000, LogStartOffset: 10}, - }, - }, - }, - ThrottleMs: 7, - } - for _, version := range []int16{3, 5, 7, 8, 9, 10} { - t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { - encoded, err := EncodeProduceResponse(resp, version) - if err != nil { - t.Fatalf("encode: %v", err) - } - parsed, err := ParseProduceResponse(encoded, version) - if err != nil { - t.Fatalf("parse: %v", err) - } - if parsed.CorrelationID != resp.CorrelationID { - t.Fatalf("correlation id: got %d want %d", parsed.CorrelationID, resp.CorrelationID) - } - if len(parsed.Topics) != len(resp.Topics) { - t.Fatalf("topic count: got %d want %d", len(parsed.Topics), len(resp.Topics)) - } - for ti, topic := range parsed.Topics { - if topic.Name != resp.Topics[ti].Name { - t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, resp.Topics[ti].Name) - } - if len(topic.Partitions) != len(resp.Topics[ti].Partitions) { - t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(resp.Topics[ti].Partitions)) - } - for pi, part := range topic.Partitions { - want := resp.Topics[ti].Partitions[pi] - if part.Partition != want.Partition || part.ErrorCode != want.ErrorCode || part.BaseOffset != want.BaseOffset { - t.Fatalf("topic[%d].part[%d]: got %+v want %+v", ti, pi, part, want) - } - } - } - if version >= 1 && parsed.ThrottleMs != resp.ThrottleMs { - t.Fatalf("throttle: got %d want %d", parsed.ThrottleMs, resp.ThrottleMs) - } - }) - } -} - -func TestEncodeProduceRequestRoundTrip(t *testing.T) { - req := &ProduceRequest{ - Acks: -1, - TimeoutMs: 5000, - Topics: []ProduceTopic{ - { - Name: "orders", - Partitions: []ProducePartition{ - {Partition: 0, Records: []byte{1, 2, 3}}, - {Partition: 1, Records: []byte{4, 5}}, - {Partition: 2, Records: []byte{6, 7, 8, 9}}, - }, - }, - { - Name: "events", - Partitions: []ProducePartition{ - {Partition: 0, Records: []byte{10}}, - {Partition: 3, Records: []byte{11, 12}}, - }, - }, - }, - } - for _, version := range []int16{3, 5, 7, 8, 9, 10} { - t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { - header := &RequestHeader{ - APIKey: APIKeyProduce, - APIVersion: version, - CorrelationID: 77, - ClientID: strPtr("test-client"), - } - encoded, err := EncodeProduceRequest(header, req, version) - if err != nil { - t.Fatalf("encode: %v", err) - } - _, parsedReq, err := ParseRequest(encoded) - if err != nil { - t.Fatalf("parse: %v", err) - } - got, ok := parsedReq.(*ProduceRequest) - if !ok { - t.Fatalf("expected *ProduceRequest, got %T", parsedReq) - } - if len(got.Topics) != len(req.Topics) { - t.Fatalf("topic count: got %d want %d", len(got.Topics), len(req.Topics)) - } - for ti, topic := range got.Topics { - want := req.Topics[ti] - if topic.Name != want.Name { - t.Fatalf("topic[%d] name: got %q want %q", ti, topic.Name, want.Name) - } - if len(topic.Partitions) != len(want.Partitions) { - t.Fatalf("topic[%d] partition count: got %d want %d", ti, len(topic.Partitions), len(want.Partitions)) - } - for pi, part := range topic.Partitions { - wantPart := want.Partitions[pi] - if part.Partition != wantPart.Partition { - t.Fatalf("topic[%d].part[%d] index: got %d want %d", ti, pi, part.Partition, wantPart.Partition) - } - if string(part.Records) != string(wantPart.Records) { - t.Fatalf("topic[%d].part[%d] records: got %x want %x", ti, pi, part.Records, wantPart.Records) - } - } - } - }) - } -} - -func makeTestRecordBatch(count int32, baseOffset int64) []byte { - const size = 90 - data := make([]byte, size) - binary.BigEndian.PutUint64(data[0:8], uint64(baseOffset)) - binary.BigEndian.PutUint32(data[8:12], uint32(size-12)) - binary.BigEndian.PutUint32(data[23:27], uint32(count-1)) - binary.BigEndian.PutUint32(data[57:61], uint32(count)) - return data -} - -// TestGroupResponseErrorCode_RoundTrip encodes known responses via the standard -// Encode* functions, then verifies GroupResponseErrorCode extracts the error. +// TestGroupResponseErrorCode_RoundTrip encodes known responses via EncodeResponse, +// then verifies GroupResponseErrorCode extracts the correct error code. func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { tests := []struct { name string apiKey int16 apiVersion int16 - encode func() ([]byte, error) + encode func() []byte wantCode int16 }{ { name: "JoinGroup v2 NOT_COORDINATOR", apiKey: APIKeyJoinGroup, apiVersion: 2, - encode: func() ([]byte, error) { - return EncodeJoinGroupResponse(&JoinGroupResponse{ - CorrelationID: 1, - ThrottleMs: 0, - ErrorCode: NOT_COORDINATOR, - }, 2) + encode: func() []byte { + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(1, 2, resp) }, wantCode: NOT_COORDINATOR, }, { - name: "JoinGroup v5 NOT_COORDINATOR", + name: "JoinGroup v5 NOT_COORDINATOR (flexible)", apiKey: APIKeyJoinGroup, apiVersion: 5, - encode: func() ([]byte, error) { - return EncodeJoinGroupResponse(&JoinGroupResponse{ - CorrelationID: 2, - ThrottleMs: 0, - ErrorCode: NOT_COORDINATOR, - }, 5) + encode: func() []byte { + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(2, 5, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1634,27 +58,24 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "JoinGroup v2 success", apiKey: APIKeyJoinGroup, apiVersion: 2, - encode: func() ([]byte, error) { - return EncodeJoinGroupResponse(&JoinGroupResponse{ - CorrelationID: 3, - ThrottleMs: 0, - ErrorCode: 0, - ProtocolName: "range", - LeaderID: "member-1", - MemberID: "member-1", - }, 2) + encode: func() []byte { + resp := kmsg.NewPtrJoinGroupResponse() + resp.ErrorCode = NONE + resp.Protocol = kmsg.StringPtr("range") + resp.LeaderID = "member-1" + resp.MemberID = "member-1" + return EncodeResponse(3, 2, resp) }, - wantCode: 0, + wantCode: NONE, }, { name: "SyncGroup v1 NOT_COORDINATOR", apiKey: APIKeySyncGroup, apiVersion: 1, - encode: func() ([]byte, error) { - return EncodeSyncGroupResponse(&SyncGroupResponse{ - CorrelationID: 4, - ErrorCode: NOT_COORDINATOR, - }, 1) + encode: func() []byte { + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(4, 1, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1662,11 +83,10 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "SyncGroup v4 NOT_COORDINATOR (flexible)", apiKey: APIKeySyncGroup, apiVersion: 4, - encode: func() ([]byte, error) { - return EncodeSyncGroupResponse(&SyncGroupResponse{ - CorrelationID: 5, - ErrorCode: NOT_COORDINATOR, - }, 4) + encode: func() []byte { + resp := kmsg.NewPtrSyncGroupResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(5, 4, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1674,11 +94,10 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "Heartbeat v1 NOT_COORDINATOR", apiKey: APIKeyHeartbeat, apiVersion: 1, - encode: func() ([]byte, error) { - return EncodeHeartbeatResponse(&HeartbeatResponse{ - CorrelationID: 6, - ErrorCode: NOT_COORDINATOR, - }, 1) + encode: func() []byte { + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(6, 1, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1686,81 +105,74 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "Heartbeat v4 NOT_COORDINATOR (flexible)", apiKey: APIKeyHeartbeat, apiVersion: 4, - encode: func() ([]byte, error) { - return EncodeHeartbeatResponse(&HeartbeatResponse{ - CorrelationID: 7, - ErrorCode: NOT_COORDINATOR, - }, 4) + encode: func() []byte { + resp := kmsg.NewPtrHeartbeatResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(7, 4, resp) }, wantCode: NOT_COORDINATOR, }, { - name: "LeaveGroup NOT_COORDINATOR", + name: "LeaveGroup v0 NOT_COORDINATOR", apiKey: APIKeyLeaveGroup, apiVersion: 0, - encode: func() ([]byte, error) { - return EncodeLeaveGroupResponse(&LeaveGroupResponse{ - CorrelationID: 8, - ErrorCode: NOT_COORDINATOR, - }) + encode: func() []byte { + resp := kmsg.NewPtrLeaveGroupResponse() + resp.ErrorCode = NOT_COORDINATOR + return EncodeResponse(8, 0, resp) }, wantCode: NOT_COORDINATOR, }, { - name: "OffsetCommit with NOT_COORDINATOR on partition", + name: "OffsetCommit v3 partition error", apiKey: APIKeyOffsetCommit, apiVersion: 3, - encode: func() ([]byte, error) { - return EncodeOffsetCommitResponse(&OffsetCommitResponse{ - CorrelationID: 9, - Topics: []OffsetCommitTopicResponse{ - { - Name: "test-topic", - Partitions: []OffsetCommitPartitionResponse{ - {Partition: 0, ErrorCode: NOT_COORDINATOR}, - }, - }, - }, - }) + encode: func() []byte { + resp := kmsg.NewPtrOffsetCommitResponse() + topic := kmsg.NewOffsetCommitResponseTopic() + topic.Topic = "test-topic" + part := kmsg.NewOffsetCommitResponseTopicPartition() + part.Partition = 0 + part.ErrorCode = NOT_COORDINATOR + topic.Partitions = append(topic.Partitions, part) + resp.Topics = append(resp.Topics, topic) + return EncodeResponse(9, 3, resp) }, wantCode: NOT_COORDINATOR, }, { - name: "OffsetCommit success (no false positive for partition 16)", + name: "OffsetCommit v3 success (partition 16, no false positive)", apiKey: APIKeyOffsetCommit, apiVersion: 3, - encode: func() ([]byte, error) { - return EncodeOffsetCommitResponse(&OffsetCommitResponse{ - CorrelationID: 10, - Topics: []OffsetCommitTopicResponse{ - { - Name: "test-topic", - Partitions: []OffsetCommitPartitionResponse{ - {Partition: 16, ErrorCode: 0}, - }, - }, - }, - }) - }, - wantCode: 0, + encode: func() []byte { + resp := kmsg.NewPtrOffsetCommitResponse() + topic := kmsg.NewOffsetCommitResponseTopic() + topic.Topic = "test-topic" + part := kmsg.NewOffsetCommitResponseTopicPartition() + part.Partition = 16 + part.ErrorCode = NONE + topic.Partitions = append(topic.Partitions, part) + resp.Topics = append(resp.Topics, topic) + return EncodeResponse(10, 3, resp) + }, + wantCode: NONE, }, { name: "OffsetFetch v3 NOT_COORDINATOR (top-level)", apiKey: APIKeyOffsetFetch, apiVersion: 3, - encode: func() ([]byte, error) { - return EncodeOffsetFetchResponse(&OffsetFetchResponse{ - CorrelationID: 11, - Topics: []OffsetFetchTopicResponse{ - { - Name: "test-topic", - Partitions: []OffsetFetchPartitionResponse{ - {Partition: 0, Offset: -1, LeaderEpoch: -1, ErrorCode: NOT_COORDINATOR}, - }, - }, - }, - ErrorCode: NOT_COORDINATOR, - }, 3) + encode: func() []byte { + resp := kmsg.NewPtrOffsetFetchResponse() + resp.ErrorCode = NOT_COORDINATOR + topic := kmsg.NewOffsetFetchResponseTopic() + topic.Topic = "test-topic" + part := kmsg.NewOffsetFetchResponseTopicPartition() + part.Partition = 0 + part.Offset = -1 + part.ErrorCode = NOT_COORDINATOR + topic.Partitions = append(topic.Partitions, part) + resp.Topics = append(resp.Topics, topic) + return EncodeResponse(11, 3, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1768,33 +180,32 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "OffsetFetch v5 success (offset 16, no false positive)", apiKey: APIKeyOffsetFetch, apiVersion: 5, - encode: func() ([]byte, error) { - return EncodeOffsetFetchResponse(&OffsetFetchResponse{ - CorrelationID: 12, - Topics: []OffsetFetchTopicResponse{ - { - Name: "test-topic", - Partitions: []OffsetFetchPartitionResponse{ - {Partition: 0, Offset: 16, LeaderEpoch: 0, ErrorCode: 0}, - }, - }, - }, - ErrorCode: 0, - }, 5) - }, - wantCode: 0, + encode: func() []byte { + resp := kmsg.NewPtrOffsetFetchResponse() + resp.ErrorCode = NONE + topic := kmsg.NewOffsetFetchResponseTopic() + topic.Topic = "test-topic" + part := kmsg.NewOffsetFetchResponseTopicPartition() + part.Partition = 0 + part.Offset = 16 + part.ErrorCode = NONE + topic.Partitions = append(topic.Partitions, part) + resp.Topics = append(resp.Topics, topic) + return EncodeResponse(12, 5, resp) + }, + wantCode: NONE, }, { name: "DescribeGroups v5 NOT_COORDINATOR", apiKey: APIKeyDescribeGroups, apiVersion: 5, - encode: func() ([]byte, error) { - return EncodeDescribeGroupsResponse(&DescribeGroupsResponse{ - CorrelationID: 13, - Groups: []DescribeGroupsResponseGroup{ - {ErrorCode: NOT_COORDINATOR, GroupID: "my-group"}, - }, - }, 5) + encode: func() []byte { + resp := kmsg.NewPtrDescribeGroupsResponse() + group := kmsg.NewDescribeGroupsResponseGroup() + group.ErrorCode = NOT_COORDINATOR + group.Group = "my-group" + resp.Groups = append(resp.Groups, group) + return EncodeResponse(13, 5, resp) }, wantCode: NOT_COORDINATOR, }, @@ -1802,24 +213,22 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { name: "DescribeGroups v5 success", apiKey: APIKeyDescribeGroups, apiVersion: 5, - encode: func() ([]byte, error) { - return EncodeDescribeGroupsResponse(&DescribeGroupsResponse{ - CorrelationID: 14, - Groups: []DescribeGroupsResponseGroup{ - {ErrorCode: 0, GroupID: "my-group", State: "Stable"}, - }, - }, 5) + encode: func() []byte { + resp := kmsg.NewPtrDescribeGroupsResponse() + group := kmsg.NewDescribeGroupsResponseGroup() + group.ErrorCode = NONE + group.Group = "my-group" + group.State = "Stable" + resp.Groups = append(resp.Groups, group) + return EncodeResponse(14, 5, resp) }, - wantCode: 0, + wantCode: NONE, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - resp, err := tc.encode() - if err != nil { - t.Fatalf("encode: %v", err) - } + resp := tc.encode() gotCode, ok := GroupResponseErrorCode(tc.apiKey, tc.apiVersion, resp) if !ok { t.Fatalf("GroupResponseErrorCode returned ok=false for valid response") @@ -1831,162 +240,7 @@ func TestGroupResponseErrorCode_RoundTrip(t *testing.T) { } } -func TestParseFetchResponse_RoundTrip(t *testing.T) { - tests := []struct { - name string - version int16 - resp *FetchResponse - }{ - { - name: "v11 name-based", - version: 11, - resp: &FetchResponse{ - CorrelationID: 7, - ThrottleMs: 0, - ErrorCode: NONE, - SessionID: 0, - Topics: []FetchTopicResponse{ - { - Name: "orders", - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 100, - LastStableOffset: 100, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: []byte("test-records"), - }, - { - Partition: 1, - ErrorCode: NOT_LEADER_OR_FOLLOWER, - HighWatermark: 0, - LastStableOffset: 0, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: []byte{}, - }, - }, - }, - }, - }, - }, - { - name: "v13 topic-id-based", - version: 13, - resp: &FetchResponse{ - CorrelationID: 11, - ThrottleMs: 5, - ErrorCode: NONE, - SessionID: 42, - Topics: []FetchTopicResponse{ - { - TopicID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - Partitions: []FetchPartitionResponse{ - { - Partition: 0, - ErrorCode: NONE, - HighWatermark: 50, - LastStableOffset: 50, - LogStartOffset: 0, - PreferredReadReplica: -1, - RecordSet: []byte("hello"), - }, - }, - }, - }, - }, - }, - { - name: "v11 multiple topics", - version: 11, - resp: &FetchResponse{ - CorrelationID: 99, - ThrottleMs: 0, - ErrorCode: NONE, - SessionID: 0, - Topics: []FetchTopicResponse{ - { - Name: "orders", - Partitions: []FetchPartitionResponse{ - {Partition: 0, ErrorCode: NONE, HighWatermark: 10, LastStableOffset: 10, PreferredReadReplica: -1, RecordSet: []byte("a")}, - }, - }, - { - Name: "events", - Partitions: []FetchPartitionResponse{ - {Partition: 0, ErrorCode: NONE, HighWatermark: 20, LastStableOffset: 20, PreferredReadReplica: -1, RecordSet: []byte("b")}, - }, - }, - }, - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - encoded, err := EncodeFetchResponse(tc.resp, tc.version) - if err != nil { - t.Fatalf("EncodeFetchResponse: %v", err) - } - - parsed, err := ParseFetchResponse(encoded, tc.version) - if err != nil { - t.Fatalf("ParseFetchResponse: %v", err) - } - - if parsed.CorrelationID != tc.resp.CorrelationID { - t.Fatalf("CorrelationID: got %d, want %d", parsed.CorrelationID, tc.resp.CorrelationID) - } - if parsed.ThrottleMs != tc.resp.ThrottleMs { - t.Fatalf("ThrottleMs: got %d, want %d", parsed.ThrottleMs, tc.resp.ThrottleMs) - } - if parsed.ErrorCode != tc.resp.ErrorCode { - t.Fatalf("ErrorCode: got %d, want %d", parsed.ErrorCode, tc.resp.ErrorCode) - } - if parsed.SessionID != tc.resp.SessionID { - t.Fatalf("SessionID: got %d, want %d", parsed.SessionID, tc.resp.SessionID) - } - if len(parsed.Topics) != len(tc.resp.Topics) { - t.Fatalf("topic count: got %d, want %d", len(parsed.Topics), len(tc.resp.Topics)) - } - for ti, topic := range parsed.Topics { - wantTopic := tc.resp.Topics[ti] - if tc.version >= 12 { - if topic.TopicID != wantTopic.TopicID { - t.Fatalf("topic[%d] ID mismatch", ti) - } - } else { - if topic.Name != wantTopic.Name { - t.Fatalf("topic[%d] name: got %q, want %q", ti, topic.Name, wantTopic.Name) - } - } - if len(topic.Partitions) != len(wantTopic.Partitions) { - t.Fatalf("topic[%d] partition count: got %d, want %d", ti, len(topic.Partitions), len(wantTopic.Partitions)) - } - for pi, part := range topic.Partitions { - wantPart := wantTopic.Partitions[pi] - if part.Partition != wantPart.Partition { - t.Fatalf("topic[%d] part[%d]: got %d, want %d", ti, pi, part.Partition, wantPart.Partition) - } - if part.ErrorCode != wantPart.ErrorCode { - t.Fatalf("topic[%d] part[%d] error: got %d, want %d", ti, pi, part.ErrorCode, wantPart.ErrorCode) - } - if part.HighWatermark != wantPart.HighWatermark { - t.Fatalf("topic[%d] part[%d] HW: got %d, want %d", ti, pi, part.HighWatermark, wantPart.HighWatermark) - } - if string(part.RecordSet) != string(wantPart.RecordSet) { - t.Fatalf("topic[%d] part[%d] records: got %q, want %q", ti, pi, part.RecordSet, wantPart.RecordSet) - } - } - } - }) - } -} - func TestGroupResponseErrorCode_Truncated(t *testing.T) { - // A truncated response should return ok=false. _, ok := GroupResponseErrorCode(APIKeyJoinGroup, 2, []byte{0, 0, 0, 1}) if ok { t.Fatalf("expected ok=false for truncated JoinGroup response") @@ -1997,3 +251,10 @@ func TestGroupResponseErrorCode_Truncated(t *testing.T) { t.Fatalf("expected ok=false for truncated LeaveGroup response") } } + +func TestGroupResponseErrorCode_UnsupportedKey(t *testing.T) { + _, ok := GroupResponseErrorCode(APIKeyProduce, 9, []byte{0, 0, 0, 1, 0, 0, 0, 0}) + if ok { + t.Fatalf("expected ok=false for unsupported api key") + } +} diff --git a/pkg/protocol/config.go b/pkg/protocol/types.go similarity index 54% rename from pkg/protocol/config.go rename to pkg/protocol/types.go index 053de7a8..ff530981 100644 --- a/pkg/protocol/config.go +++ b/pkg/protocol/types.go @@ -15,31 +15,10 @@ package protocol -// Config resource types. -const ( - ConfigResourceTopic int8 = 2 - ConfigResourceBroker int8 = 4 -) - -// Config sources. -const ( - ConfigSourceUnknown int8 = -1 - ConfigSourceDynamicTopic int8 = 1 - ConfigSourceDynamicBroker int8 = 2 - ConfigSourceStaticBroker int8 = 4 - ConfigSourceDefaultConfig int8 = 5 - ConfigSourceGroupConfig int8 = 8 -) +import "github.com/twmb/franz-go/pkg/kmsg" -// Config types. -const ( - ConfigTypeBoolean int8 = 1 - ConfigTypeString int8 = 2 - ConfigTypeInt int8 = 3 - ConfigTypeShort int8 = 4 - ConfigTypeLong int8 = 5 - ConfigTypeDouble int8 = 6 - ConfigTypeList int8 = 7 - ConfigTypeClass int8 = 8 - ConfigTypePassword int8 = 9 +type ( + MetadataBroker = kmsg.MetadataResponseBroker + MetadataTopic = kmsg.MetadataResponseTopic + MetadataPartition = kmsg.MetadataResponseTopicPartition ) diff --git a/test/e2e/franz_test.go b/test/e2e/franz_test.go index 283545af..d71df91d 100644 --- a/test/e2e/franz_test.go +++ b/test/e2e/franz_test.go @@ -53,6 +53,13 @@ func TestFranzGoProduceConsume(t *testing.T) { fmt.Sprintf("KAFSCALE_BROKER_ADDR=%s", brokerAddr), fmt.Sprintf("KAFSCALE_METRICS_ADDR=%s", metricsAddr), fmt.Sprintf("KAFSCALE_CONTROL_ADDR=%s", controlAddr), + "KAFSCALE_S3_BUCKET="+envOrDefault("KAFSCALE_S3_BUCKET", "kafscale"), + "KAFSCALE_S3_REGION="+envOrDefault("KAFSCALE_S3_REGION", "us-east-1"), + "KAFSCALE_S3_NAMESPACE="+envOrDefault("KAFSCALE_S3_NAMESPACE", "default"), + "KAFSCALE_S3_ENDPOINT="+envOrDefault("KAFSCALE_S3_ENDPOINT", "http://127.0.0.1:9000"), + "KAFSCALE_S3_PATH_STYLE="+envOrDefault("KAFSCALE_S3_PATH_STYLE", "true"), + "KAFSCALE_S3_ACCESS_KEY="+envOrDefault("KAFSCALE_S3_ACCESS_KEY", "minioadmin"), + "KAFSCALE_S3_SECRET_KEY="+envOrDefault("KAFSCALE_S3_SECRET_KEY", "minioadmin"), ) var brokerLogs bytes.Buffer var franzLogs bytes.Buffer diff --git a/test/e2e/mcp_test.go b/test/e2e/mcp_test.go index 8e2c9e18..ecd7c3ed 100644 --- a/test/e2e/mcp_test.go +++ b/test/e2e/mcp_test.go @@ -28,11 +28,12 @@ import ( "testing" "time" - "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/KafScale/platform/internal/mcpserver" metadatapb "github.com/KafScale/platform/pkg/gen/metadata" "github.com/KafScale/platform/pkg/metadata" "github.com/KafScale/platform/pkg/protocol" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/twmb/franz-go/pkg/kmsg" ) func TestMCPServer(t *testing.T) { @@ -51,16 +52,16 @@ func TestMCPServer(t *testing.T) { ControllerID: 1, Topics: []protocol.MetadataTopic{ { - Name: "orders", + Topic: kmsg.StringPtr("orders"), Partitions: []protocol.MetadataPartition{ - {PartitionIndex: 0, LeaderID: 1}, - {PartitionIndex: 1, LeaderID: 1}, + {Partition: 0, Leader: 1}, + {Partition: 1, Leader: 1}, }, }, { - Name: "payments", + Topic: kmsg.StringPtr("payments"), Partitions: []protocol.MetadataPartition{ - {PartitionIndex: 0, LeaderID: 1}, + {Partition: 0, Leader: 1}, }, }, }, diff --git a/test/e2e/snapshot_test.go b/test/e2e/snapshot_test.go index 944756bd..146aebb8 100644 --- a/test/e2e/snapshot_test.go +++ b/test/e2e/snapshot_test.go @@ -91,7 +91,7 @@ func TestSnapshotPublishAndBrokerConsumption(t *testing.T) { if len(meta.Brokers) == 0 { t.Fatalf("expected brokers in snapshot, got none") } - if len(meta.Topics) != 1 || meta.Topics[0].Name != "orders" { + if len(meta.Topics) != 1 || *meta.Topics[0].Topic != "orders" { t.Fatalf("snapshot missing topic: %+v", meta.Topics) } if meta.ClusterID == nil || *meta.ClusterID != "cluster-uid" {