diff --git a/cli/src/tasks.rs b/cli/src/tasks.rs index 7fdd9467..05c07c9f 100644 --- a/cli/src/tasks.rs +++ b/cli/src/tasks.rs @@ -1,6 +1,6 @@ -use crate::{CliResult, DetermineAccountId, Output}; +use crate::{CliResult, DetermineAccountId, Error, Output}; use clap::Subcommand; -use divviup_client::{DivviupClient, NewTask, Uuid, Vdaf}; +use divviup_client::{DivviupClient, Histogram, NewTask, Uuid, Vdaf}; use humantime::{Duration, Timestamp}; use time::{OffsetDateTime, UtcOffset}; @@ -38,8 +38,10 @@ pub enum TaskAction { time_precision: Duration, #[arg(long)] hpke_config_id: Uuid, - #[arg(long, required_if_eq("vdaf", "histogram"), value_delimiter = ',')] - buckets: Option>, + #[arg(long, value_delimiter = ',')] + categorical_buckets: Option>, + #[arg(long, value_delimiter = ',')] + continuous_buckets: Option>, #[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec")]))] length: Option, #[arg(long, required_if_eq_any([("vdaf", "sum"), ("vdaf", "sum_vec")]))] @@ -73,16 +75,33 @@ impl TaskAction { max_batch_size, expiration, hpke_config_id, - buckets, + categorical_buckets, + continuous_buckets, length, bits, time_precision, } => { let vdaf = match vdaf { VdafName::Count => Vdaf::Count, - VdafName::Histogram => Vdaf::Histogram { - buckets: buckets.unwrap(), - }, + VdafName::Histogram => { + match (length, categorical_buckets, continuous_buckets) { + (Some(length), None, None) => { + Vdaf::Histogram(Histogram::Length { length }) + } + (None, Some(buckets), None) => { + Vdaf::Histogram(Histogram::Categorical { buckets }) + } + (None, None, Some(buckets)) => { + Vdaf::Histogram(Histogram::Continuous { buckets }) + } + (None, None, None) => { + return Err(Error::Other("continuous-buckets, categorical-buckets, or length are required for histogram vdaf".into())); + } + _ => { + return Err(Error::Other("continuous-buckets, categorical-buckets, and length are mutually exclusive".into())); + } + } + } VdafName::Sum => Vdaf::Sum { bits: bits.unwrap(), }, diff --git a/client/src/aggregator.rs b/client/src/aggregator.rs index 98cdf9d0..0478311d 100644 --- a/client/src/aggregator.rs +++ b/client/src/aggregator.rs @@ -1,3 +1,4 @@ +use crate::Protocol; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use url::Url; @@ -29,6 +30,7 @@ pub struct Aggregator { pub is_first_party: bool, pub vdafs: Vec, pub query_types: Vec, + pub protocol: Protocol, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] diff --git a/client/src/lib.rs b/client/src/lib.rs index 6217b1b5..2bf2056e 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -12,6 +12,7 @@ mod aggregator; mod api_token; mod hpke_configs; mod membership; +mod protocol; mod task; mod validation_errors; @@ -34,7 +35,8 @@ pub use janus_messages::{ HpkeConfig as HpkeConfigContents, HpkePublicKey, }; pub use membership::Membership; -pub use task::{NewTask, Task, Vdaf}; +pub use protocol::Protocol; +pub use task::{Histogram, NewTask, Task, Vdaf}; pub use time::OffsetDateTime; pub use trillium_client; pub use trillium_client::Client; diff --git a/client/src/protocol.rs b/client/src/protocol.rs new file mode 100644 index 00000000..1a758f85 --- /dev/null +++ b/client/src/protocol.rs @@ -0,0 +1,42 @@ +use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Protocol { + #[serde(rename = "DAP-04")] + Dap04, + #[serde(rename = "DAP-05")] + Dap05, +} + +impl AsRef for Protocol { + fn as_ref(&self) -> &str { + match self { + Self::Dap04 => "DAP-04", + Self::Dap05 => "DAP-05", + } + } +} + +#[derive(Debug)] +pub struct UnrecognizedProtocol(String); +impl Display for UnrecognizedProtocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{} was not a recognized protocol", self.0)) + } +} +impl Error for UnrecognizedProtocol {} +impl FromStr for Protocol { + type Err = UnrecognizedProtocol; + fn from_str(s: &str) -> Result { + match &*s.to_lowercase() { + "dap-04" => Ok(Self::Dap04), + "dap-05" => Ok(Self::Dap05), + unrecognized => Err(UnrecognizedProtocol(unrecognized.to_string())), + } + } +} diff --git a/client/src/task.rs b/client/src/task.rs index f1fef7eb..b2111622 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -45,7 +45,7 @@ pub enum Vdaf { Count, #[serde(rename = "histogram")] - Histogram { buckets: Vec }, + Histogram(Histogram), #[serde(rename = "sum")] Sum { bits: u8 }, @@ -56,3 +56,11 @@ pub enum Vdaf { #[serde(rename = "sum_vec")] SumVec { bits: u8, length: u64 }, } + +#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] +#[serde(untagged)] +pub enum Histogram { + Categorical { buckets: Vec }, + Continuous { buckets: Vec }, + Length { length: u64 }, +} diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 1652de0b..96095f25 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -17,6 +17,7 @@ mod m20230703_201332_add_additional_fields_to_api_tokens; mod m20230725_220134_add_vdafs_and_query_types_to_aggregators; mod m20230731_181722_rename_aggregator_bearer_token; mod m20230808_204859_create_hpke_config; +mod m20230817_192017_add_protocol_to_aggregators; pub struct Migrator; @@ -41,6 +42,7 @@ impl MigratorTrait for Migrator { Box::new(m20230725_220134_add_vdafs_and_query_types_to_aggregators::Migration), Box::new(m20230731_181722_rename_aggregator_bearer_token::Migration), Box::new(m20230808_204859_create_hpke_config::Migration), + Box::new(m20230817_192017_add_protocol_to_aggregators::Migration), ] } } diff --git a/migration/src/m20230817_192017_add_protocol_to_aggregators.rs b/migration/src/m20230817_192017_add_protocol_to_aggregators.rs new file mode 100644 index 00000000..47017789 --- /dev/null +++ b/migration/src/m20230817_192017_add_protocol_to_aggregators.rs @@ -0,0 +1,52 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, db: &SchemaManager) -> Result<(), DbErr> { + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .add_column(ColumnDef::new(Aggregator::Protocol).string().null()) + .to_owned(), + ) + .await?; + + db.exec_stmt( + Query::update() + .table(Aggregator::Table) + .value(Aggregator::Protocol, "DAP-04") + .to_owned(), + ) + .await?; + + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .modify_column(ColumnDef::new(Aggregator::Protocol).not_null()) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, db: &SchemaManager) -> Result<(), DbErr> { + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .drop_column(Aggregator::Protocol) + .to_owned(), + ) + .await?; + Ok(()) + } +} + +#[derive(DeriveIden)] +enum Aggregator { + Table, + Protocol, +} diff --git a/src/api_mocks/aggregator_api.rs b/src/api_mocks/aggregator_api.rs index 01148528..383099c5 100644 --- a/src/api_mocks/aggregator_api.rs +++ b/src/api_mocks/aggregator_api.rs @@ -1,8 +1,8 @@ use super::random_chars; use crate::clients::aggregator_client::api_types::{ - AggregatorApiConfig, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, - HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, TaskMetrics, - TaskResponse, VdafInstance, + AggregatorApiConfig, AggregatorVdaf, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, + HpkeKemId, HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, + TaskMetrics, TaskResponse, }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use querystrong::QueryStrong; @@ -30,6 +30,7 @@ pub fn mock() -> impl Handler { role: random(), vdafs: Default::default(), query_types: Default::default(), + protocol: random(), }), ) .post("/tasks", api(post_task)) @@ -70,7 +71,7 @@ async fn get_task(conn: &mut Conn, (): ()) -> Json { task_id: task_id.parse().unwrap(), peer_aggregator_endpoint: "https://_".parse().unwrap(), query_type: QueryType::TimeInterval, - vdaf: VdafInstance::Prio3Count, + vdaf: AggregatorVdaf::Prio3Count, role: Role::Leader, vdaf_verify_key: random_chars(10), max_batch_query_count: 100, diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 1cb79ea6..92d8fc63 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -1,8 +1,8 @@ use crate::{ entity::{ aggregator::{QueryTypeName, QueryTypeNameSet, Role as AggregatorRole, VdafNameSet}, - task::vdaf::{CountVec, Histogram, Sum, SumVec, Vdaf}, - Aggregator, ProvisionableTask, + task::vdaf::{BucketLength, ContinuousBuckets, CountVec, Histogram, Sum, SumVec, Vdaf}, + Aggregator, Protocol, ProvisionableTask, }, handler::Error, }; @@ -16,28 +16,80 @@ pub use janus_messages::{ HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime, }; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[non_exhaustive] -pub enum VdafInstance { +pub enum AggregatorVdaf { Prio3Count, Prio3Sum { bits: u8 }, - Prio3Histogram { buckets: Vec }, + Prio3Histogram(HistogramType), Prio3CountVec { length: u64 }, Prio3SumVec { bits: u8, length: u64 }, } -impl From for Vdaf { - fn from(value: VdafInstance) -> Self { +impl PartialEq for AggregatorVdaf { + fn eq(&self, other: &Vdaf) -> bool { + other.eq(self) + } +} + +impl PartialEq for Vdaf { + fn eq(&self, other: &AggregatorVdaf) -> bool { + match (self, other) { + (Vdaf::Count, AggregatorVdaf::Prio3Count) => true, + ( + Vdaf::Histogram(histogram), + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }), + ) => histogram.length() == *length, + ( + Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { buckets: Some(lhs) })), + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets: rhs }), + ) => lhs == rhs, + (Vdaf::Sum(Sum { bits: Some(lhs) }), AggregatorVdaf::Prio3Sum { bits: rhs }) => { + lhs == rhs + } + ( + Vdaf::CountVec(CountVec { length: Some(lhs) }), + AggregatorVdaf::Prio3CountVec { length: rhs }, + ) => lhs == rhs, + ( + Vdaf::SumVec(SumVec { + bits: Some(lhs_bits), + length: Some(lhs_length), + }), + AggregatorVdaf::Prio3SumVec { + bits: rhs_bits, + length: rhs_length, + }, + ) => lhs_bits == rhs_bits && lhs_length == rhs_length, + _ => false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(untagged)] +pub enum HistogramType { + Opaque { length: u64 }, + Buckets { buckets: Vec }, +} + +impl From for Vdaf { + fn from(value: AggregatorVdaf) -> Self { match value { - VdafInstance::Prio3Count => Self::Count, - VdafInstance::Prio3Sum { bits } => Self::Sum(Sum { bits: Some(bits) }), - VdafInstance::Prio3Histogram { buckets } => Self::Histogram(Histogram { - buckets: Some(buckets), - }), - VdafInstance::Prio3CountVec { length } => Self::CountVec(CountVec { + AggregatorVdaf::Prio3Count => Self::Count, + AggregatorVdaf::Prio3Sum { bits } => Self::Sum(Sum { bits: Some(bits) }), + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets }) => { + Self::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(buckets), + })) + } + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }) => { + Self::Histogram(Histogram::Opaque(BucketLength { length })) + } + AggregatorVdaf::Prio3CountVec { length } => Self::CountVec(CountVec { length: Some(length), }), - VdafInstance::Prio3SumVec { bits, length } => Self::SumVec(SumVec { + AggregatorVdaf::Prio3SumVec { bits, length } => Self::SumVec(SumVec { length: Some(length), bits: Some(bits), }), @@ -45,28 +97,6 @@ impl From for Vdaf { } } -impl From for VdafInstance { - fn from(value: Vdaf) -> Self { - match value { - Vdaf::Count => Self::Prio3Count, - Vdaf::Histogram(Histogram { buckets }) => Self::Prio3Histogram { - buckets: buckets.unwrap(), - }, - Vdaf::Sum(Sum { bits }) => Self::Prio3Sum { - bits: bits.unwrap(), - }, - Vdaf::CountVec(CountVec { length }) => Self::Prio3CountVec { - length: length.unwrap(), - }, - Vdaf::SumVec(SumVec { length, bits }) => Self::Prio3SumVec { - bits: bits.unwrap(), - length: length.unwrap(), - }, - Vdaf::Unrecognized => unreachable!(), - } - } -} - #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum QueryType { TimeInterval, @@ -142,7 +172,7 @@ pub struct TaskCreate { pub collector_auth_token: Option, pub peer_aggregator_endpoint: Url, pub query_type: QueryType, - pub vdaf: VdafInstance, + pub vdaf: AggregatorVdaf, pub role: Role, pub max_batch_query_count: u64, pub task_expiration: Option, @@ -169,7 +199,7 @@ impl TaskCreate { new_task.leader_aggregator.dap_url.clone().into() }, query_type: new_task.max_batch_size.into(), - vdaf: new_task.vdaf.clone().into(), + vdaf: new_task.aggregator_vdaf.clone(), role, max_batch_query_count: 1, task_expiration: new_task.expiration.map(|expiration| { @@ -193,7 +223,7 @@ pub struct TaskResponse { pub task_id: TaskId, pub peer_aggregator_endpoint: Url, pub query_type: QueryType, - pub vdaf: VdafInstance, + pub vdaf: AggregatorVdaf, pub role: Role, pub vdaf_verify_key: String, pub max_batch_query_count: u64, @@ -236,6 +266,8 @@ pub struct AggregatorApiConfig { pub role: AggregatorRole, pub vdafs: VdafNameSet, pub query_types: QueryTypeNameSet, + #[serde(default)] + pub protocol: Protocol, } #[cfg(test)] diff --git a/src/entity.rs b/src/entity.rs index 50734803..8e780793 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -16,7 +16,7 @@ pub use account::{ }; pub use aggregator::{ Column as AggregatorColumn, Entity as Aggregators, Model as Aggregator, NewAggregator, - UpdateAggregator, + Protocol, Role, UnrecognizedProtocol, UnrecognizedRole, UpdateAggregator, }; pub use api_token::{ Column as ApiTokenColumn, Entity as ApiTokens, Model as ApiToken, UpdateApiToken, diff --git a/src/entity/aggregator.rs b/src/entity/aggregator.rs index c8b7da13..7ac6acbe 100644 --- a/src/entity/aggregator.rs +++ b/src/entity/aggregator.rs @@ -1,4 +1,5 @@ mod new_aggregator; +mod protocol; mod query_type_name; mod role; mod update_aggregator; @@ -16,6 +17,7 @@ use time::OffsetDateTime; use uuid::Uuid; pub use new_aggregator::NewAggregator; +pub use protocol::{Protocol, UnrecognizedProtocol}; pub use query_type_name::{QueryTypeName, QueryTypeNameSet}; pub use role::{Role, UnrecognizedRole}; pub use update_aggregator::UpdateAggregator; @@ -42,6 +44,7 @@ pub struct Model { pub is_first_party: bool, pub query_types: Json, pub vdafs: Json, + pub protocol: Protocol, #[serde(skip)] pub encrypted_bearer_token: Vec, } diff --git a/src/entity/aggregator/new_aggregator.rs b/src/entity/aggregator/new_aggregator.rs index dc748474..1025d697 100644 --- a/src/entity/aggregator/new_aggregator.rs +++ b/src/entity/aggregator/new_aggregator.rs @@ -102,6 +102,7 @@ impl NewAggregator { is_first_party: account.is_none() && self.is_first_party.unwrap_or(true), query_types: aggregator_config.query_types.into(), vdafs: aggregator_config.vdafs.into(), + protocol: aggregator_config.protocol, } .into_active_model()) } diff --git a/src/entity/aggregator/protocol.rs b/src/entity/aggregator/protocol.rs new file mode 100644 index 00000000..6ba01b07 --- /dev/null +++ b/src/entity/aggregator/protocol.rs @@ -0,0 +1,67 @@ +use rand::{distributions::Standard, prelude::Distribution}; +use sea_orm::{DeriveActiveEnum, EnumIter}; +use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize, Default, +)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum Protocol { + #[sea_orm(string_value = "DAP-04")] + #[serde(rename = "DAP-04")] + #[default] + Dap04, + + #[sea_orm(string_value = "DAP-05")] + #[serde(rename = "DAP-05")] + Dap05, +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Protocol { + if rng.gen() { + Protocol::Dap04 + } else { + Protocol::Dap05 + } + } +} + +impl Display for Protocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl AsRef for Protocol { + fn as_ref(&self) -> &str { + match self { + Self::Dap04 => "DAP-04", + Self::Dap05 => "DAP-05", + } + } +} + +#[derive(Debug)] +pub struct UnrecognizedProtocol(String); +impl Display for UnrecognizedProtocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{} was not a recognized protocol", self.0)) + } +} +impl Error for UnrecognizedProtocol {} +impl FromStr for Protocol { + type Err = UnrecognizedProtocol; + fn from_str(s: &str) -> Result { + match &*s.to_lowercase() { + "dap-04" => Ok(Self::Dap04), + "dap-05" => Ok(Self::Dap05), + unrecognized => Err(UnrecognizedProtocol(unrecognized.to_string())), + } + } +} diff --git a/src/entity/task/new_task.rs b/src/entity/task/new_task.rs index 96f52a66..4288a2e8 100644 --- a/src/entity/task/new_task.rs +++ b/src/entity/task/new_task.rs @@ -1,7 +1,9 @@ use super::*; use crate::{ - clients::aggregator_client::api_types::QueryType, - entity::{aggregator::Role, Account, Aggregator, Aggregators, HpkeConfig, HpkeConfigColumn}, + clients::aggregator_client::api_types::{AggregatorVdaf, QueryType}, + entity::{ + aggregator::Role, Account, Aggregator, Aggregators, HpkeConfig, HpkeConfigColumn, Protocol, + }, handler::Error, }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; @@ -131,7 +133,7 @@ impl NewTask { account: &Account, db: &impl ConnectionTrait, errors: &mut ValidationErrors, - ) -> Option<(Aggregator, Aggregator)> { + ) -> Option<(Aggregator, Aggregator, Protocol)> { let leader = load_aggregator(account, self.leader_aggregator_id.as_deref(), db) .await .ok() @@ -168,6 +170,14 @@ impl NewTask { ); } + let resolved_protocol = if leader.protocol == helper.protocol { + leader.protocol + } else { + errors.add("leader_aggregator_id", ValidationError::new("protocol")); + errors.add("helper_aggregator_id", ValidationError::new("protocol")); + return None; + }; + if leader.role == Role::Helper { errors.add("leader_aggregator_id", ValidationError::new("role")) } @@ -177,7 +187,7 @@ impl NewTask { } if errors.is_empty() { - Some((leader, helper)) + Some((leader, helper, resolved_protocol)) } else { None } @@ -187,22 +197,48 @@ impl NewTask { &self, leader: &Aggregator, helper: &Aggregator, + protocol: &Protocol, errors: &mut ValidationErrors, - ) { + ) -> Option { let Some(vdaf) = self.vdaf.as_ref() else { - return; + return None; }; + let name = vdaf.name(); - if leader.vdafs.contains(&name) && helper.vdafs.contains(&name) { - return; - } - if let ValidationErrorsKind::Struct(errors) = errors - .errors_mut() - .entry("vdaf") - .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))) - { - errors.add("type", ValidationError::new("not-supported")); + let aggregator_vdaf = match vdaf.representation_for_protocol(protocol) { + Ok(vdaf) => vdaf, + Err(e) => { + let errors = errors.errors_mut().entry("vdaf").or_insert_with(|| { + ValidationErrorsKind::Struct(Box::new(ValidationErrors::new())) + }); + match errors { + ValidationErrorsKind::Struct(errors) => { + errors.errors_mut().extend(e.into_errors()) + } + other => *other = ValidationErrorsKind::Struct(Box::new(e)), + }; + return None; + } + }; + + if !leader.vdafs.contains(&name) || !helper.vdafs.contains(&name) { + let errors = errors + .errors_mut() + .entry("vdaf") + .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))); + match errors { + ValidationErrorsKind::Struct(errors) => { + errors.add("type", ValidationError::new("not-supported")); + } + other => { + let mut e = ValidationErrors::new(); + e.add("type", ValidationError::new("not-supported")); + *other = ValidationErrorsKind::Struct(Box::new(e)); + } + }; } + + Some(aggregator_vdaf) } fn validate_query_type_is_supported( @@ -226,10 +262,13 @@ impl NewTask { self.validate_min_lte_max(&mut errors); let hpke_config = self.validate_hpke_config(&account, db, &mut errors).await; let aggregators = self.validate_aggregators(&account, db, &mut errors).await; - if let Some((leader, helper)) = aggregators.as_ref() { - self.validate_vdaf_is_supported(leader, helper, &mut errors); + + let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() { self.validate_query_type_is_supported(leader, helper, &mut errors); - } + self.validate_vdaf_is_supported(leader, helper, protocol, &mut errors) + } else { + None + }; if errors.is_empty() { // Unwrap safety: All of these unwraps below have previously @@ -238,7 +277,7 @@ impl NewTask { // disharmonious combination of Validate and the fact that we // need to use options for all fields so serde doesn't bail on // the first error. - let (leader_aggregator, helper_aggregator) = aggregators.unwrap(); + let (leader_aggregator, helper_aggregator, protocol) = aggregators.unwrap(); let (vdaf_verify_key, id) = generate_vdaf_verify_key_and_expected_task_id(); @@ -250,12 +289,14 @@ impl NewTask { leader_aggregator, helper_aggregator, vdaf: self.vdaf.clone().unwrap(), + aggregator_vdaf: aggregator_vdaf.unwrap(), min_batch_size: self.min_batch_size.unwrap(), max_batch_size: self.max_batch_size, expiration: self.expiration, time_precision_seconds: self.time_precision_seconds.unwrap(), hpke_config: hpke_config.unwrap(), aggregator_auth_token: None, + protocol, }) } else { Err(errors) diff --git a/src/entity/task/provisionable_task.rs b/src/entity/task/provisionable_task.rs index c9584152..301f72bb 100644 --- a/src/entity/task/provisionable_task.rs +++ b/src/entity/task/provisionable_task.rs @@ -1,7 +1,7 @@ use super::{ActiveModel, *}; use crate::{ - clients::aggregator_client::api_types::AuthenticationToken, - entity::{Account, Aggregator, HpkeConfig, Task}, + clients::aggregator_client::api_types::{AggregatorVdaf, AuthenticationToken}, + entity::{Account, Aggregator, HpkeConfig, Protocol, Task}, handler::Error, Crypter, }; @@ -24,12 +24,14 @@ pub struct ProvisionableTask { pub leader_aggregator: Aggregator, pub helper_aggregator: Aggregator, pub vdaf: Vdaf, + pub aggregator_vdaf: AggregatorVdaf, pub min_batch_size: u64, pub max_batch_size: Option, pub expiration: Option, pub time_precision_seconds: u64, pub hpke_config: HpkeConfig, pub aggregator_auth_token: Option, + pub protocol: Protocol, } fn assert_same( @@ -56,7 +58,7 @@ impl ProvisionableTask { .create_task(self) .await?; - assert_same(&self.vdaf, &response.vdaf.clone().into(), "vdaf")?; + assert_same(&self.aggregator_vdaf, &response.vdaf, "vdaf")?; assert_same( self.min_batch_size, response.min_batch_size, diff --git a/src/entity/task/vdaf.rs b/src/entity/task/vdaf.rs index 6d884791..a27f1aed 100644 --- a/src/entity/task/vdaf.rs +++ b/src/entity/task/vdaf.rs @@ -1,31 +1,106 @@ -use crate::entity::aggregator::VdafName; +use crate::{ + clients::aggregator_client::api_types::{AggregatorVdaf, HistogramType}, + entity::{aggregator::VdafName, Protocol}, +}; use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, hash::Hash}; use validator::{Validate, ValidationError, ValidationErrors}; -#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] -pub struct Histogram { - #[validate(required, custom = "strictly_increasing")] - pub buckets: Option>, +#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] +#[serde(untagged)] +pub enum Histogram { + Opaque(BucketLength), + Categorical(CategoricalBuckets), + Continuous(ContinuousBuckets), } -fn strictly_increasing(buckets: &Vec) -> Result<(), ValidationError> { - let mut last_bucket = None; - for bucket in buckets { - let bucket = *bucket; - match last_bucket { - Some(last_bucket) if last_bucket == bucket => { - return Err(ValidationError::new("unique")); +impl Histogram { + pub fn length(&self) -> u64 { + match self { + Histogram::Categorical(CategoricalBuckets { + buckets: Some(buckets), + }) => buckets.len() as u64, + Histogram::Continuous(ContinuousBuckets { + buckets: Some(buckets), + }) => buckets.len() as u64 + 1, + Histogram::Opaque(BucketLength { length }) => *length, + _ => 0, + } + } + + fn representation_for_protocol( + &self, + protocol: &Protocol, + ) -> Result { + match (protocol, self) { + (Protocol::Dap05, histogram) => { + Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { + length: histogram.length(), + })) } - Some(last_bucket) if last_bucket > bucket => { - return Err(ValidationError::new("sorted")); + ( + Protocol::Dap04, + Self::Continuous(ContinuousBuckets { + buckets: Some(buckets), + }), + ) => Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { + buckets: buckets.clone(), + })), + + (Protocol::Dap04, Self::Categorical(_)) => { + let mut errors = ValidationErrors::new(); + errors.add("buckets", ValidationError::new("must-be-numbers")); + Err(errors) } - _ => { - last_bucket = Some(bucket); + (Protocol::Dap04, _) => { + let mut errors = ValidationErrors::new(); + errors.add("buckets", ValidationError::new("required")); + Err(errors) } } } +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +pub struct ContinuousBuckets { + #[validate(required, length(min = 1), custom = "increasing", custom = "unique")] + pub buckets: Option>, +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +pub struct CategoricalBuckets { + #[validate(required, length(min = 1), custom = "unique")] + pub buckets: Option>, +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq, Copy)] +pub struct BucketLength { + #[validate(range(min = 1))] + pub length: u64, +} + +fn unique(buckets: &[T]) -> Result<(), ValidationError> { + if buckets.len() == buckets.iter().collect::>().len() { + Ok(()) + } else { + Err(ValidationError::new("unique")) + } +} + +fn increasing(buckets: &[u64]) -> Result<(), ValidationError> { + let Some(mut last) = buckets.first().copied() else { + return Ok(()); + }; + + for bucket in &buckets[1..] { + if *bucket >= last { + last = *bucket; + } else { + return Err(ValidationError::new("sorted")); + } + } Ok(()) } @@ -83,13 +158,37 @@ impl Vdaf { Vdaf::Unrecognized => VdafName::Other("unsupported".into()), } } + + pub fn representation_for_protocol( + &self, + protocol: &Protocol, + ) -> Result { + match self { + Self::Histogram(histogram) => histogram.representation_for_protocol(protocol), + Self::Count => Ok(AggregatorVdaf::Prio3Count), + Self::Sum(Sum { bits: Some(bits) }) => Ok(AggregatorVdaf::Prio3Sum { bits: *bits }), + Self::SumVec(SumVec { + length: Some(length), + bits: Some(bits), + }) => Ok(AggregatorVdaf::Prio3SumVec { + bits: *bits, + length: *length, + }), + Self::CountVec(CountVec { + length: Some(length), + }) => Ok(AggregatorVdaf::Prio3CountVec { length: *length }), + _ => Err(ValidationErrors::new()), + } + } } impl Validate for Vdaf { fn validate(&self) -> Result<(), ValidationErrors> { match self { Vdaf::Count => Ok(()), - Vdaf::Histogram(h) => h.validate(), + Vdaf::Histogram(Histogram::Continuous(buckets)) => buckets.validate(), + Vdaf::Histogram(Histogram::Categorical(buckets)) => buckets.validate(), + Vdaf::Histogram(Histogram::Opaque(length)) => length.validate(), Vdaf::Sum(s) => s.validate(), Vdaf::SumVec(sv) => sv.validate(), Vdaf::CountVec(cv) => cv.validate(), @@ -108,15 +207,15 @@ mod tests { use crate::test::assert_errors; #[test] - fn validate_histogram() { - assert!(Histogram { + fn validate_continuous_histogram() { + assert!(ContinuousBuckets { buckets: Some(vec![0, 1, 2]) } .validate() .is_ok()); assert_errors( - Histogram { + ContinuousBuckets { buckets: Some(vec![0, 2, 1]), }, "buckets", @@ -124,11 +223,28 @@ mod tests { ); assert_errors( - Histogram { + ContinuousBuckets { buckets: Some(vec![0, 0, 2]), }, "buckets", &["unique"], ); } + + #[test] + fn validate_categorical_histogram() { + assert!(CategoricalBuckets { + buckets: Some(vec!["a".into(), "b".into()]) + } + .validate() + .is_ok()); + + assert_errors( + CategoricalBuckets { + buckets: Some(vec!["a".into(), "a".into()]), + }, + "buckets", + &["unique"], + ); + } } diff --git a/test-support/src/fixtures.rs b/test-support/src/fixtures.rs index 233228e1..550eb3e3 100644 --- a/test-support/src/fixtures.rs +++ b/test-support/src/fixtures.rs @@ -160,6 +160,7 @@ pub async fn aggregator(app: &DivviupApi, account: Option<&Account>) -> Aggregat role: Role::Either, query_types: Default::default(), vdafs: Default::default(), + protocol: Protocol::Dap05, } .into_active_model() .insert(app.db()) diff --git a/tests/vdaf.rs b/tests/vdaf.rs new file mode 100644 index 00000000..112ea093 --- /dev/null +++ b/tests/vdaf.rs @@ -0,0 +1,48 @@ +use divviup_api::entity::task::vdaf::Vdaf; +use test_support::{assert_eq, test, *}; +#[test] +pub fn histogram_representations() { + let scenarios = [ + ( + json!({"type": "histogram", "buckets": ["a", "b", "c"]}), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params": {}}]})), + ), + ( + json!({"type": "histogram", "buckets": ["a", "b", "c"]}), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length": 3}})), + ), + ( + json!({"type": "histogram", "buckets": [1, 2, 3]}), + Protocol::Dap04, + Ok(json!({"Prio3Histogram": {"buckets": [1, 2, 3]}})), + ), + ( + json!({"type": "histogram", "buckets": [1, 2, 3]}), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length": 4}})), + ), + ( + json!({"type": "histogram", "length": 3}), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "required", "message": null, "params":{}}]})), + ), + ( + json!({"type": "histogram", "length": 3}), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length": 3}})), + ), + ]; + + for (input, protocol, output) in scenarios { + let vdaf: Vdaf = serde_json::from_value(input.clone()).unwrap(); + assert_eq!( + output, + vdaf.representation_for_protocol(&protocol) + .map(|o| serde_json::to_value(o).unwrap()) + .map_err(|e| serde_json::to_value(e).unwrap()), + "{vdaf:?} {input} {protocol}" + ); + } +}