Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions common/testing/protorequire/require.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package protorequire

import (
"fmt"

"github.com/stretchr/testify/require"
"go.temporal.io/server/common/testing/protoassert"
"google.golang.org/protobuf/proto"
Expand All @@ -22,7 +24,33 @@ func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message) {
if th, ok := t.(helper); ok {
th.Helper()
}
if !protoassert.ProtoEqual(t, a, b) {

// Find sentinels on the original message (pointer identity requires the
// original, not a clone).
paths := findSentinelPaths(a.ProtoReflect())
if len(paths) == 0 {
if !protoassert.ProtoEqual(t, a, b) {
t.FailNow()
}
return
}

// Validate sentinel fields against the actual message.
for _, p := range paths {
if !validateSentinel(b, p) {
require.Fail(t, fmt.Sprintf("expected field %q to be non-zero", p))
}
}

// Clone both messages, clear sentinel fields, then diff.
cleanA := proto.Clone(a)
cleanB := proto.Clone(b)
for _, p := range paths {
clearField(cleanA, p)
clearField(cleanB, p)
}

if !protoassert.ProtoEqual(t, cleanA, cleanB) {
t.FailNow()
}
}
Expand Down Expand Up @@ -53,9 +81,7 @@ func (x ProtoAssertions) ProtoEqual(a proto.Message, b proto.Message) {
if th, ok := x.t.(helper); ok {
th.Helper()
}
if !protoassert.ProtoEqual(x.t, a, b) {
x.t.FailNow()
}
ProtoEqual(x.t, a, b)
}

func (x ProtoAssertions) NotProtoEqual(a proto.Message, b proto.Message) {
Expand Down
137 changes: 137 additions & 0 deletions common/testing/protorequire/sentinel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package protorequire

import (
"strings"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
)

// Sentinel values — strings use a distinctive value unlikely to appear in real
// data; messages are identified by pointer address; integers use magic values.
var (
nonZeroString = "__protorequire_NonZero_sentinel__"
nonZeroTimestamp = &timestamppb.Timestamp{}
nonZeroDuration = &durationpb.Duration{}
// Arbitrary large negatives — not powers of 2 or boundary values,
// so they won't collide with real test data or edge-case tests.
nonZeroInt32 = int32(-479_001_599)
nonZeroInt64 = int64(-6_279_541_637_813)
)

// NonZero returns a sentinel value that can be used as a field value in an
// expected proto message. When [ProtoEqual] encounters this sentinel, it
// asserts that the corresponding field in the actual message is present and
// non-zero, rather than checking for exact equality.
//
// Example:
//
// expected := &pb.MyMessage{
// Name: "exact match",
// CreatedAt: protorequire.NonZero[*timestamppb.Timestamp](),
// ExternalId: protorequire.NonZero[string](),
// Count: protorequire.NonZero[int64](),
// Duration: protorequire.NonZero[*durationpb.Duration](),
// }
// protorequire.ProtoEqual(t, expected, actual)
func NonZero[T string | *timestamppb.Timestamp | *durationpb.Duration | int32 | int64]() T {
var zero T
switch any(zero).(type) {
case string:
return any(nonZeroString).(T)
case *timestamppb.Timestamp:
return any(nonZeroTimestamp).(T)
case *durationpb.Duration:
return any(nonZeroDuration).(T)
case int32:
return any(nonZeroInt32).(T)
case int64:
return any(nonZeroInt64).(T)
}
panic("unreachable")
}

func isMessageSentinel(m protoreflect.Message) bool {
switch v := m.Interface().(type) {
case *timestamppb.Timestamp:
return v == nonZeroTimestamp
case *durationpb.Duration:
return v == nonZeroDuration
}
return false
}

// fieldPath is a chain of field descriptors from root to the sentinel field.
type fieldPath []protoreflect.FieldDescriptor

func (p fieldPath) String() string {
var b strings.Builder
for i, fd := range p {
if i > 0 {
b.WriteByte('.')
}
b.WriteString(string(fd.Name()))
}
return b.String()
}

// findSentinelPaths walks the original (uncloned) message to find sentinel
// values using pointer identity (for strings and messages) or magic values
// (for integers). Returns the path to each sentinel field.
func findSentinelPaths(msg protoreflect.Message) []fieldPath {
return findSentinelPathsRecurse(msg, nil)
}

func findSentinelPathsRecurse(msg protoreflect.Message, prefix fieldPath) []fieldPath {
var paths []fieldPath
msg.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
current := append(append(fieldPath(nil), prefix...), fd)
switch fd.Kind() {
case protoreflect.StringKind:
if v.String() == nonZeroString {
paths = append(paths, current)
}
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if int32(v.Int()) == nonZeroInt32 {
paths = append(paths, current)
}
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
if v.Int() == nonZeroInt64 {
paths = append(paths, current)
}
case protoreflect.MessageKind:
if isMessageSentinel(v.Message()) {
paths = append(paths, current)
} else {
paths = append(paths, findSentinelPathsRecurse(v.Message(), current)...)
}
}
return true
})
return paths
}

