Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,13 @@ cc_test(
srcs = ["client_test.cc"],
deps = [
":client_cc",
":client_cxx",
"@googletest//:gtest_main",
"@abseil-cpp//absl/status",
"@cxx.rs//:core",
"//ffi_utils:status_matchers",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:input_spec_cc_proto",
"//willow/src/input_encoding:codec",
"//willow/src/testing_utils:shell_testing_decryptor_cc",
"//willow/src/testing_utils:testing_utils_cc",
],
)
50 changes: 7 additions & 43 deletions willow/src/api/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

#include <memory>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "ffi_utils/status_matchers.h"
Expand All @@ -29,6 +28,7 @@
#include "willow/src/input_encoding/codec.h"
#include "willow/src/input_encoding/codec_factory.h"
#include "willow/src/testing_utils/shell_testing_decryptor.h"
#include "willow/src/testing_utils/testing_utils.h"

namespace secure_aggregation {
namespace willow {
Expand All @@ -41,49 +41,13 @@ using ::testing::ElementsAreArray;
using ::testing::Pair;
using ::testing::UnorderedElementsAre;

AggregationConfigProto CreateTestConfig() {
AggregationConfigProto config;
VectorConfig vector_config;
vector_config.set_length(8); // 4 countries x 2 languages
vector_config.set_bound(100);
(*config.mutable_vector_configs())["metric1"] = vector_config;
config.set_max_number_of_decryptors(1);
config.set_max_number_of_clients(10);
config.set_key_id("test");
return config;
}

TEST(WillowShellClientTest, InitializeAndGenerateContribution) {
AggregationConfigProto config = CreateTestConfig();
AggregationConfigProto config = CreateTestAggregationConfigProto();

// Create and encode input.
MetricData metric_data;
metric_data["metric1"] = {10, 20, 5};
GroupData group_by_data;
group_by_data["country"] = {"US", "CA", "US"};
group_by_data["lang"] = {"en", "es", "es"};
InputSpec input_spec;
auto* metric_spec = input_spec.add_metric_vector_specs();
metric_spec->set_vector_name("metric1");
metric_spec->set_data_type(InputSpec::INT64);
auto* group_by_spec1 = input_spec.add_group_by_vector_specs();
group_by_spec1->set_vector_name("country");
group_by_spec1->set_data_type(InputSpec::STRING);
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"CA");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"GB");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"MX");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"US");
auto* group_by_spec2 = input_spec.add_group_by_vector_specs();
group_by_spec2->set_vector_name("lang");
group_by_spec2->set_data_type(InputSpec::STRING);
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"en");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"es");
MetricData metric_data = CreateTestMetricData();
GroupData group_by_data = CreateTestGroupData();
InputSpec input_spec = CreateTestInputSpecProto();
SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Codec> encoder,
CodecFactory::CreateExplicitCodec(input_spec));
SECAGG_ASSERT_OK_AND_ASSIGN(EncodedData encoded_data,
Expand Down Expand Up @@ -129,7 +93,7 @@ TEST(WillowShellClientTest, InitializeAndGenerateContribution) {
}

TEST(WillowShellClientTest, EmptyEncodedData) {
AggregationConfigProto config = CreateTestConfig();
AggregationConfigProto config = CreateTestAggregationConfigProto();

// Create empty encoded data.
EncodedData encoded_data;
Expand All @@ -151,7 +115,7 @@ TEST(WillowShellClientTest, EmptyEncodedData) {

TEST(WillowShellClientTest, InvalidAggregationConfig) {
// Originally valid config.
AggregationConfigProto config = CreateTestConfig();
AggregationConfigProto config = CreateTestAggregationConfigProto();

// Create encoded data directly.
EncodedData encoded_data = {{"metric1", {0, 20, 0, 0, 0, 0, 10, 5}}};
Expand Down
1 change: 1 addition & 0 deletions willow/src/input_encoding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ cc_test(
"@abseil-cpp//absl/status",
"//ffi_utils:status_matchers",
"//willow/proto/willow:input_spec_cc_proto",
"//willow/src/testing_utils:testing_utils_cc",
],
)
91 changes: 10 additions & 81 deletions willow/src/input_encoding/explicit_codec_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "willow/proto/willow/input_spec.pb.h"
#include "willow/src/input_encoding/codec.h"
#include "willow/src/input_encoding/codec_factory.h"
#include "willow/src/testing_utils/testing_utils.h"

namespace secure_aggregation {
namespace willow {
Expand Down Expand Up @@ -298,33 +299,9 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) {
}

TEST(CodecFactoryTest, EncodeSimpleGroupBy) {
MetricData metric_data;
metric_data["metric1"] = {10, 20, 5};
GroupData group_by_data;
group_by_data["country"] = {"US", "CA", "US"};
group_by_data["lang"] = {"en", "es", "es"};
InputSpec input_spec;
auto* metric_spec = input_spec.add_metric_vector_specs();
metric_spec->set_vector_name("metric1");
metric_spec->set_data_type(InputSpec::INT64);
auto* group_by_spec1 = input_spec.add_group_by_vector_specs();
group_by_spec1->set_vector_name("country");
group_by_spec1->set_data_type(InputSpec::STRING);
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"CA");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"GB");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"MX");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"US");
auto* group_by_spec2 = input_spec.add_group_by_vector_specs();
group_by_spec2->set_vector_name("lang");
group_by_spec2->set_data_type(InputSpec::STRING);
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"en");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"es");
InputSpec input_spec = CreateTestInputSpecProto();
MetricData metric_data = CreateTestMetricData();
GroupData group_by_data = CreateTestGroupData();

