diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index 8491704..9fddc70 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -184,14 +184,13 @@ cc_test( srcs = ["client_test.cc"], deps = [ ":client_cc", - ":client_cxx", "@googletest//:gtest_main", "@abseil-cpp//absl/status", - "@cxx.rs//:core", "//ffi_utils:status_matchers", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:input_spec_cc_proto", "//willow/src/input_encoding:codec", "//willow/src/testing_utils:shell_testing_decryptor_cc", + "//willow/src/testing_utils:testing_utils_cc", ], ) diff --git a/willow/src/api/client_test.cc b/willow/src/api/client_test.cc index d23af77..4e4c7da 100644 --- a/willow/src/api/client_test.cc +++ b/willow/src/api/client_test.cc @@ -18,7 +18,6 @@ #include #include -#include #include "absl/status/status.h" #include "ffi_utils/status_matchers.h" @@ -29,6 +28,7 @@ #include "willow/src/input_encoding/codec.h" #include "willow/src/input_encoding/codec_factory.h" #include "willow/src/testing_utils/shell_testing_decryptor.h" +#include "willow/src/testing_utils/testing_utils.h" namespace secure_aggregation { namespace willow { @@ -41,49 +41,13 @@ using ::testing::ElementsAreArray; using ::testing::Pair; using ::testing::UnorderedElementsAre; -AggregationConfigProto CreateTestConfig() { - AggregationConfigProto config; - VectorConfig vector_config; - vector_config.set_length(8); // 4 countries x 2 languages - vector_config.set_bound(100); - (*config.mutable_vector_configs())["metric1"] = vector_config; - config.set_max_number_of_decryptors(1); - config.set_max_number_of_clients(10); - config.set_key_id("test"); - return config; -} - TEST(WillowShellClientTest, InitializeAndGenerateContribution) { - AggregationConfigProto config = CreateTestConfig(); + AggregationConfigProto config = CreateTestAggregationConfigProto(); // Create and encode input. - MetricData metric_data; - metric_data["metric1"] = {10, 20, 5}; - GroupData group_by_data; - group_by_data["country"] = {"US", "CA", "US"}; - group_by_data["lang"] = {"en", "es", "es"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); - group_by_spec1->set_vector_name("country"); - group_by_spec1->set_data_type(InputSpec::STRING); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "GB"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "MX"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); - auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); - group_by_spec2->set_vector_name("lang"); - group_by_spec2->set_data_type(InputSpec::STRING); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "en"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "es"); + MetricData metric_data = CreateTestMetricData(); + GroupData group_by_data = CreateTestGroupData(); + InputSpec input_spec = CreateTestInputSpecProto(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, CodecFactory::CreateExplicitCodec(input_spec)); SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data, @@ -129,7 +93,7 @@ TEST(WillowShellClientTest, InitializeAndGenerateContribution) { } TEST(WillowShellClientTest, EmptyEncodedData) { - AggregationConfigProto config = CreateTestConfig(); + AggregationConfigProto config = CreateTestAggregationConfigProto(); // Create empty encoded data. EncodedData encoded_data; @@ -151,7 +115,7 @@ TEST(WillowShellClientTest, EmptyEncodedData) { TEST(WillowShellClientTest, InvalidAggregationConfig) { // Originally valid config. - AggregationConfigProto config = CreateTestConfig(); + AggregationConfigProto config = CreateTestAggregationConfigProto(); // Create encoded data directly. EncodedData encoded_data = {{"metric1", {0, 20, 0, 0, 0, 0, 10, 5}}}; diff --git a/willow/src/input_encoding/BUILD b/willow/src/input_encoding/BUILD index e3cf5ac..38d366c 100644 --- a/willow/src/input_encoding/BUILD +++ b/willow/src/input_encoding/BUILD @@ -51,5 +51,6 @@ cc_test( "@abseil-cpp//absl/status", "//ffi_utils:status_matchers", "//willow/proto/willow:input_spec_cc_proto", + "//willow/src/testing_utils:testing_utils_cc", ], ) diff --git a/willow/src/input_encoding/explicit_codec_test.cc b/willow/src/input_encoding/explicit_codec_test.cc index 9c8a341..f5ae9af 100644 --- a/willow/src/input_encoding/explicit_codec_test.cc +++ b/willow/src/input_encoding/explicit_codec_test.cc @@ -24,6 +24,7 @@ #include "willow/proto/willow/input_spec.pb.h" #include "willow/src/input_encoding/codec.h" #include "willow/src/input_encoding/codec_factory.h" +#include "willow/src/testing_utils/testing_utils.h" namespace secure_aggregation { namespace willow { @@ -298,33 +299,9 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { } TEST(CodecFactoryTest, EncodeSimpleGroupBy) { - MetricData metric_data; - metric_data["metric1"] = {10, 20, 5}; - GroupData group_by_data; - group_by_data["country"] = {"US", "CA", "US"}; - group_by_data["lang"] = {"en", "es", "es"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); - group_by_spec1->set_vector_name("country"); - group_by_spec1->set_data_type(InputSpec::STRING); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "GB"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "MX"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); - auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); - group_by_spec2->set_vector_name("lang"); - group_by_spec2->set_data_type(InputSpec::STRING); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "en"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "es"); + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data = CreateTestMetricData(); + GroupData group_by_data = CreateTestGroupData(); // group_by keys are sorted: "country", "lang" // value_to_index_maps["country"]: {"CA":0, "GB":1, "MX":2, "US":3} @@ -402,33 +379,9 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) { } TEST(CodecFactoryTest, EncodeThenDecode) { - MetricData metric_data; - metric_data["metric1"] = {10, 20, 5}; - GroupData group_by_data; - group_by_data["country"] = {"US", "CA", "US"}; - group_by_data["lang"] = {"en", "es", "es"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); - group_by_spec1->set_vector_name("country"); - group_by_spec1->set_data_type(InputSpec::STRING); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "GB"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "MX"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); - auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); - group_by_spec2->set_vector_name("lang"); - group_by_spec2->set_data_type(InputSpec::STRING); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "en"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "es"); + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data = CreateTestMetricData(); + GroupData group_by_data = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder, CodecFactory::CreateExplicitCodec(input_spec)); @@ -457,33 +410,9 @@ TEST(CodecFactoryTest, EncodeThenDecode) { } TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) { - MetricData metric_data1; - metric_data1["metric1"] = {10, 20, 5}; - GroupData group_by_data1; - group_by_data1["lang"] = {"en", "es", "es"}; - group_by_data1["country"] = {"US", "CA", "US"}; - InputSpec input_spec; - auto* metric_spec = input_spec.add_metric_vector_specs(); - metric_spec->set_vector_name("metric1"); - metric_spec->set_data_type(InputSpec::INT64); - auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); - group_by_spec1->set_vector_name("lang"); - group_by_spec1->set_data_type(InputSpec::STRING); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "en"); - group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( - "es"); - auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); - group_by_spec2->set_vector_name("country"); - group_by_spec2->set_data_type(InputSpec::STRING); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "CA"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "GB"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "MX"); - group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( - "US"); + InputSpec input_spec = CreateTestInputSpecProto(); + MetricData metric_data1 = CreateTestMetricData(); + GroupData group_by_data1 = CreateTestGroupData(); SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr encoder1, CodecFactory::CreateExplicitCodec(input_spec)); diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index f639f7d..f0fa03c 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -131,6 +131,7 @@ cc_library( name = "shell_testing_decryptor_cc", srcs = ["shell_testing_decryptor.cc"], hdrs = ["shell_testing_decryptor.h"], + visibility = ["//visibility:public"], deps = [ ":shell_testing_decryptor_cxx", "@abseil-cpp//absl/memory", @@ -157,3 +158,15 @@ cc_test( "//willow/proto/willow:aggregation_config_cc_proto", ], ) + +cc_library( + name = "testing_utils_cc", + srcs = ["testing_utils.cc"], + hdrs = ["testing_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:input_spec_cc_proto", + "//willow/src/input_encoding:codec", + ], +) diff --git a/willow/src/testing_utils/testing_utils.cc b/willow/src/testing_utils/testing_utils.cc new file mode 100644 index 0000000..8241c32 --- /dev/null +++ b/willow/src/testing_utils/testing_utils.cc @@ -0,0 +1,74 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/src/testing_utils/testing_utils.h" + +#include "willow/src/input_encoding/codec.h" + +namespace secure_aggregation::willow { + +AggregationConfigProto CreateTestAggregationConfigProto() { + AggregationConfigProto config; + VectorConfig vector_config; + vector_config.set_length(8); // 4 countries x 2 languages + vector_config.set_bound(100); + (*config.mutable_vector_configs())["metric1"] = vector_config; + config.set_max_number_of_decryptors(1); + config.set_max_number_of_clients(10); + config.set_key_id("test"); + return config; +} + +InputSpec CreateTestInputSpecProto() { + InputSpec input_spec; + auto* metric_spec = input_spec.add_metric_vector_specs(); + metric_spec->set_vector_name("metric1"); + metric_spec->set_data_type(InputSpec::INT64); + metric_spec->mutable_domain_spec()->mutable_interval()->set_min(0); + metric_spec->mutable_domain_spec()->mutable_interval()->set_max(100); + auto* group_by_spec1 = input_spec.add_group_by_vector_specs(); + group_by_spec1->set_vector_name("country"); + group_by_spec1->set_data_type(InputSpec::STRING); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "CA"); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "GB"); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "MX"); + group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values( + "US"); + auto* group_by_spec2 = input_spec.add_group_by_vector_specs(); + group_by_spec2->set_vector_name("lang"); + group_by_spec2->set_data_type(InputSpec::STRING); + group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( + "en"); + group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values( + "es"); + return input_spec; +} + +MetricData CreateTestMetricData() { + MetricData metric_data; + metric_data["metric1"] = {10, 20, 5}; + return metric_data; +} + +GroupData CreateTestGroupData() { + GroupData group_by_data; + group_by_data["country"] = {"US", "CA", "US"}; + group_by_data["lang"] = {"en", "es", "es"}; + return group_by_data; +} + +} // namespace secure_aggregation::willow diff --git a/willow/src/testing_utils/testing_utils.h b/willow/src/testing_utils/testing_utils.h new file mode 100644 index 0000000..2933628 --- /dev/null +++ b/willow/src/testing_utils/testing_utils.h @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_ + +#include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/input_spec.pb.h" +#include "willow/src/input_encoding/codec.h" + +namespace secure_aggregation::willow { + +// Returns a test AggregationConfigProto. +AggregationConfigProto CreateTestAggregationConfigProto(); + +// Returns a test InputSpec proto with one metric and two group-by vectors. +InputSpec CreateTestInputSpecProto(); + +// Returns a test MetricData with one metric. +MetricData CreateTestMetricData(); + +// Returns a test GroupData with two group-by vectors. +GroupData CreateTestGroupData(); + +} // namespace secure_aggregation::willow + +#endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_