diff --git a/common/testing/protorequire/require.go b/common/testing/protorequire/require.go index 5ad0286a1f6..f973561ea8d 100644 --- a/common/testing/protorequire/require.go +++ b/common/testing/protorequire/require.go @@ -1,9 +1,14 @@ package protorequire import ( + "fmt" + + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "go.temporal.io/server/common/testing/protoassert" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/testing/protocmp" ) type helper interface { @@ -27,6 +32,17 @@ func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message) { } } +// ProtoEqualIgnoreFields compares two proto messages for equality, ignoring the specified fields on the given message +// type. Fields are specified by their proto name (snake_case). +func ProtoEqualIgnoreFields(t require.TestingT, a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) { + if th, ok := t.(helper); ok { + th.Helper() + } + if diff := cmp.Diff(a, b, protocmp.Transform(), protocmp.IgnoreFields(msgType, fields...)); diff != "" { + require.Fail(t, fmt.Sprintf("Proto mismatch (-want +got):\n%v", diff)) + } +} + func NotProtoEqual(t require.TestingT, a proto.Message, b proto.Message) { if th, ok := t.(helper); ok { th.Helper() @@ -83,3 +99,10 @@ func (x ProtoAssertions) ProtoElementsMatch(a any, b any) bool { return protoassert.ProtoElementsMatch(x.t, a, b) } + +func (x ProtoAssertions) ProtoEqualIgnoreFields(a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) { + if th, ok := x.t.(helper); ok { + th.Helper() + } + ProtoEqualIgnoreFields(x.t, a, b, msgType, fields...) +} diff --git a/common/testing/protorequire/require_test.go b/common/testing/protorequire/require_test.go new file mode 100644 index 00000000000..ece41854bb9 --- /dev/null +++ b/common/testing/protorequire/require_test.go @@ -0,0 +1,37 @@ +package protorequire_test + +import ( + "testing" + + commonpb "go.temporal.io/api/common/v1" + workflowpb "go.temporal.io/api/workflow/v1" + "go.temporal.io/server/common/testing/protorequire" +) + +const myUUID = "deb7b204-b384-4fde-85c6-e5a56c42336a" + +func TestProtoEqualIgnoreFields(t *testing.T) { + a := &workflowpb.WorkflowExecutionInfo{ + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: myUUID, + }, + Status: 1, + TaskQueue: "queue-a", + } + b := &workflowpb.WorkflowExecutionInfo{ + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: myUUID, + }, + Status: 2, + TaskQueue: "queue-b", + } + + // Should pass: both differing fields are ignored + protorequire.ProtoEqualIgnoreFields(t, a, b, + &workflowpb.WorkflowExecutionInfo{}, + "status", + "task_queue", + ) +} diff --git a/tests/standalone_activity_test.go b/tests/standalone_activity_test.go index e04cf84f26c..f2ef4203d44 100644 --- a/tests/standalone_activity_test.go +++ b/tests/standalone_activity_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" activitypb "go.temporal.io/api/activity/v1" @@ -29,7 +28,6 @@ import ( "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/tests/testcore" "google.golang.org/grpc/codes" - "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -2866,16 +2864,13 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_NoWait() { UserMetadata: defaultUserMetadata, } - diff := cmp.Diff(expected, respInfo, - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "schedule_time", - "state_transition_count", - ), + // Ignore non-deterministic fields. Validated separately. + protorequire.ProtoEqualIgnoreFields(t, expected, respInfo, + &activitypb.ActivityExecutionInfo{}, + "execution_duration", + "schedule_time", + "state_transition_count", ) - require.Empty(t, diff) require.Equal(t, respInfo.GetExecutionDuration().AsDuration(), time.Duration(0)) // Never completed, so expect 0 require.Nil(t, describeResp.GetInfo().GetCloseTime()) require.Positive(t, respInfo.GetScheduleTime().AsTime().Unix()) @@ -2927,16 +2922,13 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING, TaskQueue: taskQueue.Name, } - diff := cmp.Diff(expected, firstDescribeResp.GetInfo(), - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "schedule_time", - "state_transition_count", - ), + // Ignore non-deterministic fields. Validated separately. + protorequire.ProtoEqualIgnoreFields(t, expected, firstDescribeResp.GetInfo(), + &activitypb.ActivityExecutionInfo{}, + "execution_duration", + "schedule_time", + "state_transition_count", ) - require.Empty(t, diff) taskQueuePollErr := make(chan error, 1) activityPollDone := make(chan struct{}) @@ -2985,17 +2977,14 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING, TaskQueue: taskQueue.Name, } - diff := cmp.Diff(expected, describeResp.GetInfo(), - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "last_started_time", - "schedule_time", - "state_transition_count", - ), + // Ignore non-deterministic fields. Validated separately. + protorequire.ProtoEqualIgnoreFields(t, expected, describeResp.GetInfo(), + &activitypb.ActivityExecutionInfo{}, + "execution_duration", + "last_started_time", + "schedule_time", + "state_transition_count", ) - require.Empty(t, diff) protorequire.ProtoEqual(t, defaultInput, describeResp.Input)