diff --git a/tools/BUILD b/tools/BUILD index c4b957fff..26956df59 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -121,3 +121,32 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "descriptor_pool_builder", + srcs = ["descriptor_pool_builder.cc"], + hdrs = ["descriptor_pool_builder.h"], + deps = [ + "//common:minimal_descriptor_database", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "descriptor_pool_builder_test", + srcs = ["descriptor_pool_builder_test.cc"], + deps = [ + ":descriptor_pool_builder", + "//internal:testing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/tools/descriptor_pool_builder.cc b/tools/descriptor_pool_builder.cc new file mode 100644 index 000000000..a0ca44442 --- /dev/null +++ b/tools/descriptor_pool_builder.cc @@ -0,0 +1,111 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/descriptor_pool_builder.h" + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "common/minimal_descriptor_database.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +namespace { + +absl::Status FindDeps( + std::vector& to_resolve, + absl::flat_hash_set& resolved, + DescriptorPoolBuilder& builder) { + while (!to_resolve.empty()) { + const auto* file = to_resolve.back(); + to_resolve.pop_back(); + if (resolved.contains(file)) { + continue; + } + google::protobuf::FileDescriptorProto file_proto; + file->CopyTo(&file_proto); + // Note: order doesn't matter here as long as all the cross references are + // correct in the final database. + CEL_RETURN_IF_ERROR(builder.AddFileDescriptor(file_proto)); + resolved.insert(file); + for (int i = 0; i < file->dependency_count(); ++i) { + to_resolve.push_back(file->dependency(i)); + } + } + return absl::OkStatus(); +} + +} // namespace + +DescriptorPoolBuilder::StateHolder::StateHolder( + google::protobuf::DescriptorDatabase* base) + : base(base), merged(base, &extensions), pool(&merged) {} + +DescriptorPoolBuilder::DescriptorPoolBuilder() + : state_(std::make_shared( + cel::GetMinimalDescriptorDatabase())) {} + +std::shared_ptr +DescriptorPoolBuilder::Build() && { + auto alias = + std::shared_ptr(state_, &state_->pool); + state_.reset(); + return alias; +} + +absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( + absl::Nonnull desc) { + absl::flat_hash_set resolved; + std::vector to_resolve{desc->file()}; + return FindDeps(to_resolve, resolved, *this); +} + +absl::Status DescriptorPoolBuilder::AddTransitiveDescriptorSet( + absl::Span> descs) { + absl::flat_hash_set resolved; + std::vector> to_resolve; + to_resolve.reserve(descs.size()); + for (const google::protobuf::Descriptor* desc : descs) { + to_resolve.push_back(desc->file()); + } + + return FindDeps(to_resolve, resolved, *this); +} + +absl::Status DescriptorPoolBuilder::AddFileDescriptor( + const google::protobuf::FileDescriptorProto& file) { + if (!state_->extensions.Add(file)) { + return absl::InvalidArgumentError( + absl::StrCat("proto descriptor conflict: ", file.name())); + } + return absl::OkStatus(); +} + +absl::Status DescriptorPoolBuilder::AddFileDescriptorSet( + const google::protobuf::FileDescriptorSet& file) { + for (const auto& file : file.file()) { + CEL_RETURN_IF_ERROR(AddFileDescriptor(file)); + } + return absl::OkStatus(); +} + +} // namespace cel diff --git a/tools/descriptor_pool_builder.h b/tools/descriptor_pool_builder.h new file mode 100644 index 000000000..ad2ec75da --- /dev/null +++ b/tools/descriptor_pool_builder.h @@ -0,0 +1,93 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ + +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/descriptor_database.h" + +namespace cel { + +// A helper class for building a descriptor pool from a set proto file +// descriptors. Manages lifetime for the descriptor databases backing +// the pool. +// +// Client must ensure that types are not added multiple times. +// +// Note: in the constructed pool, the definitions for the required types for +// CEL will shadow any added to the builder. Clients should not modify types +// from the google.protobuf package in general, but if they do the behavior of +// the constructed descriptor pool will be inconsistent. +class DescriptorPoolBuilder { + public: + DescriptorPoolBuilder(); + + DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&) = delete; + DescriptorPoolBuilder(const DescriptorPoolBuilder&) = delete; + DescriptorPoolBuilder& operator=(const DescriptorPoolBuilder&&) = delete; + DescriptorPoolBuilder(DescriptorPoolBuilder&&) = delete; + + ~DescriptorPoolBuilder() = default; + + // Returns a shared pointer to the new descriptor pool that manages the + // underlying descriptor databases backing the pool. + // + // Consumes the builder instance. It is unsafe to make any further changes + // to the descriptor databases after accessing the pool. + std::shared_ptr Build() &&; + + // Utility for adding the transitive dependencies of a message with a linked + // descriptor. + absl::Status AddTransitiveDescriptorSet( + absl::Nonnull desc); + + absl::Status AddTransitiveDescriptorSet( + absl::Span>); + + // Adds a file descriptor set to the pool. Client must ensure that all + // dependencies are satisfied and that files are not added multiple times. + absl::Status AddFileDescriptorSet(const google::protobuf::FileDescriptorSet& files); + + // Adds a single proto file descriptor set to the pool. Client must ensure + // that all dependencies are satisfied and that files are not added multiple + // times. + absl::Status AddFileDescriptor(const google::protobuf::FileDescriptorProto& file); + + private: + struct StateHolder { + explicit StateHolder(google::protobuf::DescriptorDatabase* base); + + google::protobuf::DescriptorDatabase* base; + google::protobuf::SimpleDescriptorDatabase extensions; + google::protobuf::MergedDescriptorDatabase merged; + google::protobuf::DescriptorPool pool; + }; + + explicit DescriptorPoolBuilder(std::shared_ptr state) + : state_(std::move(state)) {} + + std::shared_ptr state_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_DESCRIPTOR_POOL_BUILDER_H_ diff --git a/tools/descriptor_pool_builder_test.cc b/tools/descriptor_pool_builder_test.cc new file mode 100644 index 000000000..82fa8f699 --- /dev/null +++ b/tools/descriptor_pool_builder_test.cc @@ -0,0 +1,177 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/descriptor_pool_builder.h" + +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "internal/testing.h" +#include "cel/expr/conformance/proto2/test_all_types.pb.h" +#include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" +#include "google/protobuf/text_format.h" + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::testing::IsNull; +using ::testing::NotNull; + +namespace cel { +namespace { + +TEST(DescriptorPoolBuilderTest, IncludesDefaults) { + DescriptorPoolBuilder builder; + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + IsNull()); + + EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Timestamp"), + NotNull()); + EXPECT_THAT(pool->FindMessageTypeByName("google.protobuf.Any"), NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSet) { + DescriptorPoolBuilder builder; + ASSERT_THAT(builder.AddTransitiveDescriptorSet( + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + descriptor()), + IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddTransitiveDescriptorSetSpan) { + DescriptorPoolBuilder builder; + const google::protobuf::Descriptor* descs[] = { + cel::expr::conformance::proto2::TestAllTypes::descriptor(), + cel::expr::conformance::proto2::Proto2ExtensionScopedMessage:: + descriptor()}; + ASSERT_THAT(builder.AddTransitiveDescriptorSet(descs), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT( + pool->FindMessageTypeByName("cel.expr.conformance.proto2.TestAllTypes"), + NotNull()); +} + +TEST(DescriptorPoolBuilderTest, AddFileDescriptorSet) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorSet file_set; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "foo.proto" + package: "cel.test" + dependency: "bar.proto" + message_type { + name: "Foo" + field: { + name: "bar" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".cel.test.Bar" + } + } + )pb", + file_set.add_file())); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "bar.proto" + package: "cel.test" + message_type { + name: "Bar" + field: { + name: "baz" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + )pb", + file_set.add_file())); + ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), NotNull()); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); +} + +TEST(DescriptorPoolBuilderTest, BadRef) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorSet file_set; + // Unfulfilled dependency. + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "foo.proto" + package: "cel.test" + dependency: "bar.proto" + message_type { + name: "Foo" + field: { + name: "bar" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".cel.test.Bar" + } + } + )pb", + file_set.add_file())); + // Note: descriptor pool is initialized lazily so this will not lead to an + // error now, but looking up the message will fail. + ASSERT_THAT(builder.AddFileDescriptorSet(file_set), IsOk()); + + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Foo"), IsNull()); +} + +TEST(DescriptorPoolBuilderTest, AddFile) { + DescriptorPoolBuilder builder; + google::protobuf::FileDescriptorProto file; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + name: "bar.proto" + package: "cel.test" + message_type { + name: "Bar" + field: { + name: "baz" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + )pb", + &file)); + + ASSERT_THAT(builder.AddFileDescriptor(file), IsOk()); + // Duplicate file. + ASSERT_THAT(builder.AddFileDescriptor(file), + StatusIs(absl::StatusCode::kInvalidArgument)); + + // In this specific case, we know that the duplicate is the same so + // the pool will still be valid. + auto pool = std::move(builder).Build(); + EXPECT_THAT(pool->FindMessageTypeByName("cel.test.Bar"), NotNull()); +} + +} // namespace +} // namespace cel