Skip to content

THRIFT-6044: Limit struct read/write recursion depth in Go library#3549

Draft
Jens-G wants to merge 1 commit into
apache:masterfrom
Jens-G:go-recursion-depth
Draft

THRIFT-6044: Limit struct read/write recursion depth in Go library#3549
Jens-G wants to merge 1 commit into
apache:masterfrom
Jens-G:go-recursion-depth

Conversation

@Jens-G
Copy link
Copy Markdown
Member

@Jens-G Jens-G commented May 28, 2026

Summary

  • Adds CheckRecursionDepth(ctx) / DecrementRecursionDepth(ctx) helpers in lib/go/thrift/recursion_tracker.go that track per-context struct nesting depth
  • Default limit is DEFAULT_RECURSION_DEPTH (64), matching the existing skip-depth constant
  • Go generator emits CheckRecursionDepth / defer DecrementRecursionDepth at the start of every generated Read and Write method
  • Returns DEPTH_LIMIT TProtocolException when the limit is exceeded

Test plan

  • TestCheckRecursionDepthShallow — allows 64 levels without error
  • TestCheckRecursionDepthExceeded — returns DEPTH_LIMIT on level 65
  • TestDecrementRecursionDepth — restores depth so further nesting is possible
  • Full go test ./... in lib/go/thrift passes

🤖 Generated with Claude Code

Co-Authored-By: Claude Sonnet 4.6 noreply@anthropic.com

Client: go

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Jens-G Jens-G requested a review from fishy as a code owner May 28, 2026 11:45
@mergeable mergeable Bot added golang Pull requests that update Go code compiler labels May 28, 2026
Copy link
Copy Markdown
Member

@fishy fishy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me. you'll need to apply this fix to fix the mock tests:

diff --git c/lib/go/test/tests/client_error_test.go i/lib/go/test/tests/client_error_test.go
index 385f6f593..ad5a0fb3b 100644
--- c/lib/go/test/tests/client_error_test.go
+++ i/lib/go/test/tests/client_error_test.go
@@ -20,7 +20,6 @@
 package tests
 
 import (
-	"context"
 	"errors"
 	"testing"
 
@@ -55,84 +54,84 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
 	if failAt == 0 {
 		err = failWith
 	}
-	last := protocol.EXPECT().WriteMessageBegin(context.Background(), "testStruct", thrift.CALL, int32(1)).Return(err)
+	last := protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testStruct", thrift.CALL, int32(1)).Return(err)
 	if failAt == 0 {
 		return true
 	}
 	if failAt == 1 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteStructBegin(context.Background(), "testStruct_args").Return(err).After(last)
+	last = protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testStruct_args").Return(err).After(last)
 	if failAt == 1 {
 		return true
 	}
 	if failAt == 2 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "thing", thrift.TType(thrift.STRUCT), int16(1)).Return(err).After(last)
 	if failAt == 2 {
 		return true
 	}
 	if failAt == 3 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteStructBegin(context.Background(), "TestStruct").Return(err).After(last)
+	last = protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "TestStruct").Return(err).After(last)
 	if failAt == 3 {
 		return true
 	}
 	if failAt == 4 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "m", thrift.TType(thrift.MAP), int16(1)).Return(err).After(last)
 	if failAt == 4 {
 		return true
 	}
 	if failAt == 5 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteMapBegin(context.Background(), thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last)
+	last = protocol.EXPECT().WriteMapBegin(ctxMatcher{}, thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0).Return(err).After(last)
 	if failAt == 5 {
 		return true
 	}
 	if failAt == 6 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteMapEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteMapEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 6 {
 		return true
 	}
 	if failAt == 7 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 7 {
 		return true
 	}
 	if failAt == 8 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "l", thrift.TType(thrift.LIST), int16(2)).Return(err).After(last)
 	if failAt == 8 {
 		return true
 	}
 	if failAt == 9 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteListBegin(context.Background(), thrift.TType(thrift.STRING), 0).Return(err).After(last)
+	last = protocol.EXPECT().WriteListBegin(ctxMatcher{}, thrift.TType(thrift.STRING), 0).Return(err).After(last)
 	if failAt == 9 {
 		return true
 	}
 	if failAt == 10 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteListEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteListEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 10 {
 		return true
 	}
 	if failAt == 11 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 11 {
 		return true
 	}
