From f55c8aadd87e1350a517a3dcef0e2a08c61dc060 Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Mon, 13 Apr 2026 18:55:54 -0700 Subject: [PATCH] add protorequire.NonZero --- common/testing/protorequire/require.go | 34 ++++- common/testing/protorequire/sentinel.go | 137 ++++++++++++++++++ common/testing/protorequire/sentinel_test.go | 139 +++++++++++++++++++ tests/standalone_activity_test.go | 45 ++---- 4 files changed, 316 insertions(+), 39 deletions(-) create mode 100644 common/testing/protorequire/sentinel.go create mode 100644 common/testing/protorequire/sentinel_test.go diff --git a/common/testing/protorequire/require.go b/common/testing/protorequire/require.go index 5ad0286a1f..d59d3656ea 100644 --- a/common/testing/protorequire/require.go +++ b/common/testing/protorequire/require.go @@ -1,6 +1,8 @@ package protorequire import ( + "fmt" + "github.com/stretchr/testify/require" "go.temporal.io/server/common/testing/protoassert" "google.golang.org/protobuf/proto" @@ -22,7 +24,33 @@ func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message) { if th, ok := t.(helper); ok { th.Helper() } - if !protoassert.ProtoEqual(t, a, b) { + + // Find sentinels on the original message (pointer identity requires the + // original, not a clone). + paths := findSentinelPaths(a.ProtoReflect()) + if len(paths) == 0 { + if !protoassert.ProtoEqual(t, a, b) { + t.FailNow() + } + return + } + + // Validate sentinel fields against the actual message. + for _, p := range paths { + if !validateSentinel(b, p) { + require.Fail(t, fmt.Sprintf("expected field %q to be non-zero", p)) + } + } + + // Clone both messages, clear sentinel fields, then diff. + cleanA := proto.Clone(a) + cleanB := proto.Clone(b) + for _, p := range paths { + clearField(cleanA, p) + clearField(cleanB, p) + } + + if !protoassert.ProtoEqual(t, cleanA, cleanB) { t.FailNow() } } @@ -53,9 +81,7 @@ func (x ProtoAssertions) ProtoEqual(a proto.Message, b proto.Message) { if th, ok := x.t.(helper); ok { th.Helper() } - if !protoassert.ProtoEqual(x.t, a, b) { - x.t.FailNow() - } + ProtoEqual(x.t, a, b) } func (x ProtoAssertions) NotProtoEqual(a proto.Message, b proto.Message) { diff --git a/common/testing/protorequire/sentinel.go b/common/testing/protorequire/sentinel.go new file mode 100644 index 0000000000..900af48732 --- /dev/null +++ b/common/testing/protorequire/sentinel.go @@ -0,0 +1,137 @@ +package protorequire + +import ( + "strings" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Sentinel values — strings use a distinctive value unlikely to appear in real +// data; messages are identified by pointer address; integers use magic values. +var ( + nonZeroString = "__protorequire_NonZero_sentinel__" + nonZeroTimestamp = ×tamppb.Timestamp{} + nonZeroDuration = &durationpb.Duration{} + // Arbitrary large negatives — not powers of 2 or boundary values, + // so they won't collide with real test data or edge-case tests. + nonZeroInt32 = int32(-479_001_599) + nonZeroInt64 = int64(-6_279_541_637_813) +) + +// NonZero returns a sentinel value that can be used as a field value in an +// expected proto message. When [ProtoEqual] encounters this sentinel, it +// asserts that the corresponding field in the actual message is present and +// non-zero, rather than checking for exact equality. +// +// Example: +// +// expected := &pb.MyMessage{ +// Name: "exact match", +// CreatedAt: protorequire.NonZero[*timestamppb.Timestamp](), +// ExternalId: protorequire.NonZero[string](), +// Count: protorequire.NonZero[int64](), +// Duration: protorequire.NonZero[*durationpb.Duration](), +// } +// protorequire.ProtoEqual(t, expected, actual) +func NonZero[T string | *timestamppb.Timestamp | *durationpb.Duration | int32 | int64]() T { + var zero T + switch any(zero).(type) { + case string: + return any(nonZeroString).(T) + case *timestamppb.Timestamp: + return any(nonZeroTimestamp).(T) + case *durationpb.Duration: + return any(nonZeroDuration).(T) + case int32: + return any(nonZeroInt32).(T) + case int64: + return any(nonZeroInt64).(T) + } + panic("unreachable") +} + +func isMessageSentinel(m protoreflect.Message) bool { + switch v := m.Interface().(type) { + case *timestamppb.Timestamp: + return v == nonZeroTimestamp + case *durationpb.Duration: + return v == nonZeroDuration + } + return false +} + +// fieldPath is a chain of field descriptors from root to the sentinel field. +type fieldPath []protoreflect.FieldDescriptor + +func (p fieldPath) String() string { + var b strings.Builder + for i, fd := range p { + if i > 0 { + b.WriteByte('.') + } + b.WriteString(string(fd.Name())) + } + return b.String() +} + +// findSentinelPaths walks the original (uncloned) message to find sentinel +// values using pointer identity (for strings and messages) or magic values +// (for integers). Returns the path to each sentinel field. +func findSentinelPaths(msg protoreflect.Message) []fieldPath { + return findSentinelPathsRecurse(msg, nil) +} + +func findSentinelPathsRecurse(msg protoreflect.Message, prefix fieldPath) []fieldPath { + var paths []fieldPath + msg.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + current := append(append(fieldPath(nil), prefix...), fd) + switch fd.Kind() { + case protoreflect.StringKind: + if v.String() == nonZeroString { + paths = append(paths, current) + } + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + if int32(v.Int()) == nonZeroInt32 { + paths = append(paths, current) + } + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + if v.Int() == nonZeroInt64 { + paths = append(paths, current) + } + case protoreflect.MessageKind: + if isMessageSentinel(v.Message()) { + paths = append(paths, current) + } else { + paths = append(paths, findSentinelPathsRecurse(v.Message(), current)...) + } + } + return true + }) + return paths +} + +// validateSentinel checks that the actual message has the field at path set +// to a non-zero value. Uses read-only traversal to avoid mutating actual. +func validateSentinel(actual proto.Message, path fieldPath) bool { + msg := actual.ProtoReflect() + for _, fd := range path[:len(path)-1] { + if !msg.Has(fd) { + return false + } + msg = msg.Get(fd).Message() + } + return msg.Has(path[len(path)-1]) +} + +// clearField clears the leaf field at path in the given message. +// Uses Mutable to traverse — only safe on cloned messages. +func clearField(msg proto.Message, path fieldPath) { + m := msg.ProtoReflect() + for _, fd := range path[:len(path)-1] { + m = m.Mutable(fd).Message() + } + m.Clear(path[len(path)-1]) +} diff --git a/common/testing/protorequire/sentinel_test.go b/common/testing/protorequire/sentinel_test.go new file mode 100644 index 0000000000..8d2628b2b7 --- /dev/null +++ b/common/testing/protorequire/sentinel_test.go @@ -0,0 +1,139 @@ +package protorequire_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + workflowpb "go.temporal.io/api/workflow/v1" + "go.temporal.io/server/common/testing/protorequire" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestNonZero_String(t *testing.T) { + expected := &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: protorequire.NonZero[string](), + } + + t.Run("passes when set", func(t *testing.T) { + actual := &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: "some-run-id", + } + protorequire.ProtoEqual(t, expected, actual) + }) + + t.Run("fails when empty", func(t *testing.T) { + mockT := &mockTestingT{} + actual := &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: "", + } + protorequire.ProtoEqual(mockT, expected, actual) + require.True(t, mockT.failed, "expected ProtoEqual to fail for empty string field") + }) +} + +func TestNonZero_Timestamp(t *testing.T) { + expected := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + StartTime: protorequire.NonZero[*timestamppb.Timestamp](), + } + + t.Run("passes when set", func(t *testing.T) { + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + StartTime: timestamppb.Now(), + } + protorequire.ProtoEqual(t, expected, actual) + }) + + t.Run("fails when nil", func(t *testing.T) { + mockT := &mockTestingT{} + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + } + protorequire.ProtoEqual(mockT, expected, actual) + require.True(t, mockT.failed, "expected ProtoEqual to fail for nil timestamp field") + }) +} + +func TestNonZero_Duration(t *testing.T) { + expected := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + ExecutionDuration: protorequire.NonZero[*durationpb.Duration](), + } + + t.Run("passes when set", func(t *testing.T) { + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + ExecutionDuration: durationpb.New(5000), + } + protorequire.ProtoEqual(t, expected, actual) + }) + + t.Run("fails when nil", func(t *testing.T) { + mockT := &mockTestingT{} + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + } + protorequire.ProtoEqual(mockT, expected, actual) + require.True(t, mockT.failed, "expected ProtoEqual to fail for nil duration field") + }) +} + +func TestNonZero_Int64(t *testing.T) { + expected := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + StateTransitionCount: protorequire.NonZero[int64](), + } + + t.Run("passes when non-zero", func(t *testing.T) { + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + StateTransitionCount: 42, + } + protorequire.ProtoEqual(t, expected, actual) + }) + + t.Run("fails when zero", func(t *testing.T) { + mockT := &mockTestingT{} + actual := &workflowpb.WorkflowExecutionInfo{ + TaskQueue: "queue-a", + } + protorequire.ProtoEqual(mockT, expected, actual) + require.True(t, mockT.failed, "expected ProtoEqual to fail for zero int64 field") + }) +} + +func TestNonZero_Mixed(t *testing.T) { + expected := &workflowpb.WorkflowExecutionInfo{ + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: protorequire.NonZero[string](), + }, + TaskQueue: protorequire.NonZero[string](), + StartTime: protorequire.NonZero[*timestamppb.Timestamp](), + StateTransitionCount: protorequire.NonZero[int64](), + } + actual := &workflowpb.WorkflowExecutionInfo{ + Execution: &commonpb.WorkflowExecution{ + WorkflowId: "wf-1", + RunId: "run-abc", + }, + TaskQueue: "queue-a", + StartTime: timestamppb.Now(), + StateTransitionCount: 3, + } + protorequire.ProtoEqual(t, expected, actual) +} + +// mockTestingT captures failures without stopping the test. +type mockTestingT struct { + failed bool +} + +func (m *mockTestingT) Errorf(string, ...interface{}) { m.failed = true } +func (m *mockTestingT) FailNow() { m.failed = true } diff --git a/tests/standalone_activity_test.go b/tests/standalone_activity_test.go index e04cf84f26..61d57dcb41 100644 --- a/tests/standalone_activity_test.go +++ b/tests/standalone_activity_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" activitypb "go.temporal.io/api/activity/v1" @@ -29,7 +28,6 @@ import ( "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/tests/testcore" "google.golang.org/grpc/codes" - "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -2857,29 +2855,20 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_NoWait() { RunId: startResp.RunId, RunState: enumspb.PENDING_ACTIVITY_STATE_SCHEDULED, Priority: startReq.GetPriority(), + ScheduleTime: protorequire.NonZero[*timestamppb.Timestamp](), ScheduleToCloseTimeout: startReq.GetScheduleToCloseTimeout(), ScheduleToStartTimeout: startReq.GetScheduleToStartTimeout(), StartToCloseTimeout: startReq.GetStartToCloseTimeout(), + StateTransitionCount: protorequire.NonZero[int64](), Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING, SearchAttributes: defaultSearchAttributes, TaskQueue: taskQueue.Name, UserMetadata: defaultUserMetadata, } - diff := cmp.Diff(expected, respInfo, - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "schedule_time", - "state_transition_count", - ), - ) - require.Empty(t, diff) + protorequire.ProtoEqual(t, expected, respInfo) require.Equal(t, respInfo.GetExecutionDuration().AsDuration(), time.Duration(0)) // Never completed, so expect 0 require.Nil(t, describeResp.GetInfo().GetCloseTime()) - require.Positive(t, respInfo.GetScheduleTime().AsTime().Unix()) - require.Positive(t, respInfo.GetStateTransitionCount()) protorequire.ProtoEqual(t, defaultInput, describeResp.Input) @@ -2920,23 +2909,16 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState RetryPolicy: defaultRetryPolicy, RunId: startResp.RunId, RunState: enumspb.PENDING_ACTIVITY_STATE_SCHEDULED, + ScheduleTime: protorequire.NonZero[*timestamppb.Timestamp](), SearchAttributes: &commonpb.SearchAttributes{}, ScheduleToCloseTimeout: durationpb.New(0), ScheduleToStartTimeout: durationpb.New(0), StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + StateTransitionCount: protorequire.NonZero[int64](), Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING, TaskQueue: taskQueue.Name, } - diff := cmp.Diff(expected, firstDescribeResp.GetInfo(), - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "schedule_time", - "state_transition_count", - ), - ) - require.Empty(t, diff) + protorequire.ProtoEqual(t, expected, firstDescribeResp.GetInfo()) taskQueuePollErr := make(chan error, 1) activityPollDone := make(chan struct{}) @@ -2974,28 +2956,21 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState ActivityType: s.tv.ActivityType(), Attempt: 1, HeartbeatTimeout: durationpb.New(0), + LastStartedTime: protorequire.NonZero[*timestamppb.Timestamp](), LastWorkerIdentity: defaultIdentity, RetryPolicy: defaultRetryPolicy, RunId: startResp.RunId, RunState: enumspb.PENDING_ACTIVITY_STATE_STARTED, + ScheduleTime: protorequire.NonZero[*timestamppb.Timestamp](), ScheduleToCloseTimeout: durationpb.New(0), ScheduleToStartTimeout: durationpb.New(0), SearchAttributes: &commonpb.SearchAttributes{}, StartToCloseTimeout: durationpb.New(defaultStartToCloseTimeout), + StateTransitionCount: protorequire.NonZero[int64](), Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING, TaskQueue: taskQueue.Name, } - diff := cmp.Diff(expected, describeResp.GetInfo(), - protocmp.Transform(), - // Ignore non-deterministic fields. Validated separately. - protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, - "execution_duration", - "last_started_time", - "schedule_time", - "state_transition_count", - ), - ) - require.Empty(t, diff) + protorequire.ProtoEqual(t, expected, describeResp.GetInfo()) protorequire.ProtoEqual(t, defaultInput, describeResp.Input)