diff --git a/MODULE.bazel b/MODULE.bazel index f0334b6..d9c2b58 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -33,4 +33,4 @@ bazel_dep(name = "abseil-cpp", version = "20240722.0", repo_name = "com_google_a bazel_dep(name = "cel-cpp", version = "0.11.0", repo_name = "com_google_cel_cpp") -bazel_dep(name = "protovalidate", version = "1.0.0-rc.1", repo_name = "com_github_bufbuild_protovalidate") +bazel_dep(name = "protovalidate", version = "1.0.0-rc.2", repo_name = "com_github_bufbuild_protovalidate") diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 8040281..4d26cab 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -184,8 +184,8 @@ "https://bcr.bazel.build/modules/protoc-gen-validate/1.0.4.bcr.2/MODULE.bazel": "c4bd2c850211ff5b7dadf9d2d0496c1c922fdedc303c775b01dfd3b3efc907ed", "https://bcr.bazel.build/modules/protoc-gen-validate/1.0.4.bcr.2/source.json": "4cc97f70b521890798058600a927ce4b0def8ee84ff2a5aa632aabcb4234aa0b", "https://bcr.bazel.build/modules/protoc-gen-validate/1.0.4/MODULE.bazel": "b8913c154b16177990f6126d2d2477d187f9ddc568e95ee3e2d50fc65d2c494a", - "https://bcr.bazel.build/modules/protovalidate/1.0.0-rc.1/MODULE.bazel": "97ac1e244c63781d40c38ce15ed9d2b1c923833ba2494d92fe80bf247d3fe30a", - "https://bcr.bazel.build/modules/protovalidate/1.0.0-rc.1/source.json": "e3d4894e0a64c5ce5ad38639c5468fb8dae6644a6cff2d77c576d3ad5159017a", + "https://bcr.bazel.build/modules/protovalidate/1.0.0-rc.2/MODULE.bazel": "1e267077e68ed12b555671fe9be7e3a8e78929723a2b91c181095b6d0a5cb7f1", + "https://bcr.bazel.build/modules/protovalidate/1.0.0-rc.2/source.json": "64f215d83a29846d17e1511fb5ed504b8fc5e8a8b61285fecc72abccdaeed08d", "https://bcr.bazel.build/modules/pybind11_bazel/2.11.1/MODULE.bazel": "88af1c246226d87e65be78ed49ecd1e6f5e98648558c14ce99176da041dc378e", "https://bcr.bazel.build/modules/pybind11_bazel/2.12.0/MODULE.bazel": "e6f4c20442eaa7c90d7190d8dc539d0ab422f95c65a57cc59562170c58ae3d34", "https://bcr.bazel.build/modules/pybind11_bazel/2.12.0/source.json": "6900fdc8a9e95866b8c0d4ad4aba4d4236317b5c1cd04c502df3f0d33afed680", @@ -509,7 +509,7 @@ "@@rules_buf+//buf:extensions.bzl%buf": { "general": { "bzlTransitiveDigest": "Cn52bY/1OxGKNv4TI5yExj2vL9hBgfRj4HOOcjWQTdQ=", - "usagesDigest": "LhpNrCpDf26/nlI1amvlckINVn131BBRmXO9ypWRwAE=", + "usagesDigest": "3eheAn/0yukeu67ELwwJYYb+QQDwGTOHqPLdmlfsmQg=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, diff --git a/buf/validate/internal/message_rules.cc b/buf/validate/internal/message_rules.cc index 2f4f4d1..b3eee4b 100644 --- a/buf/validate/internal/message_rules.cc +++ b/buf/validate/internal/message_rules.cc @@ -46,6 +46,19 @@ Rules NewMessageRules( return rules_or.status(); } result.emplace_back(std::move(rules_or).value()); + + // buf.validate.MessageRules.oneof + for (const auto& msgOneof : msgLvl.oneof()) { + std::vector fields; + for (const auto& name : msgOneof.fields()) { + auto fdesc = descriptor->FindFieldByName(name); + if (fdesc == nullptr) { + return absl::FailedPreconditionError(absl::StrCat("field \"", name, "\" not found in message ", descriptor->full_name())); + } + fields.push_back(fdesc); + } + result.emplace_back(std::make_unique(fields, msgOneof.required())); + } } for (int i = 0; i < descriptor->field_count(); i++) { diff --git a/buf/validate/internal/rules.cc b/buf/validate/internal/rules.cc index f12c0e9..818ff80 100644 --- a/buf/validate/internal/rules.cc +++ b/buf/validate/internal/rules.cc @@ -446,4 +446,33 @@ absl::Status OneofValidationRules::Validate( return absl::OkStatus(); } +absl::Status MessageOneofValidationRules::Validate( + RuleContext& ctx, const google::protobuf::Message& message) const { + int has_count = 0; + for (const auto& fdesc : fields_) { + if (message.GetReflection()->HasField(message, fdesc)) { + has_count++; + } + } + if (has_count > 1) { + Violation violation; + *violation.mutable_rule_id() = "message.oneof"; + *violation.mutable_message() = absl::StrCat("only one of ", field_names_(), " can be set"); + ctx.violations.emplace_back(std::move(violation), absl::nullopt, absl::nullopt); + } + if (required_ && has_count == 0) { + Violation violation; + *violation.mutable_rule_id() = "message.oneof"; + *violation.mutable_message() = absl::StrCat("one of ", field_names_(), " must be set"); + ctx.violations.emplace_back(std::move(violation), absl::nullopt, absl::nullopt); + } + return absl::OkStatus(); +} + +std::string MessageOneofValidationRules::field_names_() const { + return absl::StrJoin(fields_, ", ", [](std::string* out, const google::protobuf::FieldDescriptor *fdesc) { + absl::StrAppend(out, fdesc->name()); + }); +} + } // namespace buf::validate::internal diff --git a/buf/validate/internal/rules.h b/buf/validate/internal/rules.h index 926bfed..4100fcf 100644 --- a/buf/validate/internal/rules.h +++ b/buf/validate/internal/rules.h @@ -133,6 +133,21 @@ class OneofValidationRules : public ValidationRules { bool required_ = false; }; +class MessageOneofValidationRules : public ValidationRules { + using Base = ValidationRules; + +public: + MessageOneofValidationRules(const std::vector fields, const bool required) + : fields_(fields), required_(required) {} + + absl::Status Validate(RuleContext& ctx, const google::protobuf::Message& message) const override; + +private: + const std::vector fields_; + bool required_ = false; + std::string field_names_() const; +}; + // Creates a new expression builder suitable for creating rules. absl::StatusOr> NewRuleBuilder( google::protobuf::Arena* arena); diff --git a/deps/shared_deps.json b/deps/shared_deps.json index bd07466..b38aa75 100644 --- a/deps/shared_deps.json +++ b/deps/shared_deps.json @@ -38,10 +38,10 @@ }, "protovalidate": { "meta": { - "version": "1.0.0-rc.1" + "version": "1.0.0-rc.2" }, "source": { - "sha256": "f04ecb455e58172bcd7b011dac4ef35ffbe4f6acf0ec7021433bed4e2289de2b", + "sha256": "b379378e55d27111e5a3712c35aa126e58cbdd974ffcd3982c2d16f41bc01231", "strip_prefix": "protovalidate-{version}", "urls": [ "https://github.com/bufbuild/protovalidate/releases/download/v{version}/protovalidate-{version}.tar.gz"