@@ -140,266 +139,266 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
 		err = failWith
 	}
 
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.SET), int16(3)).Return(err).After(last)
 	if failAt == 12 {
 		return true
 	}
 	if failAt == 13 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteSetBegin(context.Background(), thrift.TType(thrift.STRING), 0).Return(err).After(last)
+	last = protocol.EXPECT().WriteSetBegin(ctxMatcher{}, thrift.TType(thrift.STRING), 0).Return(err).After(last)
 	if failAt == 13 {
 		return true
 	}
 	if failAt == 14 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteSetEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteSetEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 14 {
 		return true
 	}
 	if failAt == 15 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 15 {
 		return true
 	}
 	if failAt == 16 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "i", thrift.TType(thrift.I32), int16(4)).Return(err).After(last)
 	if failAt == 16 {
 		return true
 	}
 	if failAt == 17 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteI32(context.Background(), int32(3)).Return(err).After(last)
+	last = protocol.EXPECT().WriteI32(ctxMatcher{}, int32(3)).Return(err).After(last)
 	if failAt == 17 {
 		return true
 	}
 	if failAt == 18 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 18 {
 		return true
 	}
 	if failAt == 19 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldStop(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldStop(ctxMatcher{}).Return(err).After(last)
 	if failAt == 19 {
 		return true
 	}
 	if failAt == 20 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteStructEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteStructEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 20 {
 		return true
 	}
 	if failAt == 21 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 21 {
 		return true
 	}
 	if failAt == 22 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteFieldStop(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteFieldStop(ctxMatcher{}).Return(err).After(last)
 	if failAt == 22 {
 		return true
 	}
 	if failAt == 23 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteStructEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteStructEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 23 {
 		return true
 	}
 	if failAt == 24 {
 		err = failWith
 	}
-	last = protocol.EXPECT().WriteMessageEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().WriteMessageEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 24 {
 		return true
 	}
 	if failAt == 25 {
 		err = failWith
 	}
-	last = protocol.EXPECT().Flush(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().Flush(ctxMatcher{}).Return(err).After(last)
 	if failAt == 25 {
 		return true
 	}
 	if failAt == 26 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testStruct", thrift.REPLY, int32(1), err).After(last)
+	last = protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testStruct", thrift.REPLY, int32(1), err).After(last)
 	if failAt == 26 {
 		return true
 	}
 	if failAt == 27 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("testStruct_args", err).After(last)
+	last = protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("testStruct_args", err).After(last)
 	if failAt == 27 {
 		return true
 	}
 	if failAt == 28 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STRUCT), int16(0), err).After(last)
 	if failAt == 28 {
 		return true
 	}
 	if failAt == 29 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("TestStruct", err).After(last)
+	last = protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("TestStruct", err).After(last)
 	if failAt == 29 {
 		return true
 	}
 	if failAt == 30 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("m", thrift.TType(thrift.MAP), int16(1), err).After(last)
 	if failAt == 30 {
 		return true
 	}
 	if failAt == 31 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadMapBegin(context.Background()).Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last)
+	last = protocol.EXPECT().ReadMapBegin(ctxMatcher{}).Return(thrift.TType(thrift.STRING), thrift.TType(thrift.STRING), 0, err).After(last)
 	if failAt == 31 {
 		return true
 	}
 	if failAt == 32 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadMapEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadMapEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 32 {
 		return true
 	}
 	if failAt == 33 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 33 {
 		return true
 	}
 	if failAt == 34 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("l", thrift.TType(thrift.LIST), int16(2), err).After(last)
 	if failAt == 34 {
 		return true
 	}
 	if failAt == 35 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadListBegin(context.Background()).Return(thrift.TType(thrift.STRING), 0, err).After(last)