// group_by keys are sorted: "country", "lang"
// value_to_index_maps["country"]: {"CA":0, "GB":1, "MX":2, "US":3}
Expand Down Expand Up @@ -402,33 +379,9 @@ TEST(CodecFactoryTest, EncodeTwoMetricsOneGroupBy) {
}

TEST(CodecFactoryTest, EncodeThenDecode) {
MetricData metric_data;
metric_data["metric1"] = {10, 20, 5};
GroupData group_by_data;
group_by_data["country"] = {"US", "CA", "US"};
group_by_data["lang"] = {"en", "es", "es"};
InputSpec input_spec;
auto* metric_spec = input_spec.add_metric_vector_specs();
metric_spec->set_vector_name("metric1");
metric_spec->set_data_type(InputSpec::INT64);
auto* group_by_spec1 = input_spec.add_group_by_vector_specs();
group_by_spec1->set_vector_name("country");
group_by_spec1->set_data_type(InputSpec::STRING);
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"CA");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"GB");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"MX");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"US");
auto* group_by_spec2 = input_spec.add_group_by_vector_specs();
group_by_spec2->set_vector_name("lang");
group_by_spec2->set_data_type(InputSpec::STRING);
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"en");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"es");
InputSpec input_spec = CreateTestInputSpecProto();
MetricData metric_data = CreateTestMetricData();
GroupData group_by_data = CreateTestGroupData();

SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Codec> encoder,
CodecFactory::CreateExplicitCodec(input_spec));
Expand Down Expand Up @@ -457,33 +410,9 @@ TEST(CodecFactoryTest, EncodeThenDecode) {
}

TEST(CodecFactoryTest, EncodeThenDecodeDataOrderDoesNotMatter) {
MetricData metric_data1;
metric_data1["metric1"] = {10, 20, 5};
GroupData group_by_data1;
group_by_data1["lang"] = {"en", "es", "es"};
group_by_data1["country"] = {"US", "CA", "US"};
InputSpec input_spec;
auto* metric_spec = input_spec.add_metric_vector_specs();
metric_spec->set_vector_name("metric1");
metric_spec->set_data_type(InputSpec::INT64);
auto* group_by_spec1 = input_spec.add_group_by_vector_specs();
group_by_spec1->set_vector_name("lang");
group_by_spec1->set_data_type(InputSpec::STRING);
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"en");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"es");
auto* group_by_spec2 = input_spec.add_group_by_vector_specs();
group_by_spec2->set_vector_name("country");
group_by_spec2->set_data_type(InputSpec::STRING);
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"CA");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"GB");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"MX");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"US");
InputSpec input_spec = CreateTestInputSpecProto();
MetricData metric_data1 = CreateTestMetricData();
GroupData group_by_data1 = CreateTestGroupData();