// validateSentinel checks that the actual message has the field at path set
// to a non-zero value. Uses read-only traversal to avoid mutating actual.
func validateSentinel(actual proto.Message, path fieldPath) bool {
msg := actual.ProtoReflect()
for _, fd := range path[:len(path)-1] {
if !msg.Has(fd) {
return false
}
msg = msg.Get(fd).Message()
}
return msg.Has(path[len(path)-1])
}

// clearField clears the leaf field at path in the given message.
// Uses Mutable to traverse — only safe on cloned messages.
func clearField(msg proto.Message, path fieldPath) {
m := msg.ProtoReflect()
for _, fd := range path[:len(path)-1] {
m = m.Mutable(fd).Message()
}
m.Clear(path[len(path)-1])
}
139 changes: 139 additions & 0 deletions common/testing/protorequire/sentinel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package protorequire_test

import (
"testing"

"github.com/stretchr/testify/require"
commonpb "go.temporal.io/api/common/v1"
workflowpb "go.temporal.io/api/workflow/v1"
"go.temporal.io/server/common/testing/protorequire"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
)

func TestNonZero_String(t *testing.T) {
expected := &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: protorequire.NonZero[string](),
}

t.Run("passes when set", func(t *testing.T) {
actual := &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: "some-run-id",
}
protorequire.ProtoEqual(t, expected, actual)
})

t.Run("fails when empty", func(t *testing.T) {
mockT := &mockTestingT{}
actual := &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: "",
}
protorequire.ProtoEqual(mockT, expected, actual)
require.True(t, mockT.failed, "expected ProtoEqual to fail for empty string field")
})
}

func TestNonZero_Timestamp(t *testing.T) {
expected := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
StartTime: protorequire.NonZero[*timestamppb.Timestamp](),
}

t.Run("passes when set", func(t *testing.T) {
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
StartTime: timestamppb.Now(),
}
protorequire.ProtoEqual(t, expected, actual)
})

t.Run("fails when nil", func(t *testing.T) {
mockT := &mockTestingT{}
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
}
protorequire.ProtoEqual(mockT, expected, actual)
require.True(t, mockT.failed, "expected ProtoEqual to fail for nil timestamp field")
})
}

func TestNonZero_Duration(t *testing.T) {
expected := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
ExecutionDuration: protorequire.NonZero[*durationpb.Duration](),
}

t.Run("passes when set", func(t *testing.T) {
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
ExecutionDuration: durationpb.New(5000),
}
protorequire.ProtoEqual(t, expected, actual)
})

t.Run("fails when nil", func(t *testing.T) {
mockT := &mockTestingT{}
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
}
protorequire.ProtoEqual(mockT, expected, actual)
require.True(t, mockT.failed, "expected ProtoEqual to fail for nil duration field")
})
}

func TestNonZero_Int64(t *testing.T) {
expected := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
StateTransitionCount: protorequire.NonZero[int64](),
}

t.Run("passes when non-zero", func(t *testing.T) {
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
StateTransitionCount: 42,
}
protorequire.ProtoEqual(t, expected, actual)
})

t.Run("fails when zero", func(t *testing.T) {
mockT := &mockTestingT{}
actual := &workflowpb.WorkflowExecutionInfo{
TaskQueue: "queue-a",
}
protorequire.ProtoEqual(mockT, expected, actual)
require.True(t, mockT.failed, "expected ProtoEqual to fail for zero int64 field")
})
}

func TestNonZero_Mixed(t *testing.T) {
expected := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: protorequire.NonZero[string](),
},
TaskQueue: protorequire.NonZero[string](),
StartTime: protorequire.NonZero[*timestamppb.Timestamp](),
StateTransitionCount: protorequire.NonZero[int64](),
}
actual := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: "run-abc",
},
TaskQueue: "queue-a",
StartTime: timestamppb.Now(),
StateTransitionCount: 3,
}
protorequire.ProtoEqual(t, expected, actual)
}

// mockTestingT captures failures without stopping the test.
type mockTestingT struct {
failed bool
}

func (m *mockTestingT) Errorf(string, ...interface{}) { m.failed = true }
func (m *mockTestingT) FailNow() { m.failed = true }
Loading
Loading