From 3b702a35a641e32aab69af1ebfb00e693ab001fd Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Thu, 17 Aug 2023 19:00:01 -0700 Subject: [PATCH 1/5] protocol-determined vdaf representation --- cli/src/tasks.rs | 37 +++- client/src/aggregator.rs | 2 + client/src/lib.rs | 4 +- client/src/protocol.rs | 42 +++++ client/src/task.rs | 12 +- migration/src/lib.rs | 2 + ...d_dap_versions_to_tasks_and_aggregators.rs | 89 ++++++++++ src/api_mocks/aggregator_api.rs | 9 +- src/clients/aggregator_client/api_types.rs | 110 +++++++----- src/entity.rs | 2 + src/entity/aggregator.rs | 2 + src/entity/aggregator/new_aggregator.rs | 1 + src/entity/protocol.rs | 67 ++++++++ src/entity/task.rs | 2 + src/entity/task/new_task.rs | 75 +++++++-- src/entity/task/provisionable_task.rs | 7 +- src/entity/task/vdaf.rs | 158 +++++++++++++++--- test-support/src/fixtures.rs | 2 + tests/vdaf.rs | 48 ++++++ 19 files changed, 576 insertions(+), 95 deletions(-) create mode 100644 client/src/protocol.rs create mode 100644 migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs create mode 100644 src/entity/protocol.rs create mode 100644 tests/vdaf.rs diff --git a/cli/src/tasks.rs b/cli/src/tasks.rs index 7fdd94675..54fdfc261 100644 --- a/cli/src/tasks.rs +++ b/cli/src/tasks.rs @@ -1,6 +1,6 @@ -use crate::{CliResult, DetermineAccountId, Output}; +use crate::{CliResult, DetermineAccountId, Error, Output}; use clap::Subcommand; -use divviup_client::{DivviupClient, NewTask, Uuid, Vdaf}; +use divviup_client::{DivviupClient, Histogram, NewTask, Uuid, Vdaf}; use humantime::{Duration, Timestamp}; use time::{OffsetDateTime, UtcOffset}; @@ -38,9 +38,11 @@ pub enum TaskAction { time_precision: Duration, #[arg(long)] hpke_config_id: Uuid, - #[arg(long, required_if_eq("vdaf", "histogram"), value_delimiter = ',')] - buckets: Option>, - #[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec")]))] + #[arg(long, value_delimiter = ',')] + categorical_buckets: Option>, + #[arg(long, value_delimiter = ',')] + continuous_buckets: Option>, + #[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec"), ("vdaf", "histogram")]))] length: Option, #[arg(long, required_if_eq_any([("vdaf", "sum"), ("vdaf", "sum_vec")]))] bits: Option, @@ -73,16 +75,33 @@ impl TaskAction { max_batch_size, expiration, hpke_config_id, - buckets, + categorical_buckets, + continuous_buckets, length, bits, time_precision, } => { let vdaf = match vdaf { VdafName::Count => Vdaf::Count, - VdafName::Histogram => Vdaf::Histogram { - buckets: buckets.unwrap(), - }, + VdafName::Histogram => { + match (length, categorical_buckets, continuous_buckets) { + (Some(length), None, None) => { + Vdaf::Histogram(Histogram::Length { length }) + } + (None, Some(buckets), None) => { + Vdaf::Histogram(Histogram::Categorical { buckets }) + } + (None, None, Some(buckets)) => { + Vdaf::Histogram(Histogram::Continuous { buckets }) + } + (None, None, None) => { + return Err(Error::Other("continuous-buckets, categorical-buckets, or length are required for histogram vdaf".into())); + } + _ => { + return Err(Error::Other("continuous-buckets, categorical-buckets, and length mutually exclusive".into())); + } + } + } VdafName::Sum => Vdaf::Sum { bits: bits.unwrap(), }, diff --git a/client/src/aggregator.rs b/client/src/aggregator.rs index 98cdf9d04..0478311d4 100644 --- a/client/src/aggregator.rs +++ b/client/src/aggregator.rs @@ -1,3 +1,4 @@ +use crate::Protocol; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use url::Url; @@ -29,6 +30,7 @@ pub struct Aggregator { pub is_first_party: bool, pub vdafs: Vec, pub query_types: Vec, + pub protocol: Protocol, } #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] diff --git a/client/src/lib.rs b/client/src/lib.rs index 6217b1b52..2bf2056ed 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -12,6 +12,7 @@ mod aggregator; mod api_token; mod hpke_configs; mod membership; +mod protocol; mod task; mod validation_errors; @@ -34,7 +35,8 @@ pub use janus_messages::{ HpkeConfig as HpkeConfigContents, HpkePublicKey, }; pub use membership::Membership; -pub use task::{NewTask, Task, Vdaf}; +pub use protocol::Protocol; +pub use task::{Histogram, NewTask, Task, Vdaf}; pub use time::OffsetDateTime; pub use trillium_client; pub use trillium_client::Client; diff --git a/client/src/protocol.rs b/client/src/protocol.rs new file mode 100644 index 000000000..1a758f851 --- /dev/null +++ b/client/src/protocol.rs @@ -0,0 +1,42 @@ +use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Protocol { + #[serde(rename = "DAP-04")] + Dap04, + #[serde(rename = "DAP-05")] + Dap05, +} + +impl AsRef for Protocol { + fn as_ref(&self) -> &str { + match self { + Self::Dap04 => "DAP-04", + Self::Dap05 => "DAP-05", + } + } +} + +#[derive(Debug)] +pub struct UnrecognizedProtocol(String); +impl Display for UnrecognizedProtocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{} was not a recognized protocol", self.0)) + } +} +impl Error for UnrecognizedProtocol {} +impl FromStr for Protocol { + type Err = UnrecognizedProtocol; + fn from_str(s: &str) -> Result { + match &*s.to_lowercase() { + "dap-04" => Ok(Self::Dap04), + "dap-05" => Ok(Self::Dap05), + unrecognized => Err(UnrecognizedProtocol(unrecognized.to_string())), + } + } +} diff --git a/client/src/task.rs b/client/src/task.rs index f1fef7eb1..6598ca7e6 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -1,3 +1,4 @@ +use crate::Protocol; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; @@ -22,6 +23,7 @@ pub struct Task { pub leader_aggregator_id: Uuid, pub helper_aggregator_id: Uuid, pub hpke_config_id: Uuid, + pub protocol: Protocol, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] @@ -45,7 +47,7 @@ pub enum Vdaf { Count, #[serde(rename = "histogram")] - Histogram { buckets: Vec }, + Histogram(Histogram), #[serde(rename = "sum")] Sum { bits: u8 }, @@ -56,3 +58,11 @@ pub enum Vdaf { #[serde(rename = "sum_vec")] SumVec { bits: u8, length: u64 }, } + +#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] +#[serde(untagged)] +pub enum Histogram { + Categorical { buckets: Vec }, + Continuous { buckets: Vec }, + Length { length: u64 }, +} diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 1652de0b4..01953bc86 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -17,6 +17,7 @@ mod m20230703_201332_add_additional_fields_to_api_tokens; mod m20230725_220134_add_vdafs_and_query_types_to_aggregators; mod m20230731_181722_rename_aggregator_bearer_token; mod m20230808_204859_create_hpke_config; +mod m20230817_192017_add_dap_versions_to_tasks_and_aggregators; pub struct Migrator; @@ -41,6 +42,7 @@ impl MigratorTrait for Migrator { Box::new(m20230725_220134_add_vdafs_and_query_types_to_aggregators::Migration), Box::new(m20230731_181722_rename_aggregator_bearer_token::Migration), Box::new(m20230808_204859_create_hpke_config::Migration), + Box::new(m20230817_192017_add_dap_versions_to_tasks_and_aggregators::Migration), ] } } diff --git a/migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs b/migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs new file mode 100644 index 000000000..b49294d59 --- /dev/null +++ b/migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs @@ -0,0 +1,89 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, db: &SchemaManager) -> Result<(), DbErr> { + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .add_column(ColumnDef::new(Aggregator::Protocol).string().null()) + .to_owned(), + ) + .await?; + + db.exec_stmt( + Query::update() + .table(Aggregator::Table) + .value(Aggregator::Protocol, "DAP-04") + .to_owned(), + ) + .await?; + + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .modify_column(ColumnDef::new(Aggregator::Protocol).not_null()) + .to_owned(), + ) + .await?; + + db.alter_table( + TableAlterStatement::new() + .table(Task::Table) + .add_column(ColumnDef::new(Task::Protocol).string().null()) + .to_owned(), + ) + .await?; + + db.exec_stmt( + Query::update() + .table(Task::Table) + .value(Task::Protocol, "DAP-04") + .to_owned(), + ) + .await?; + + db.alter_table( + TableAlterStatement::new() + .table(Task::Table) + .modify_column(ColumnDef::new(Task::Protocol).not_null()) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, db: &SchemaManager) -> Result<(), DbErr> { + db.alter_table( + TableAlterStatement::new() + .table(Aggregator::Table) + .drop_column(Aggregator::Protocol) + .to_owned(), + ) + .await?; + db.alter_table( + TableAlterStatement::new() + .table(Task::Table) + .drop_column(Task::Protocol) + .to_owned(), + ) + .await?; + Ok(()) + } +} + +#[derive(DeriveIden)] +enum Aggregator { + Table, + Protocol, +} + +#[derive(DeriveIden)] +enum Task { + Table, + Protocol, +} diff --git a/src/api_mocks/aggregator_api.rs b/src/api_mocks/aggregator_api.rs index 011485282..383099c50 100644 --- a/src/api_mocks/aggregator_api.rs +++ b/src/api_mocks/aggregator_api.rs @@ -1,8 +1,8 @@ use super::random_chars; use crate::clients::aggregator_client::api_types::{ - AggregatorApiConfig, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, - HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, TaskMetrics, - TaskResponse, VdafInstance, + AggregatorApiConfig, AggregatorVdaf, AuthenticationToken, HpkeAeadId, HpkeConfig, HpkeKdfId, + HpkeKemId, HpkePublicKey, JanusDuration, QueryType, Role, TaskCreate, TaskId, TaskIds, + TaskMetrics, TaskResponse, }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use querystrong::QueryStrong; @@ -30,6 +30,7 @@ pub fn mock() -> impl Handler { role: random(), vdafs: Default::default(), query_types: Default::default(), + protocol: random(), }), ) .post("/tasks", api(post_task)) @@ -70,7 +71,7 @@ async fn get_task(conn: &mut Conn, (): ()) -> Json { task_id: task_id.parse().unwrap(), peer_aggregator_endpoint: "https://_".parse().unwrap(), query_type: QueryType::TimeInterval, - vdaf: VdafInstance::Prio3Count, + vdaf: AggregatorVdaf::Prio3Count, role: Role::Leader, vdaf_verify_key: random_chars(10), max_batch_query_count: 100, diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 1cb79ea63..92d8fc63a 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -1,8 +1,8 @@ use crate::{ entity::{ aggregator::{QueryTypeName, QueryTypeNameSet, Role as AggregatorRole, VdafNameSet}, - task::vdaf::{CountVec, Histogram, Sum, SumVec, Vdaf}, - Aggregator, ProvisionableTask, + task::vdaf::{BucketLength, ContinuousBuckets, CountVec, Histogram, Sum, SumVec, Vdaf}, + Aggregator, Protocol, ProvisionableTask, }, handler::Error, }; @@ -16,28 +16,80 @@ pub use janus_messages::{ HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime, }; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[non_exhaustive] -pub enum VdafInstance { +pub enum AggregatorVdaf { Prio3Count, Prio3Sum { bits: u8 }, - Prio3Histogram { buckets: Vec }, + Prio3Histogram(HistogramType), Prio3CountVec { length: u64 }, Prio3SumVec { bits: u8, length: u64 }, } -impl From for Vdaf { - fn from(value: VdafInstance) -> Self { +impl PartialEq for AggregatorVdaf { + fn eq(&self, other: &Vdaf) -> bool { + other.eq(self) + } +} + +impl PartialEq for Vdaf { + fn eq(&self, other: &AggregatorVdaf) -> bool { + match (self, other) { + (Vdaf::Count, AggregatorVdaf::Prio3Count) => true, + ( + Vdaf::Histogram(histogram), + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }), + ) => histogram.length() == *length, + ( + Vdaf::Histogram(Histogram::Continuous(ContinuousBuckets { buckets: Some(lhs) })), + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets: rhs }), + ) => lhs == rhs, + (Vdaf::Sum(Sum { bits: Some(lhs) }), AggregatorVdaf::Prio3Sum { bits: rhs }) => { + lhs == rhs + } + ( + Vdaf::CountVec(CountVec { length: Some(lhs) }), + AggregatorVdaf::Prio3CountVec { length: rhs }, + ) => lhs == rhs, + ( + Vdaf::SumVec(SumVec { + bits: Some(lhs_bits), + length: Some(lhs_length), + }), + AggregatorVdaf::Prio3SumVec { + bits: rhs_bits, + length: rhs_length, + }, + ) => lhs_bits == rhs_bits && lhs_length == rhs_length, + _ => false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(untagged)] +pub enum HistogramType { + Opaque { length: u64 }, + Buckets { buckets: Vec }, +} + +impl From for Vdaf { + fn from(value: AggregatorVdaf) -> Self { match value { - VdafInstance::Prio3Count => Self::Count, - VdafInstance::Prio3Sum { bits } => Self::Sum(Sum { bits: Some(bits) }), - VdafInstance::Prio3Histogram { buckets } => Self::Histogram(Histogram { - buckets: Some(buckets), - }), - VdafInstance::Prio3CountVec { length } => Self::CountVec(CountVec { + AggregatorVdaf::Prio3Count => Self::Count, + AggregatorVdaf::Prio3Sum { bits } => Self::Sum(Sum { bits: Some(bits) }), + AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { buckets }) => { + Self::Histogram(Histogram::Continuous(ContinuousBuckets { + buckets: Some(buckets), + })) + } + AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { length }) => { + Self::Histogram(Histogram::Opaque(BucketLength { length })) + } + AggregatorVdaf::Prio3CountVec { length } => Self::CountVec(CountVec { length: Some(length), }), - VdafInstance::Prio3SumVec { bits, length } => Self::SumVec(SumVec { + AggregatorVdaf::Prio3SumVec { bits, length } => Self::SumVec(SumVec { length: Some(length), bits: Some(bits), }), @@ -45,28 +97,6 @@ impl From for Vdaf { } } -impl From for VdafInstance { - fn from(value: Vdaf) -> Self { - match value { - Vdaf::Count => Self::Prio3Count, - Vdaf::Histogram(Histogram { buckets }) => Self::Prio3Histogram { - buckets: buckets.unwrap(), - }, - Vdaf::Sum(Sum { bits }) => Self::Prio3Sum { - bits: bits.unwrap(), - }, - Vdaf::CountVec(CountVec { length }) => Self::Prio3CountVec { - length: length.unwrap(), - }, - Vdaf::SumVec(SumVec { length, bits }) => Self::Prio3SumVec { - bits: bits.unwrap(), - length: length.unwrap(), - }, - Vdaf::Unrecognized => unreachable!(), - } - } -} - #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum QueryType { TimeInterval, @@ -142,7 +172,7 @@ pub struct TaskCreate { pub collector_auth_token: Option, pub peer_aggregator_endpoint: Url, pub query_type: QueryType, - pub vdaf: VdafInstance, + pub vdaf: AggregatorVdaf, pub role: Role, pub max_batch_query_count: u64, pub task_expiration: Option, @@ -169,7 +199,7 @@ impl TaskCreate { new_task.leader_aggregator.dap_url.clone().into() }, query_type: new_task.max_batch_size.into(), - vdaf: new_task.vdaf.clone().into(), + vdaf: new_task.aggregator_vdaf.clone(), role, max_batch_query_count: 1, task_expiration: new_task.expiration.map(|expiration| { @@ -193,7 +223,7 @@ pub struct TaskResponse { pub task_id: TaskId, pub peer_aggregator_endpoint: Url, pub query_type: QueryType, - pub vdaf: VdafInstance, + pub vdaf: AggregatorVdaf, pub role: Role, pub vdaf_verify_key: String, pub max_batch_query_count: u64, @@ -236,6 +266,8 @@ pub struct AggregatorApiConfig { pub role: AggregatorRole, pub vdafs: VdafNameSet, pub query_types: QueryTypeNameSet, + #[serde(default)] + pub protocol: Protocol, } #[cfg(test)] diff --git a/src/entity.rs b/src/entity.rs index 507348031..68219cdc9 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -5,6 +5,7 @@ mod codec; pub mod hpke_config; mod json; pub mod membership; +pub mod protocol; pub mod queue; pub mod session; pub mod task; @@ -28,6 +29,7 @@ pub use hpke_config::{ pub use membership::{ Column as MembershipColumn, CreateMembership, Entity as Memberships, Model as Membership, }; +pub use protocol::{Protocol, UnrecognizedProtocol}; pub use session::{Column as SessionColumn, Entity as Sessions, Model as Session}; pub use task::{ Column as TaskColumn, Entity as Tasks, Model as Task, NewTask, ProvisionableTask, UpdateTask, diff --git a/src/entity/aggregator.rs b/src/entity/aggregator.rs index c8b7da136..9efabf857 100644 --- a/src/entity/aggregator.rs +++ b/src/entity/aggregator.rs @@ -15,6 +15,7 @@ use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; +pub use super::protocol::{Protocol, UnrecognizedProtocol}; pub use new_aggregator::NewAggregator; pub use query_type_name::{QueryTypeName, QueryTypeNameSet}; pub use role::{Role, UnrecognizedRole}; @@ -42,6 +43,7 @@ pub struct Model { pub is_first_party: bool, pub query_types: Json, pub vdafs: Json, + pub protocol: Protocol, #[serde(skip)] pub encrypted_bearer_token: Vec, } diff --git a/src/entity/aggregator/new_aggregator.rs b/src/entity/aggregator/new_aggregator.rs index dc7484744..1025d697c 100644 --- a/src/entity/aggregator/new_aggregator.rs +++ b/src/entity/aggregator/new_aggregator.rs @@ -102,6 +102,7 @@ impl NewAggregator { is_first_party: account.is_none() && self.is_first_party.unwrap_or(true), query_types: aggregator_config.query_types.into(), vdafs: aggregator_config.vdafs.into(), + protocol: aggregator_config.protocol, } .into_active_model()) } diff --git a/src/entity/protocol.rs b/src/entity/protocol.rs new file mode 100644 index 000000000..6ba01b078 --- /dev/null +++ b/src/entity/protocol.rs @@ -0,0 +1,67 @@ +use rand::{distributions::Standard, prelude::Distribution}; +use sea_orm::{DeriveActiveEnum, EnumIter}; +use serde::{Deserialize, Serialize}; +use std::{ + error::Error, + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize, Default, +)] +#[sea_orm(rs_type = "String", db_type = "String(None)")] +pub enum Protocol { + #[sea_orm(string_value = "DAP-04")] + #[serde(rename = "DAP-04")] + #[default] + Dap04, + + #[sea_orm(string_value = "DAP-05")] + #[serde(rename = "DAP-05")] + Dap05, +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> Protocol { + if rng.gen() { + Protocol::Dap04 + } else { + Protocol::Dap05 + } + } +} + +impl Display for Protocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl AsRef for Protocol { + fn as_ref(&self) -> &str { + match self { + Self::Dap04 => "DAP-04", + Self::Dap05 => "DAP-05", + } + } +} + +#[derive(Debug)] +pub struct UnrecognizedProtocol(String); +impl Display for UnrecognizedProtocol { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{} was not a recognized protocol", self.0)) + } +} +impl Error for UnrecognizedProtocol {} +impl FromStr for Protocol { + type Err = UnrecognizedProtocol; + fn from_str(s: &str) -> Result { + match &*s.to_lowercase() { + "dap-04" => Ok(Self::Dap04), + "dap-05" => Ok(Self::Dap05), + unrecognized => Err(UnrecognizedProtocol(unrecognized.to_string())), + } + } +} diff --git a/src/entity/task.rs b/src/entity/task.rs index ef22094a9..fa4d988ae 100644 --- a/src/entity/task.rs +++ b/src/entity/task.rs @@ -22,6 +22,7 @@ pub use new_task::NewTask; mod update_task; pub use update_task::UpdateTask; mod provisionable_task; +use super::Protocol; pub use provisionable_task::{ProvisionableTask, TaskProvisioningError}; use super::json::Json; @@ -48,6 +49,7 @@ pub struct Model { pub leader_aggregator_id: Uuid, pub helper_aggregator_id: Uuid, pub hpke_config_id: Uuid, + pub protocol: Protocol, } impl Model { diff --git a/src/entity/task/new_task.rs b/src/entity/task/new_task.rs index 96f52a667..6103b4fd5 100644 --- a/src/entity/task/new_task.rs +++ b/src/entity/task/new_task.rs @@ -1,6 +1,6 @@ use super::*; use crate::{ - clients::aggregator_client::api_types::QueryType, + clients::aggregator_client::api_types::{AggregatorVdaf, QueryType}, entity::{aggregator::Role, Account, Aggregator, Aggregators, HpkeConfig, HpkeConfigColumn}, handler::Error, }; @@ -131,7 +131,7 @@ impl NewTask { account: &Account, db: &impl ConnectionTrait, errors: &mut ValidationErrors, - ) -> Option<(Aggregator, Aggregator)> { + ) -> Option<(Aggregator, Aggregator, Protocol)> { let leader = load_aggregator(account, self.leader_aggregator_id.as_deref(), db) .await .ok() @@ -168,6 +168,14 @@ impl NewTask { ); } + let resolved_protocol = if leader.protocol == helper.protocol { + leader.protocol + } else { + errors.add("leader_aggregator_id", ValidationError::new("protocol")); + errors.add("helper_aggregator_id", ValidationError::new("protocol")); + return None; + }; + if leader.role == Role::Helper { errors.add("leader_aggregator_id", ValidationError::new("role")) } @@ -177,7 +185,7 @@ impl NewTask { } if errors.is_empty() { - Some((leader, helper)) + Some((leader, helper, resolved_protocol)) } else { None } @@ -187,22 +195,48 @@ impl NewTask { &self, leader: &Aggregator, helper: &Aggregator, + protocol: &Protocol, errors: &mut ValidationErrors, - ) { + ) -> Option { let Some(vdaf) = self.vdaf.as_ref() else { - return; + return None; }; + let name = vdaf.name(); - if leader.vdafs.contains(&name) && helper.vdafs.contains(&name) { - return; - } - if let ValidationErrorsKind::Struct(errors) = errors - .errors_mut() - .entry("vdaf") - .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))) - { - errors.add("type", ValidationError::new("not-supported")); + let aggregator_vdaf = match vdaf.representation_for_protocol(protocol) { + Ok(vdaf) => vdaf, + Err(e) => { + let errors = errors.errors_mut().entry("vdaf").or_insert_with(|| { + ValidationErrorsKind::Struct(Box::new(ValidationErrors::new())) + }); + match errors { + ValidationErrorsKind::Struct(errors) => { + errors.errors_mut().extend(e.into_errors()) + } + other => *other = ValidationErrorsKind::Struct(Box::new(e)), + }; + return None; + } + }; + + if !leader.vdafs.contains(&name) || !helper.vdafs.contains(&name) { + let errors = errors + .errors_mut() + .entry("vdaf") + .or_insert_with(|| ValidationErrorsKind::Struct(Box::new(ValidationErrors::new()))); + match errors { + ValidationErrorsKind::Struct(errors) => { + errors.add("type", ValidationError::new("not-supported")); + } + other => { + let mut e = ValidationErrors::new(); + e.add("type", ValidationError::new("not-supported")); + *other = ValidationErrorsKind::Struct(Box::new(e)); + } + }; } + + Some(aggregator_vdaf) } fn validate_query_type_is_supported( @@ -226,10 +260,13 @@ impl NewTask { self.validate_min_lte_max(&mut errors); let hpke_config = self.validate_hpke_config(&account, db, &mut errors).await; let aggregators = self.validate_aggregators(&account, db, &mut errors).await; - if let Some((leader, helper)) = aggregators.as_ref() { - self.validate_vdaf_is_supported(leader, helper, &mut errors); + + let aggregator_vdaf = if let Some((leader, helper, protocol)) = aggregators.as_ref() { self.validate_query_type_is_supported(leader, helper, &mut errors); - } + self.validate_vdaf_is_supported(leader, helper, protocol, &mut errors) + } else { + None + }; if errors.is_empty() { // Unwrap safety: All of these unwraps below have previously @@ -238,7 +275,7 @@ impl NewTask { // disharmonious combination of Validate and the fact that we // need to use options for all fields so serde doesn't bail on // the first error. - let (leader_aggregator, helper_aggregator) = aggregators.unwrap(); + let (leader_aggregator, helper_aggregator, protocol) = aggregators.unwrap(); let (vdaf_verify_key, id) = generate_vdaf_verify_key_and_expected_task_id(); @@ -250,12 +287,14 @@ impl NewTask { leader_aggregator, helper_aggregator, vdaf: self.vdaf.clone().unwrap(), + aggregator_vdaf: aggregator_vdaf.unwrap(), min_batch_size: self.min_batch_size.unwrap(), max_batch_size: self.max_batch_size, expiration: self.expiration, time_precision_seconds: self.time_precision_seconds.unwrap(), hpke_config: hpke_config.unwrap(), aggregator_auth_token: None, + protocol, }) } else { Err(errors) diff --git a/src/entity/task/provisionable_task.rs b/src/entity/task/provisionable_task.rs index c9584152d..91292d4eb 100644 --- a/src/entity/task/provisionable_task.rs +++ b/src/entity/task/provisionable_task.rs @@ -1,6 +1,6 @@ use super::{ActiveModel, *}; use crate::{ - clients::aggregator_client::api_types::AuthenticationToken, + clients::aggregator_client::api_types::{AggregatorVdaf, AuthenticationToken}, entity::{Account, Aggregator, HpkeConfig, Task}, handler::Error, Crypter, @@ -24,12 +24,14 @@ pub struct ProvisionableTask { pub leader_aggregator: Aggregator, pub helper_aggregator: Aggregator, pub vdaf: Vdaf, + pub aggregator_vdaf: AggregatorVdaf, pub min_batch_size: u64, pub max_batch_size: Option, pub expiration: Option, pub time_precision_seconds: u64, pub hpke_config: HpkeConfig, pub aggregator_auth_token: Option, + pub protocol: Protocol, } fn assert_same( @@ -56,7 +58,7 @@ impl ProvisionableTask { .create_task(self) .await?; - assert_same(&self.vdaf, &response.vdaf.clone().into(), "vdaf")?; + assert_same(&self.aggregator_vdaf, &response.vdaf, "vdaf")?; assert_same( self.min_batch_size, response.min_batch_size, @@ -115,6 +117,7 @@ impl ProvisionableTask { leader_aggregator_id: self.leader_aggregator.id, helper_aggregator_id: self.helper_aggregator.id, hpke_config_id: self.hpke_config.id, + protocol: self.protocol, } .into_active_model()) } diff --git a/src/entity/task/vdaf.rs b/src/entity/task/vdaf.rs index 6d8847919..fe612a2f8 100644 --- a/src/entity/task/vdaf.rs +++ b/src/entity/task/vdaf.rs @@ -1,31 +1,106 @@ -use crate::entity::aggregator::VdafName; +use crate::{ + clients::aggregator_client::api_types::{AggregatorVdaf, HistogramType}, + entity::{aggregator::VdafName, Protocol}, +}; use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, hash::Hash}; use validator::{Validate, ValidationError, ValidationErrors}; -#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] -pub struct Histogram { - #[validate(required, custom = "strictly_increasing")] - pub buckets: Option>, +#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] +#[serde(untagged)] +pub enum Histogram { + Opaque(BucketLength), + Categorical(CategoricalBuckets), + Continuous(ContinuousBuckets), } -fn strictly_increasing(buckets: &Vec) -> Result<(), ValidationError> { - let mut last_bucket = None; - for bucket in buckets { - let bucket = *bucket; - match last_bucket { - Some(last_bucket) if last_bucket == bucket => { - return Err(ValidationError::new("unique")); +impl Histogram { + pub fn length(&self) -> u64 { + match self { + Histogram::Categorical(CategoricalBuckets { + buckets: Some(buckets), + }) => buckets.len() as u64, + Histogram::Continuous(ContinuousBuckets { + buckets: Some(buckets), + }) => buckets.len() as u64, + Histogram::Opaque(BucketLength { length }) => *length, + _ => 0, + } + } + + fn representation_for_protocol( + &self, + protocol: &Protocol, + ) -> Result { + match (protocol, self) { + (Protocol::Dap05, histogram) => { + Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Opaque { + length: histogram.length(), + })) } - Some(last_bucket) if last_bucket > bucket => { - return Err(ValidationError::new("sorted")); + ( + Protocol::Dap04, + Self::Continuous(ContinuousBuckets { + buckets: Some(buckets), + }), + ) => Ok(AggregatorVdaf::Prio3Histogram(HistogramType::Buckets { + buckets: buckets.clone(), + })), + + (Protocol::Dap04, Self::Categorical(_)) => { + let mut errors = ValidationErrors::new(); + errors.add("buckets", ValidationError::new("must-be-numbers")); + Err(errors) } - _ => { - last_bucket = Some(bucket); + (Protocol::Dap04, _) => { + let mut errors = ValidationErrors::new(); + errors.add("buckets", ValidationError::new("required")); + Err(errors) } } } +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +pub struct ContinuousBuckets { + #[validate(required, length(min = 1), custom = "increasing", custom = "unique")] + pub buckets: Option>, +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] +pub struct CategoricalBuckets { + #[validate(required, length(min = 1), custom = "unique")] + pub buckets: Option>, +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq, Copy)] +pub struct BucketLength { + #[validate(range(min = 1))] + pub length: u64, +} + +fn unique(buckets: &[T]) -> Result<(), ValidationError> { + if buckets.len() == buckets.iter().collect::>().len() { + Ok(()) + } else { + Err(ValidationError::new("unique")) + } +} + +fn increasing(buckets: &[u64]) -> Result<(), ValidationError> { + let Some(mut last) = buckets.first().copied() else { + return Ok(()); + }; + + for bucket in &buckets[1..] { + if *bucket >= last { + last = *bucket; + } else { + return Err(ValidationError::new("sorted")); + } + } Ok(()) } @@ -83,13 +158,37 @@ impl Vdaf { Vdaf::Unrecognized => VdafName::Other("unsupported".into()), } } + + pub fn representation_for_protocol( + &self, + protocol: &Protocol, + ) -> Result { + match self { + Self::Histogram(histogram) => histogram.representation_for_protocol(protocol), + Self::Count => Ok(AggregatorVdaf::Prio3Count), + Self::Sum(Sum { bits: Some(bits) }) => Ok(AggregatorVdaf::Prio3Sum { bits: *bits }), + Self::SumVec(SumVec { + length: Some(length), + bits: Some(bits), + }) => Ok(AggregatorVdaf::Prio3SumVec { + bits: *bits, + length: *length, + }), + Self::CountVec(CountVec { + length: Some(length), + }) => Ok(AggregatorVdaf::Prio3CountVec { length: *length }), + _ => Err(ValidationErrors::new()), + } + } } impl Validate for Vdaf { fn validate(&self) -> Result<(), ValidationErrors> { match self { Vdaf::Count => Ok(()), - Vdaf::Histogram(h) => h.validate(), + Vdaf::Histogram(Histogram::Continuous(buckets)) => buckets.validate(), + Vdaf::Histogram(Histogram::Categorical(buckets)) => buckets.validate(), + Vdaf::Histogram(Histogram::Opaque(length)) => length.validate(), Vdaf::Sum(s) => s.validate(), Vdaf::SumVec(sv) => sv.validate(), Vdaf::CountVec(cv) => cv.validate(), @@ -108,15 +207,15 @@ mod tests { use crate::test::assert_errors; #[test] - fn validate_histogram() { - assert!(Histogram { + fn validate_continuous_histogram() { + assert!(ContinuousBuckets { buckets: Some(vec![0, 1, 2]) } .validate() .is_ok()); assert_errors( - Histogram { + ContinuousBuckets { buckets: Some(vec![0, 2, 1]), }, "buckets", @@ -124,11 +223,28 @@ mod tests { ); assert_errors( - Histogram { + ContinuousBuckets { buckets: Some(vec![0, 0, 2]), }, "buckets", &["unique"], ); } + + #[test] + fn validate_categorical_histogram() { + assert!(CategoricalBuckets { + buckets: Some(vec!["a".into(), "b".into()]) + } + .validate() + .is_ok()); + + assert_errors( + CategoricalBuckets { + buckets: Some(vec!["a".into(), "a".into()]), + }, + "buckets", + &["unique"], + ); + } } diff --git a/test-support/src/fixtures.rs b/test-support/src/fixtures.rs index 233228e1a..00b99f8f7 100644 --- a/test-support/src/fixtures.rs +++ b/test-support/src/fixtures.rs @@ -114,6 +114,7 @@ pub async fn task(app: &DivviupApi, account: &Account) -> Task { leader_aggregator_id: leader_aggregator.id, helper_aggregator_id: helper_aggregator.id, hpke_config_id: hpke_config.id, + protocol: Protocol::Dap05, } .into_active_model() .insert(app.db()) @@ -160,6 +161,7 @@ pub async fn aggregator(app: &DivviupApi, account: Option<&Account>) -> Aggregat role: Role::Either, query_types: Default::default(), vdafs: Default::default(), + protocol: Protocol::Dap05, } .into_active_model() .insert(app.db()) diff --git a/tests/vdaf.rs b/tests/vdaf.rs new file mode 100644 index 000000000..ce625588c --- /dev/null +++ b/tests/vdaf.rs @@ -0,0 +1,48 @@ +use divviup_api::entity::task::vdaf::Vdaf; +use test_support::{assert_eq, test, *}; +#[test] +pub fn histogram_representations() { + let scenarios = [ + ( + json!({"type": "histogram", "buckets": ["a", "b", "c"]}), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length": 3}})), + ), + ( + json!({"type": "histogram", "buckets": ["a", "b", "c"]}), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params":{}}]})), + ), + ( + json!({"type": "histogram", "buckets": [1,2,3]}), + Protocol::Dap04, + Ok(json!({ "Prio3Histogram":{"buckets": [1,2,3]}})), + ), + ( + json!({"type": "histogram", "buckets": [1,2,3]}), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length" : 3 }})), + ), + ( + json!({"type": "histogram", "length": 3 }), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "required", "message": null, "params":{}}]})), + ), + ( + json!({"type": "histogram", "length": 3 }), + Protocol::Dap05, + Ok(json!({"Prio3Histogram":{ "length" : 3} })), + ), + ]; + + for (input, protocol, output) in scenarios { + let vdaf: Vdaf = serde_json::from_value(input.clone()).unwrap(); + assert_eq!( + output, + vdaf.representation_for_protocol(&protocol) + .map(|o| serde_json::to_value(o).unwrap()) + .map_err(|e| serde_json::to_value(e).unwrap()), + "{vdaf:?} {input} {protocol}" + ); + } +} From 22d230dd8a23565ce189ce61fd1c413787b51492 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 23 Aug 2023 18:10:00 -0700 Subject: [PATCH 2/5] remove protocol from task --- client/src/task.rs | 2 - migration/src/lib.rs | 4 +- ...817_192017_add_protocol_to_aggregators.rs} | 37 ------------------- src/entity.rs | 4 +- src/entity/aggregator.rs | 3 +- src/entity/{ => aggregator}/protocol.rs | 0 src/entity/task.rs | 2 - src/entity/task/new_task.rs | 4 +- src/entity/task/provisionable_task.rs | 3 +- test-support/src/fixtures.rs | 1 - 10 files changed, 9 insertions(+), 51 deletions(-) rename migration/src/{m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs => m20230817_192017_add_protocol_to_aggregators.rs} (58%) rename src/entity/{ => aggregator}/protocol.rs (100%) diff --git a/client/src/task.rs b/client/src/task.rs index 6598ca7e6..b21116223 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -1,4 +1,3 @@ -use crate::Protocol; use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; @@ -23,7 +22,6 @@ pub struct Task { pub leader_aggregator_id: Uuid, pub helper_aggregator_id: Uuid, pub hpke_config_id: Uuid, - pub protocol: Protocol, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 01953bc86..96095f251 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -17,7 +17,7 @@ mod m20230703_201332_add_additional_fields_to_api_tokens; mod m20230725_220134_add_vdafs_and_query_types_to_aggregators; mod m20230731_181722_rename_aggregator_bearer_token; mod m20230808_204859_create_hpke_config; -mod m20230817_192017_add_dap_versions_to_tasks_and_aggregators; +mod m20230817_192017_add_protocol_to_aggregators; pub struct Migrator; @@ -42,7 +42,7 @@ impl MigratorTrait for Migrator { Box::new(m20230725_220134_add_vdafs_and_query_types_to_aggregators::Migration), Box::new(m20230731_181722_rename_aggregator_bearer_token::Migration), Box::new(m20230808_204859_create_hpke_config::Migration), - Box::new(m20230817_192017_add_dap_versions_to_tasks_and_aggregators::Migration), + Box::new(m20230817_192017_add_protocol_to_aggregators::Migration), ] } } diff --git a/migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs b/migration/src/m20230817_192017_add_protocol_to_aggregators.rs similarity index 58% rename from migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs rename to migration/src/m20230817_192017_add_protocol_to_aggregators.rs index b49294d59..470177890 100644 --- a/migration/src/m20230817_192017_add_dap_versions_to_tasks_and_aggregators.rs +++ b/migration/src/m20230817_192017_add_protocol_to_aggregators.rs @@ -30,30 +30,6 @@ impl MigrationTrait for Migration { ) .await?; - db.alter_table( - TableAlterStatement::new() - .table(Task::Table) - .add_column(ColumnDef::new(Task::Protocol).string().null()) - .to_owned(), - ) - .await?; - - db.exec_stmt( - Query::update() - .table(Task::Table) - .value(Task::Protocol, "DAP-04") - .to_owned(), - ) - .await?; - - db.alter_table( - TableAlterStatement::new() - .table(Task::Table) - .modify_column(ColumnDef::new(Task::Protocol).not_null()) - .to_owned(), - ) - .await?; - Ok(()) } @@ -65,13 +41,6 @@ impl MigrationTrait for Migration { .to_owned(), ) .await?; - db.alter_table( - TableAlterStatement::new() - .table(Task::Table) - .drop_column(Task::Protocol) - .to_owned(), - ) - .await?; Ok(()) } } @@ -81,9 +50,3 @@ enum Aggregator { Table, Protocol, } - -#[derive(DeriveIden)] -enum Task { - Table, - Protocol, -} diff --git a/src/entity.rs b/src/entity.rs index 68219cdc9..8e7807936 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -5,7 +5,6 @@ mod codec; pub mod hpke_config; mod json; pub mod membership; -pub mod protocol; pub mod queue; pub mod session; pub mod task; @@ -17,7 +16,7 @@ pub use account::{ }; pub use aggregator::{ Column as AggregatorColumn, Entity as Aggregators, Model as Aggregator, NewAggregator, - UpdateAggregator, + Protocol, Role, UnrecognizedProtocol, UnrecognizedRole, UpdateAggregator, }; pub use api_token::{ Column as ApiTokenColumn, Entity as ApiTokens, Model as ApiToken, UpdateApiToken, @@ -29,7 +28,6 @@ pub use hpke_config::{ pub use membership::{ Column as MembershipColumn, CreateMembership, Entity as Memberships, Model as Membership, }; -pub use protocol::{Protocol, UnrecognizedProtocol}; pub use session::{Column as SessionColumn, Entity as Sessions, Model as Session}; pub use task::{ Column as TaskColumn, Entity as Tasks, Model as Task, NewTask, ProvisionableTask, UpdateTask, diff --git a/src/entity/aggregator.rs b/src/entity/aggregator.rs index 9efabf857..7ac6acbe4 100644 --- a/src/entity/aggregator.rs +++ b/src/entity/aggregator.rs @@ -1,4 +1,5 @@ mod new_aggregator; +mod protocol; mod query_type_name; mod role; mod update_aggregator; @@ -15,8 +16,8 @@ use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use uuid::Uuid; -pub use super::protocol::{Protocol, UnrecognizedProtocol}; pub use new_aggregator::NewAggregator; +pub use protocol::{Protocol, UnrecognizedProtocol}; pub use query_type_name::{QueryTypeName, QueryTypeNameSet}; pub use role::{Role, UnrecognizedRole}; pub use update_aggregator::UpdateAggregator; diff --git a/src/entity/protocol.rs b/src/entity/aggregator/protocol.rs similarity index 100% rename from src/entity/protocol.rs rename to src/entity/aggregator/protocol.rs diff --git a/src/entity/task.rs b/src/entity/task.rs index fa4d988ae..ef22094a9 100644 --- a/src/entity/task.rs +++ b/src/entity/task.rs @@ -22,7 +22,6 @@ pub use new_task::NewTask; mod update_task; pub use update_task::UpdateTask; mod provisionable_task; -use super::Protocol; pub use provisionable_task::{ProvisionableTask, TaskProvisioningError}; use super::json::Json; @@ -49,7 +48,6 @@ pub struct Model { pub leader_aggregator_id: Uuid, pub helper_aggregator_id: Uuid, pub hpke_config_id: Uuid, - pub protocol: Protocol, } impl Model { diff --git a/src/entity/task/new_task.rs b/src/entity/task/new_task.rs index 6103b4fd5..4288a2e8c 100644 --- a/src/entity/task/new_task.rs +++ b/src/entity/task/new_task.rs @@ -1,7 +1,9 @@ use super::*; use crate::{ clients::aggregator_client::api_types::{AggregatorVdaf, QueryType}, - entity::{aggregator::Role, Account, Aggregator, Aggregators, HpkeConfig, HpkeConfigColumn}, + entity::{ + aggregator::Role, Account, Aggregator, Aggregators, HpkeConfig, HpkeConfigColumn, Protocol, + }, handler::Error, }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; diff --git a/src/entity/task/provisionable_task.rs b/src/entity/task/provisionable_task.rs index 91292d4eb..301f72bbd 100644 --- a/src/entity/task/provisionable_task.rs +++ b/src/entity/task/provisionable_task.rs @@ -1,7 +1,7 @@ use super::{ActiveModel, *}; use crate::{ clients::aggregator_client::api_types::{AggregatorVdaf, AuthenticationToken}, - entity::{Account, Aggregator, HpkeConfig, Task}, + entity::{Account, Aggregator, HpkeConfig, Protocol, Task}, handler::Error, Crypter, }; @@ -117,7 +117,6 @@ impl ProvisionableTask { leader_aggregator_id: self.leader_aggregator.id, helper_aggregator_id: self.helper_aggregator.id, hpke_config_id: self.hpke_config.id, - protocol: self.protocol, } .into_active_model()) } diff --git a/test-support/src/fixtures.rs b/test-support/src/fixtures.rs index 00b99f8f7..550eb3e3f 100644 --- a/test-support/src/fixtures.rs +++ b/test-support/src/fixtures.rs @@ -114,7 +114,6 @@ pub async fn task(app: &DivviupApi, account: &Account) -> Task { leader_aggregator_id: leader_aggregator.id, helper_aggregator_id: helper_aggregator.id, hpke_config_id: hpke_config.id, - protocol: Protocol::Dap05, } .into_active_model() .insert(app.db()) From 25784601c1cdab364c5cc46c4f84e17d12b2c186 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Thu, 24 Aug 2023 12:57:54 -0700 Subject: [PATCH 3/5] fix clap histogram length arg configuration --- cli/src/tasks.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cli/src/tasks.rs b/cli/src/tasks.rs index 54fdfc261..05c07c9f8 100644 --- a/cli/src/tasks.rs +++ b/cli/src/tasks.rs @@ -42,7 +42,7 @@ pub enum TaskAction { categorical_buckets: Option>, #[arg(long, value_delimiter = ',')] continuous_buckets: Option>, - #[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec"), ("vdaf", "histogram")]))] + #[arg(long, required_if_eq_any([("vdaf", "count_vec"), ("vdaf", "sum_vec")]))] length: Option, #[arg(long, required_if_eq_any([("vdaf", "sum"), ("vdaf", "sum_vec")]))] bits: Option, @@ -98,7 +98,7 @@ impl TaskAction { return Err(Error::Other("continuous-buckets, categorical-buckets, or length are required for histogram vdaf".into())); } _ => { - return Err(Error::Other("continuous-buckets, categorical-buckets, and length mutually exclusive".into())); + return Err(Error::Other("continuous-buckets, categorical-buckets, and length are mutually exclusive".into())); } } } From eee0d3708ecc933788c7dfa19a0373967e1253e9 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Thu, 24 Aug 2023 13:01:08 -0700 Subject: [PATCH 4/5] Update src/entity/task/vdaf.rs Co-authored-by: David Cook --- src/entity/task/vdaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/entity/task/vdaf.rs b/src/entity/task/vdaf.rs index fe612a2f8..a27f1aedc 100644 --- a/src/entity/task/vdaf.rs +++ b/src/entity/task/vdaf.rs @@ -22,7 +22,7 @@ impl Histogram { }) => buckets.len() as u64, Histogram::Continuous(ContinuousBuckets { buckets: Some(buckets), - }) => buckets.len() as u64, + }) => buckets.len() as u64 + 1, Histogram::Opaque(BucketLength { length }) => *length, _ => 0, } From d6be22622710ec6060a5104d6ffa678950c21885 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Thu, 24 Aug 2023 13:09:41 -0700 Subject: [PATCH 5/5] fix test to represent [1,2,3] having length 4 in DAP-05 this is because there's always an implicit top bucket greater than the highest bound --- tests/vdaf.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/vdaf.rs b/tests/vdaf.rs index ce625588c..112ea0936 100644 --- a/tests/vdaf.rs +++ b/tests/vdaf.rs @@ -5,33 +5,33 @@ pub fn histogram_representations() { let scenarios = [ ( json!({"type": "histogram", "buckets": ["a", "b", "c"]}), - Protocol::Dap05, - Ok(json!({"Prio3Histogram": {"length": 3}})), + Protocol::Dap04, + Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params": {}}]})), ), ( json!({"type": "histogram", "buckets": ["a", "b", "c"]}), - Protocol::Dap04, - Err(json!({"buckets": [{"code": "must-be-numbers", "message": null, "params":{}}]})), + Protocol::Dap05, + Ok(json!({"Prio3Histogram": {"length": 3}})), ), ( - json!({"type": "histogram", "buckets": [1,2,3]}), + json!({"type": "histogram", "buckets": [1, 2, 3]}), Protocol::Dap04, - Ok(json!({ "Prio3Histogram":{"buckets": [1,2,3]}})), + Ok(json!({"Prio3Histogram": {"buckets": [1, 2, 3]}})), ), ( - json!({"type": "histogram", "buckets": [1,2,3]}), + json!({"type": "histogram", "buckets": [1, 2, 3]}), Protocol::Dap05, - Ok(json!({"Prio3Histogram": {"length" : 3 }})), + Ok(json!({"Prio3Histogram": {"length": 4}})), ), ( - json!({"type": "histogram", "length": 3 }), + json!({"type": "histogram", "length": 3}), Protocol::Dap04, Err(json!({"buckets": [{"code": "required", "message": null, "params":{}}]})), ), ( - json!({"type": "histogram", "length": 3 }), + json!({"type": "histogram", "length": 3}), Protocol::Dap05, - Ok(json!({"Prio3Histogram":{ "length" : 3} })), + Ok(json!({"Prio3Histogram": {"length": 3}})), ), ];