SECAGG_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Codec> encoder1,
CodecFactory::CreateExplicitCodec(input_spec));
Expand Down
13 changes: 13 additions & 0 deletions willow/src/testing_utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ cc_library(
name = "shell_testing_decryptor_cc",
srcs = ["shell_testing_decryptor.cc"],
hdrs = ["shell_testing_decryptor.h"],
visibility = ["//visibility:public"],
deps = [
":shell_testing_decryptor_cxx",
"@abseil-cpp//absl/memory",
Expand All @@ -157,3 +158,15 @@ cc_test(
"//willow/proto/willow:aggregation_config_cc_proto",
],
)

cc_library(
name = "testing_utils_cc",
srcs = ["testing_utils.cc"],
hdrs = ["testing_utils.h"],
visibility = ["//visibility:public"],
deps = [
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:input_spec_cc_proto",
"//willow/src/input_encoding:codec",
],
)
74 changes: 74 additions & 0 deletions willow/src/testing_utils/testing_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "willow/src/testing_utils/testing_utils.h"

#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation::willow {

AggregationConfigProto CreateTestAggregationConfigProto() {
AggregationConfigProto config;
VectorConfig vector_config;
vector_config.set_length(8); // 4 countries x 2 languages
vector_config.set_bound(100);
(*config.mutable_vector_configs())["metric1"] = vector_config;
config.set_max_number_of_decryptors(1);
config.set_max_number_of_clients(10);
config.set_key_id("test");
return config;
}

InputSpec CreateTestInputSpecProto() {
InputSpec input_spec;
auto* metric_spec = input_spec.add_metric_vector_specs();
metric_spec->set_vector_name("metric1");
metric_spec->set_data_type(InputSpec::INT64);
metric_spec->mutable_domain_spec()->mutable_interval()->set_min(0);
metric_spec->mutable_domain_spec()->mutable_interval()->set_max(100);
auto* group_by_spec1 = input_spec.add_group_by_vector_specs();
group_by_spec1->set_vector_name("country");
group_by_spec1->set_data_type(InputSpec::STRING);
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"CA");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"GB");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"MX");
group_by_spec1->mutable_domain_spec()->mutable_string_values()->add_values(
"US");
auto* group_by_spec2 = input_spec.add_group_by_vector_specs();
group_by_spec2->set_vector_name("lang");
group_by_spec2->set_data_type(InputSpec::STRING);
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"en");
group_by_spec2->mutable_domain_spec()->mutable_string_values()->add_values(
"es");
return input_spec;
}

MetricData CreateTestMetricData() {
MetricData metric_data;
metric_data["metric1"] = {10, 20, 5};
return metric_data;
}

GroupData CreateTestGroupData() {
GroupData group_by_data;
group_by_data["country"] = {"US", "CA", "US"};
group_by_data["lang"] = {"en", "es", "es"};
return group_by_data;
}

} // namespace secure_aggregation::willow
38 changes: 38 additions & 0 deletions willow/src/testing_utils/testing_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_
#define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_

#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/input_spec.pb.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation::willow {

// Returns a test AggregationConfigProto.
AggregationConfigProto CreateTestAggregationConfigProto();

// Returns a test InputSpec proto with one metric and two group-by vectors.
InputSpec CreateTestInputSpecProto();

// Returns a test MetricData with one metric.
MetricData CreateTestMetricData();

// Returns a test GroupData with two group-by vectors.
GroupData CreateTestGroupData();

} // namespace secure_aggregation::willow

#endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_TESTING_UTILS_H_