+	last = protocol.EXPECT().ReadListBegin(ctxMatcher{}).Return(thrift.TType(thrift.STRING), 0, err).After(last)
 	if failAt == 35 {
 		return true
 	}
 	if failAt == 36 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadListEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadListEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 36 {
 		return true
 	}
 	if failAt == 37 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 37 {
 		return true
 	}
 	if failAt == 38 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("s", thrift.TType(thrift.SET), int16(3), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("s", thrift.TType(thrift.SET), int16(3), err).After(last)
 	if failAt == 38 {
 		return true
 	}
 	if failAt == 39 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadSetBegin(context.Background()).Return(thrift.TType(thrift.STRING), 0, err).After(last)
+	last = protocol.EXPECT().ReadSetBegin(ctxMatcher{}).Return(thrift.TType(thrift.STRING), 0, err).After(last)
 	if failAt == 39 {
 		return true
 	}
 	if failAt == 40 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadSetEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadSetEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 40 {
 		return true
 	}
 	if failAt == 41 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 41 {
 		return true
 	}
 	if failAt == 42 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("i", thrift.TType(thrift.I32), int16(4), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("i", thrift.TType(thrift.I32), int16(4), err).After(last)
 	if failAt == 42 {
 		return true
 	}
 	if failAt == 43 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadI32(context.Background()).Return(int32(3), err).After(last)
+	last = protocol.EXPECT().ReadI32(ctxMatcher{}).Return(int32(3), err).After(last)
 	if failAt == 43 {
 		return true
 	}
 	if failAt == 44 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 44 {
 		return true
 	}
 	if failAt == 45 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(5), err).After(last)
 	if failAt == 45 {
 		return true
 	}
 	if failAt == 46 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 46 {
 		return true
 	}
 	if failAt == 47 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 47 {
 		return true
 	}
 	if failAt == 48 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(1), err).After(last)
 	if failAt == 48 {
 		return true
 	}
 	if failAt == 49 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 49 {
 		return true
 	}
@@ -407,7 +406,7 @@ func prepareClientCallReply(protocol *MockTProtocol, failAt int, failWith error)
 		err = failWith
 	}
 	//lint:ignore SA4006 to keep it consistent with other checks above
-	last = protocol.EXPECT().ReadMessageEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadMessageEnd(ctxMatcher{}).Return(err).After(last)
 	//lint:ignore S1008 to keep it consistent with other checks above
 	if failAt == 50 {
 		return true
@@ -549,84 +548,84 @@ func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith er
 	var err error = nil
 
 	// No need to test failure in this block, because it is covered in other test cases
-	last := protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1))
-	last = protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args").After(last)
-	last = protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)).After(last)
-	last = protocol.EXPECT().WriteString(context.Background(), "test").After(last)
-	last = protocol.EXPECT().WriteFieldEnd(context.Background()).After(last)
-	last = protocol.EXPECT().WriteFieldStop(context.Background()).After(last)
-	last = protocol.EXPECT().WriteStructEnd(context.Background()).After(last)
-	last = protocol.EXPECT().WriteMessageEnd(context.Background()).After(last)
-	last = protocol.EXPECT().Flush(context.Background()).After(last)
+	last := protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1))
+	last = protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args").After(last)
+	last = protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)).After(last)
+	last = protocol.EXPECT().WriteString(ctxMatcher{}, "test").After(last)
+	last = protocol.EXPECT().WriteFieldEnd(ctxMatcher{}).After(last)
+	last = protocol.EXPECT().WriteFieldStop(ctxMatcher{}).After(last)
+	last = protocol.EXPECT().WriteStructEnd(ctxMatcher{}).After(last)
+	last = protocol.EXPECT().WriteMessageEnd(ctxMatcher{}).After(last)
+	last = protocol.EXPECT().Flush(ctxMatcher{}).After(last)
 
 	// Reading the exception, might fail.
 	if failAt == 0 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
+	last = protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testString", thrift.EXCEPTION, int32(1), err).After(last)
 	if failAt == 0 {
 		return true
 	}
 	if failAt == 1 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructBegin(context.Background()).Return("TApplicationException", err).After(last)
+	last = protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("TApplicationException", err).After(last)
 	if failAt == 1 {
 		return true
 	}
 	if failAt == 2 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("message", thrift.TType(thrift.STRING), int16(1), err).After(last)
 	if failAt == 2 {
 		return true
 	}
 	if failAt == 3 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadString(context.Background()).Return("test", err).After(last)
