From c77f15ac7640c11914ef91f0c5097e9c81a35f1f Mon Sep 17 00:00:00 2001 From: Timo Stamm Date: Wed, 11 Jun 2025 13:36:12 +0200 Subject: [PATCH] Raise CompilationError for unknown field names in MessageOneofRule.fields --- builder.go | 12 +++++++++++- message_oneof.go | 11 +++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/builder.go b/builder.go index cb6ed4fc..d6d6a36a 100644 --- a/builder.go +++ b/builder.go @@ -129,8 +129,18 @@ func (bldr *builder) buildMessage( oneofRules := msgRules.GetOneof() for _, rule := range oneofRules { + fdescs := make([]protoreflect.FieldDescriptor, 0, len(rule.GetFields())) + for _, name := range rule.GetFields() { + fdesc := desc.Fields().ByName(protoreflect.Name(name)) + if fdesc == nil { + msgEval.Err = &CompilationError{cause: fmt.Errorf( + "field %q not found in message %s", name, desc.FullName())} + } else { + fdescs = append(fdescs, fdesc) + } + } oneofEval := &oneofEvaluator{ - Fields: rule.GetFields(), + Fields: fdescs, Required: rule.GetRequired(), } msgEval.AppendNested(oneofEval) diff --git a/message_oneof.go b/message_oneof.go index ec748d3f..ba7c3865 100644 --- a/message_oneof.go +++ b/message_oneof.go @@ -27,14 +27,14 @@ import ( // fields, ensuring that only one is set. If `required` is true, it enforces that one of // the fields _must_ be set. type oneofEvaluator struct { - Fields []string + Fields []protoreflect.FieldDescriptor Required bool } func (o oneofEvaluator) formatFields() string { quoted := make([]string, len(o.Fields)) - for idx, val := range o.Fields { - quoted[idx] = fmt.Sprintf("'%s'", val) + for idx, fdesc := range o.Fields { + quoted[idx] = fmt.Sprintf("'%s'", fdesc.Name()) } return fmt.Sprintf("[%s]", strings.Join(quoted, ", ")) } @@ -50,9 +50,8 @@ func (o oneofEvaluator) EvaluateMessage(msg protoreflect.Message, cfg *validatio err := &ValidationError{} if len(o.Fields) > 0 { count := 0 - for _, v := range o.Fields { - fd := msg.Descriptor().Fields().ByName(protoreflect.Name(v)) - if fd != nil && msg.Has(fd) { + for _, fdesc := range o.Fields { + if msg.Has(fdesc) { count++ } }