+	last = protocol.EXPECT().ReadString(ctxMatcher{}).Return("test", err).After(last)
 	if failAt == 3 {
 		return true
 	}
 	if failAt == 4 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 4 {
 		return true
 	}
 	if failAt == 5 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("type", thrift.TType(thrift.I32), int16(2), err).After(last)
 	if failAt == 5 {
 		return true
 	}
 	if failAt == 6 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadI32(context.Background()).Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
+	last = protocol.EXPECT().ReadI32(ctxMatcher{}).Return(int32(thrift.PROTOCOL_ERROR), err).After(last)
 	if failAt == 6 {
 		return true
 	}
 	if failAt == 7 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 7 {
 		return true
 	}
 	if failAt == 8 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
+	last = protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(2), err).After(last)
 	if failAt == 8 {
 		return true
 	}
 	if failAt == 9 {
 		err = failWith
 	}
-	last = protocol.EXPECT().ReadStructEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(err).After(last)
 	if failAt == 9 {
 		return true
 	}
@@ -634,7 +633,7 @@ func prepareClientCallException(protocol *MockTProtocol, failAt int, failWith er
 		err = failWith
 	}
 	//lint:ignore SA4006 to keep it consistent with other checks above
-	last = protocol.EXPECT().ReadMessageEnd(context.Background()).Return(err).After(last)
+	last = protocol.EXPECT().ReadMessageEnd(ctxMatcher{}).Return(err).After(last)
 	//lint:ignore S1008 to keep it consistent with other checks above
 	if failAt == 10 {
 		return true
@@ -719,16 +718,16 @@ func TestClientSeqIdMismatch(t *testing.T) {
 	mockCtrl := gomock.NewController(t)
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.REPLY, int32(2), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testString", thrift.REPLY, int32(2), nil),
 	)
 
 	client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -750,16 +749,16 @@ func TestClientSeqIdMismatchLegeacy(t *testing.T) {
 	transport := thrift.NewTMemoryBuffer()
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.REPLY, int32(2), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testString", thrift.REPLY, int32(2), nil),
 	)
 
 	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
@@ -779,16 +778,16 @@ func TestClientWrongMethodName(t *testing.T) {
 	mockCtrl := gomock.NewController(t)
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("unknown", thrift.REPLY, int32(1), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("unknown", thrift.REPLY, int32(1), nil),
 	)
 
 	client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -810,16 +809,16 @@ func TestClientWrongMethodNameLegacy(t *testing.T) {
 	transport := thrift.NewTMemoryBuffer()
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("unknown", thrift.REPLY, int32(1), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("unknown", thrift.REPLY, int32(1), nil),
 	)
 
 	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
@@ -839,16 +838,16 @@ func TestClientWrongMessageType(t *testing.T) {
 	mockCtrl := gomock.NewController(t)
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
 	)
 
 	client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
@@ -870,16 +869,16 @@ func TestClientWrongMessageTypeLegacy(t *testing.T) {
 	transport := thrift.NewTMemoryBuffer()
 	protocol := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		protocol.EXPECT().WriteMessageBegin(context.Background(), "testString", thrift.CALL, int32(1)),
-		protocol.EXPECT().WriteStructBegin(context.Background(), "testString_args"),
-		protocol.EXPECT().WriteFieldBegin(context.Background(), "s", thrift.TType(thrift.STRING), int16(1)),
-		protocol.EXPECT().WriteString(context.Background(), "test"),
-		protocol.EXPECT().WriteFieldEnd(context.Background()),
-		protocol.EXPECT().WriteFieldStop(context.Background()),
-		protocol.EXPECT().WriteStructEnd(context.Background()),
-		protocol.EXPECT().WriteMessageEnd(context.Background()),
-		protocol.EXPECT().Flush(context.Background()),
-		protocol.EXPECT().ReadMessageBegin(context.Background()).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+		protocol.EXPECT().WriteMessageBegin(ctxMatcher{}, "testString", thrift.CALL, int32(1)),
+		protocol.EXPECT().WriteStructBegin(ctxMatcher{}, "testString_args"),
+		protocol.EXPECT().WriteFieldBegin(ctxMatcher{}, "s", thrift.TType(thrift.STRING), int16(1)),
+		protocol.EXPECT().WriteString(ctxMatcher{}, "test"),
+		protocol.EXPECT().WriteFieldEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteFieldStop(ctxMatcher{}),
+		protocol.EXPECT().WriteStructEnd(ctxMatcher{}),
+		protocol.EXPECT().WriteMessageEnd(ctxMatcher{}),
+		protocol.EXPECT().Flush(ctxMatcher{}),
+		protocol.EXPECT().ReadMessageBegin(ctxMatcher{}).Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
 	)
 
 	client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
diff --git c/lib/go/test/tests/optional_fields_test.go i/lib/go/test/tests/optional_fields_test.go
index 0d8e7396b..25ed15885 100644
--- c/lib/go/test/tests/optional_fields_test.go
+++ i/lib/go/test/tests/optional_fields_test.go
@@ -188,9 +188,9 @@ func TestNoOptionalUnsetFieldsOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.Write(context.Background(), proto)
@@ -201,9 +201,9 @@ func TestNoSetToDefaultFieldsOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.I = 42
@@ -216,12 +216,12 @@ func TestOneISetFieldOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldBegin(context.Background(), "i", thrift.TType(thrift.I64), int16(2)).Return(nil),
-		proto.EXPECT().WriteI64(context.Background(), int64(123)).Return(nil),
-		proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldBegin(ctxMatcher{}, "i", thrift.TType(thrift.I64), int16(2)).Return(nil),
+		proto.EXPECT().WriteI64(ctxMatcher{}, int64(123)).Return(nil),
+		proto.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.I = 123
@@ -233,15 +233,15 @@ func TestOneLSetFieldOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldBegin(context.Background(), "l", thrift.TType(thrift.LIST), int16(9)).Return(nil),
-		proto.EXPECT().WriteListBegin(context.Background(), thrift.TType(thrift.I64), 2).Return(nil),
-		proto.EXPECT().WriteI64(context.Background(), int64(1)).Return(nil),
-		proto.EXPECT().WriteI64(context.Background(), int64(2)).Return(nil),
-		proto.EXPECT().WriteListEnd(context.Background()).Return(nil),
-		proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldBegin(ctxMatcher{}, "l", thrift.TType(thrift.LIST), int16(9)).Return(nil),
+		proto.EXPECT().WriteListBegin(ctxMatcher{}, thrift.TType(thrift.I64), 2).Return(nil),
+		proto.EXPECT().WriteI64(ctxMatcher{}, int64(1)).Return(nil),
+		proto.EXPECT().WriteI64(ctxMatcher{}, int64(2)).Return(nil),
+		proto.EXPECT().WriteListEnd(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.L = []int64{1, 2}
@@ -253,12 +253,12 @@ func TestOneBinSetFieldOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldBegin(context.Background(), "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
-		proto.EXPECT().WriteBinary(context.Background(), []byte("somebytestring")).Return(nil),
-		proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldBegin(ctxMatcher{}, "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
+		proto.EXPECT().WriteBinary(ctxMatcher{}, []byte("somebytestring")).Return(nil),
+		proto.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.Bin = []byte("somebytestring")
@@ -270,12 +270,12 @@ func TestOneEmptyBinSetFieldOnWire(t *testing.T) {
 	defer mockCtrl.Finish()
 	proto := NewMockTProtocol(mockCtrl)
 	gomock.InOrder(
-		proto.EXPECT().WriteStructBegin(context.Background(), "all_optional").Return(nil),
-		proto.EXPECT().WriteFieldBegin(context.Background(), "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
-		proto.EXPECT().WriteBinary(context.Background(), []byte{}).Return(nil),
-		proto.EXPECT().WriteFieldEnd(context.Background()).Return(nil),
-		proto.EXPECT().WriteFieldStop(context.Background()).Return(nil),
-		proto.EXPECT().WriteStructEnd(context.Background()).Return(nil),
+		proto.EXPECT().WriteStructBegin(ctxMatcher{}, "all_optional").Return(nil),
+		proto.EXPECT().WriteFieldBegin(ctxMatcher{}, "bin", thrift.TType(thrift.STRING), int16(13)).Return(nil),
+		proto.EXPECT().WriteBinary(ctxMatcher{}, []byte{}).Return(nil),
+		proto.EXPECT().WriteFieldEnd(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteFieldStop(ctxMatcher{}).Return(nil),
+		proto.EXPECT().WriteStructEnd(ctxMatcher{}).Return(nil),
 	)
 	ao := optionalfieldstest.NewAllOptional()
 	ao.Bin = []byte{}
diff --git c/lib/go/test/tests/protocol_mock_helper.go i/lib/go/test/tests/protocol_mock_helper.go
new file mode 100644
index 000000000..19b9264ce
--- /dev/null
+++ i/lib/go/test/tests/protocol_mock_helper.go
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package tests
+
+import (
+	"context"
+
+	gomock "github.com/golang/mock/gomock"
+)
+
+// ctxMatcher implements gomock.Matcher to match any context.Context value.
+type ctxMatcher struct{}
+
+var _ gomock.Matcher = ctxMatcher{}
+
+func (ctxMatcher) Matches(x any) bool {
+	_, ok := x.(context.Context)
+	return ok
+}
+
+func (ctxMatcher) String() string {
+	return "any context.Context value"
+}
diff --git c/lib/go/test/tests/required_fields_test.go i/lib/go/test/tests/required_fields_test.go
index da80f9be1..4d7d499a9 100644
--- c/lib/go/test/tests/required_fields_test.go
+++ i/lib/go/test/tests/required_fields_test.go
@@ -68,9 +68,9 @@ func TestStructReadRequiredFields(t *testing.T) {
 
 	// None of required fields are set
 	gomock.InOrder(
-		protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
-		protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
+		protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(nil),
 	)
 
 	err := testStruct.Read(context.Background(), protocol)
@@ -89,12 +89,12 @@ func TestStructReadRequiredFields(t *testing.T) {
 
 	// One of the required fields is set
 	gomock.InOrder(
-		protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("I", thrift.TType(thrift.I32), int16(2), nil),
-		protocol.EXPECT().ReadI32(context.Background()).Return(int32(1), nil),
-		protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
-		protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
+		protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("I", thrift.TType(thrift.I32), int16(2), nil),
+		protocol.EXPECT().ReadI32(ctxMatcher{}).Return(int32(1), nil),
+		protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(nil),
 	)
 
 	err = testStruct.Read(context.Background(), protocol)
@@ -113,15 +113,15 @@ func TestStructReadRequiredFields(t *testing.T) {
 
 	// Both of the required fields are set
 	gomock.InOrder(
-		protocol.EXPECT().ReadStructBegin(context.Background()).Return("StructC", nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("i", thrift.TType(thrift.I32), int16(2), nil),
-		protocol.EXPECT().ReadI32(context.Background()).Return(int32(1), nil),
-		protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
-		protocol.EXPECT().ReadString(context.Background()).Return("test", nil),
-		protocol.EXPECT().ReadFieldEnd(context.Background()).Return(nil),
-		protocol.EXPECT().ReadFieldBegin(context.Background()).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
-		protocol.EXPECT().ReadStructEnd(context.Background()).Return(nil),
+		protocol.EXPECT().ReadStructBegin(ctxMatcher{}).Return("StructC", nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("i", thrift.TType(thrift.I32), int16(2), nil),
+		protocol.EXPECT().ReadI32(ctxMatcher{}).Return(int32(1), nil),
+		protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("s2", thrift.TType(thrift.STRING), int16(4), nil),
+		protocol.EXPECT().ReadString(ctxMatcher{}).Return("test", nil),
+		protocol.EXPECT().ReadFieldEnd(ctxMatcher{}).Return(nil),
+		protocol.EXPECT().ReadFieldBegin(ctxMatcher{}).Return("_", thrift.TType(thrift.STOP), int16(1), nil),
+		protocol.EXPECT().ReadStructEnd(ctxMatcher{}).Return(nil),
 	)
 
 	err = testStruct.Read(context.Background(), protocol)

@Jens-G Jens-G marked this pull request as draft May 28, 2026 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

compiler golang Pull requests that update Go code

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants