From af77ebe1a41af5e60817c320a4d830abbe9cbef0 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 5 May 2025 16:39:42 +0200 Subject: [PATCH 01/29] API v0.5 draft Signed-off-by: Denis Varlakov --- Cargo.lock | 3 +- README.md | 2 +- .../random-generation-protocol/src/lib.rs | 63 +- round-based-derive/src/lib.rs | 42 +- round-based-tests/tests/rounds.rs | 96 +-- round-based/Cargo.toml | 2 + round-based/src/lib.rs | 22 +- round-based/src/mpc/mod.rs | 237 +++++++ round-based/src/mpc/party.rs | 233 +++++++ round-based/src/mpc/rounds_router/mod.rs | 517 ++++++++++++++ round-based/src/mpc/rounds_router/store.rs | 1 + round-based/src/{ => mpc}/runtime.rs | 20 +- round-based/src/party.rs | 178 ----- round-based/src/round/mod.rs | 47 ++ .../{rounds_router => round}/simple_store.rs | 25 +- round-based/src/rounds_router/mod.rs | 638 ------------------ round-based/src/rounds_router/store.rs | 132 ---- round-based/src/sim/async_env.rs | 65 +- round-based/src/sim/mod.rs | 16 +- round-based/src/state_machine/delivery.rs | 31 +- round-based/src/state_machine/mod.rs | 17 +- round-based/src/state_machine/runtime.rs | 7 +- .../tests/derive/compile-fail/wrong_usage.rs | 20 +- .../derive/compile-pass/correct_usage.rs | 6 +- 24 files changed, 1228 insertions(+), 1192 deletions(-) create mode 100644 round-based/src/mpc/mod.rs create mode 100644 round-based/src/mpc/party.rs create mode 100644 round-based/src/mpc/rounds_router/mod.rs create mode 100644 round-based/src/mpc/rounds_router/store.rs rename round-based/src/{ => mpc}/runtime.rs (77%) delete mode 100644 round-based/src/party.rs create mode 100644 round-based/src/round/mod.rs rename round-based/src/{rounds_router => round}/simple_store.rs (96%) delete mode 100644 round-based/src/rounds_router/mod.rs delete mode 100644 round-based/src/rounds_router/store.rs diff --git a/Cargo.lock b/Cargo.lock index a86357a..9c5ab72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -463,7 +463,7 @@ dependencies = [ [[package]] name = "round-based" -version = "0.4.0" +version = "0.4.1" dependencies = [ "anyhow", "futures", @@ -471,6 +471,7 @@ dependencies = [ "hex", "matches", "phantom-type", + "pin-project-lite", "rand", "rand_dev", "round-based-derive", diff --git a/README.md b/README.md index 5308be4..bec638e 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ the documentation of the protocol you're using), but usually they are: module * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync API, see `state_machine` module -* `derive` is needed to use `ProtocolMessage` proc macro +* `derive` is needed to use `ProtocolMsg` proc macro * `runtime-tokio` enables tokio-specific implementation of async runtime ## Join us in Discord! diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index 4f30ded..d1649e2 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -18,14 +18,13 @@ use alloc::{vec, vec::Vec}; use serde::{Deserialize, Serialize}; use sha2::{digest::Output, Digest, Sha256}; -use round_based::rounds_router::{ - simple_store::{RoundInput, RoundInputError}, - CompleteRoundError, RoundsRouter, +use round_based::{ + mpc::{Mpc, MpcExecution}, + MsgId, }; -use round_based::{Delivery, Mpc, MpcParty, MsgId, Outgoing, PartyIndex, ProtocolMessage, SinkExt}; /// Protocol message -#[derive(Clone, Debug, PartialEq, ProtocolMessage, Serialize, Deserialize)] +#[derive(round_based::ProtocolMsg, Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum Msg { /// Round 1 CommitMsg(CommitMsg), @@ -49,23 +48,19 @@ pub struct DecommitMsg { /// Carries out the randomness generation protocol pub async fn protocol_of_random_generation( - party: M, - i: PartyIndex, + mut mpc: M, + i: u16, n: u16, mut rng: R, -) -> Result<[u8; 32], Error> +) -> Result<[u8; 32], Error, M::SendErr>> where - M: Mpc, + M: Mpc, R: rand_core::RngCore, { - let MpcParty { delivery, .. } = party.into_party(); - let (incoming, mut outgoing) = delivery.split(); - // Define rounds - let mut rounds = RoundsRouter::::builder(); - let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); - let round2 = rounds.add_round(RoundInput::::broadcast(i, n)); - let mut rounds = rounds.listen(incoming); + let round1 = mpc.add_round(round_based::round::broadcast::(i, n)); + let round2 = mpc.add_round(round_based::round::broadcast::(i, n)); + let mut mpc = mpc.finish(); // --- The Protocol --- @@ -75,32 +70,22 @@ where // 2. Commit local randomness (broadcast m=sha256(randomness)) let commitment = Sha256::digest(local_randomness); - outgoing - .send(Outgoing::broadcast(Msg::CommitMsg(CommitMsg { - commitment, - }))) + mpc.send_broadcast(Msg::CommitMsg(CommitMsg { commitment })) .await .map_err(Error::Round1Send)?; // 3. Receive committed randomness from other parties - let commitments = rounds - .complete(round1) - .await - .map_err(Error::Round1Receive)?; + let commitments = mpc.complete(round1).await.map_err(Error::Round1Receive)?; // 4. Open local randomness - outgoing - .send(Outgoing::broadcast(Msg::DecommitMsg(DecommitMsg { - randomness: local_randomness, - }))) - .await - .map_err(Error::Round2Send)?; + mpc.send_broadcast(Msg::DecommitMsg(DecommitMsg { + randomness: local_randomness, + })) + .await + .map_err(Error::Round2Send)?; // 5. Receive opened local randomness from other parties, verify them, and output protocol randomness - let randomness = rounds - .complete(round2) - .await - .map_err(Error::Round2Receive)?; + let randomness = mpc.complete(round2).await.map_err(Error::Round2Receive)?; let mut guilty_parties = vec![]; let mut output = local_randomness; @@ -139,13 +124,13 @@ pub enum Error { Round1Send(#[source] SendErr), /// Couldn't receive a message in the first round #[error("receive messages at round 1")] - Round1Receive(#[source] CompleteRoundError), + Round1Receive(#[source] RecvErr), /// Couldn't send a message in the second round #[error("send a message at round 2")] Round2Send(#[source] SendErr), /// Couldn't receive a message in the second round #[error("receive messages at round 2")] - Round2Receive(#[source] CompleteRoundError), + Round2Receive(#[source] RecvErr), /// Some of the parties cheated #[error("malicious parties: {guilty_parties:?}")] @@ -155,11 +140,15 @@ pub enum Error { }, } +/// Error indicating that receiving message at certain round failed +pub type CompleteRoundErr = + round_based::mpc::CompleteRoundErr; + /// Blames a party in cheating during the protocol #[derive(Debug)] pub struct Blame { /// Index of the cheated party - pub guilty_party: PartyIndex, + pub guilty_party: u16, /// ID of the message that party sent in the first round pub commitment_msg: MsgId, /// ID of the message that party sent in the second round diff --git a/round-based-derive/src/lib.rs b/round-based-derive/src/lib.rs index 4b62029..c701e65 100644 --- a/round-based-derive/src/lib.rs +++ b/round-based-derive/src/lib.rs @@ -6,18 +6,18 @@ use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Token, Variant}; -#[proc_macro_derive(ProtocolMessage, attributes(protocol_message))] -pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +#[proc_macro_derive(ProtocolMsg, attributes(protocol_msg))] +pub fn protocol_msg(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let mut root = None; for attr in input.attrs { - if !attr.path.is_ident("protocol_message") { + if !attr.path.is_ident("protocol_msg") { continue; } if root.is_some() { - return quote_spanned! { attr.path.span() => compile_error!("#[protocol_message] attribute appears more than once"); }.into(); + return quote_spanned! { attr.path.span() => compile_error!("#[protocol_msg] attribute appears more than once"); }.into(); } let tokens = attr.tokens.into(); root = Some(parse_macro_input!(tokens as RootAttribute)); @@ -30,10 +30,10 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre let enum_data = match input.data { Data::Enum(e) => e, Data::Struct(s) => { - return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMessage");}.into() + return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMsg");}.into() } Data::Union(s) => { - return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMessage");}.into() + return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMsg");}.into() } }; @@ -46,15 +46,15 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre quote! { match *self {} } }; - let impl_protocol_message = quote! { - impl #impl_generics #root_path::ProtocolMessage for #name #ty_generics #where_clause { + let impl_protocol_msg = quote! { + impl #impl_generics #root_path::ProtocolMsg for #name #ty_generics #where_clause { fn round(&self) -> u16 { #round_method_impl } } }; - let impl_round_message = round_messages( + let impl_round_msg = round_msgs( &root_path, &name, &input.generics, @@ -62,8 +62,8 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre ); proc_macro::TokenStream::from(quote! { - #impl_protocol_message - #impl_round_message + #impl_protocol_msg + #impl_round_msg }) } @@ -73,11 +73,11 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator quote_spanned! { variant.ident.span() => - #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMessage"), + #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMsg"), }, Fields::Named(_) => quote_spanned! { variant.ident.span() => - #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMessage"), + #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMsg"), }, Fields::Unnamed(unnamed) => if unnamed.unnamed.len() == 1 { quote_spanned! { @@ -87,7 +87,7 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator - #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMessage"), + #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMsg"), } }, } @@ -99,7 +99,7 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator( +fn round_msgs<'v>( root_path: &RootPath, enum_name: &Ident, generics: &Generics, @@ -113,16 +113,16 @@ fn round_messages<'v>( let msg_type = &unnamed.unnamed[0].ty; quote_spanned! { variant.ident.span() => - impl #impl_generics #root_path::RoundMessage<#msg_type> for #enum_name #ty_generics #where_clause { + impl #impl_generics #root_path::RoundMsg<#msg_type> for #enum_name #ty_generics #where_clause { const ROUND: u16 = #i; - fn to_protocol_message(round_message: #msg_type) -> Self { - #enum_name::#variant_name(round_message) + fn to_protocol_msg(round_msg: #msg_type) -> Self { + #enum_name::#variant_name(round_msg) } - fn from_protocol_message(protocol_message: Self) -> Result<#msg_type, Self> { + fn from_protocol_msg(protocol_msg: Self) -> Result<#msg_type, Self> { #[allow(unreachable_patterns)] - match protocol_message { + match protocol_msg { #enum_name::#variant_name(msg) => Ok(msg), - _ => Err(protocol_message), + _ => Err(protocol_msg), } } } diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/rounds.rs index 00403d6..2fe81e3 100644 --- a/round-based-tests/tests/rounds.rs +++ b/round-based-tests/tests/rounds.rs @@ -1,6 +1,6 @@ use std::convert::Infallible; -use futures::{sink, stream, Sink, Stream}; +use futures::{sink, stream, SinkExt}; use hex_literal::hex; use matches::assert_matches; use rand_chacha::rand_core::SeedableRng; @@ -8,9 +8,10 @@ use rand_chacha::rand_core::SeedableRng; use random_generation_protocol::{ protocol_of_random_generation, CommitMsg, DecommitMsg, Error, Msg, }; -use round_based::rounds_router::errors::IoError; -use round_based::rounds_router::{simple_store::RoundInput, CompleteRoundError, RoundsRouter}; -use round_based::{Delivery, Incoming, MessageType, MpcParty, Outgoing}; +use round_based::{ + mpc::errors::{CompleteRoundError, WithIo}, + Incoming, MessageType, +}; const PARTY0_SEED: [u8; 32] = hex!("6772d079d5c984b3936a291e36b0d3dc6c474e36ed4afdfc973ef79a431ca870"); @@ -93,7 +94,9 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round1Receive(WithIo::Other( + CompleteRoundError::ProcessMsg(_) + ))) ) } @@ -137,7 +140,9 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round2Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round2Receive(WithIo::Other( + CompleteRoundError::ProcessMsg(_) + ))) ) } @@ -155,7 +160,9 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round1Receive(WithIo::Other( + CompleteRoundError::ProcessMsg(_) + ))) ) } @@ -291,7 +298,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { ]) .await; - assert_matches!(output, Err(Error::Round2Receive(CompleteRoundError::Io(_)))); + assert_matches!(output, Err(Error::Round2Receive(WithIo::Io(_)))); } #[tokio::test] @@ -333,7 +340,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { ]) .await; - assert_matches!(output, Err(Error::Round1Receive(CompleteRoundError::Io(_)))); + assert_matches!(output, Err(Error::Round1Receive(WithIo::Io(_)))); } #[tokio::test] @@ -366,33 +373,21 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { ]) .await; - assert_matches!( - output, - Err(Error::Round2Receive(CompleteRoundError::Io( - IoError::UnexpectedEof - ))) - ); -} - -#[tokio::test] -async fn all_non_completed_rounds_are_terminated_with_unexpected_eof_error_if_incoming_channel_suddenly_closed( -) { - let mut rounds = RoundsRouter::builder(); - let round1 = rounds.add_round(RoundInput::::new(0, 3, MessageType::P2P)); - let round2 = rounds.add_round(RoundInput::::new(0, 3, MessageType::P2P)); - let mut rounds = rounds.listen(stream::empty::, Infallible>>()); - - assert_matches!( - rounds.complete(round1).await, - Err(CompleteRoundError::Io(IoError::UnexpectedEof)) - ); - assert_matches!( - rounds.complete(round2).await, - Err(CompleteRoundError::Io(IoError::UnexpectedEof)) - ); + assert_matches!(output, Err(Error::Round2Receive(WithIo::UnexpectedEof))); } -async fn run_protocol(incomings: I) -> Result<[u8; 32], Error> +async fn run_protocol( + incomings: I, +) -> Result< + [u8; 32], + random_generation_protocol::Error< + round_based::mpc::errors::WithIo< + E, + round_based::mpc::errors::CompleteRoundError, + >, + E, + >, +> where I: IntoIterator, E>>, I::IntoIter: Send + 'static, @@ -400,38 +395,13 @@ where { let rng = rand_chacha::ChaCha8Rng::from_seed(PARTY0_SEED); - let party = MpcParty::connected(MockedDelivery::new(stream::iter(incomings), sink::drain())); + let party = round_based::mpc::connected_halves( + stream::iter(incomings), + sink::drain().sink_map_err(|e| match e {}), + ); protocol_of_random_generation(party, 0, 3, rng).await } -struct MockedDelivery { - incoming: I, - outgoing: O, -} - -impl MockedDelivery { - pub fn new(incoming: I, outgoing: O) -> Self { - Self { incoming, outgoing } - } -} - -impl Delivery for MockedDelivery -where - I: Stream, IErr>> + Send + Unpin + 'static, - O: Sink, Error = OErr> + Send + Unpin, - IErr: std::error::Error + Send + Sync + 'static, - OErr: std::error::Error + Send + Sync + 'static, -{ - type Send = O; - type Receive = I; - type SendError = OErr; - type ReceiveError = IErr; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.incoming, self.outgoing) - } -} - #[derive(Debug)] struct DummyError; diff --git a/round-based/Cargo.toml b/round-based/Cargo.toml index 3cea8fe..be42da1 100644 --- a/round-based/Cargo.toml +++ b/round-based/Cargo.toml @@ -26,6 +26,8 @@ round-based-derive = { version = "0.2", optional = true, path = "../round-based- tokio = { version = "1", features = ["rt"], optional = true } tokio-stream = { version = "0.1", features = ["sync"], optional = true } +pin-project-lite = "0.2" + [dev-dependencies] trybuild = "1" matches = "0.1" diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index a653be1..4cb649d 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -40,14 +40,15 @@ //! module //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync //! API, see [`state_machine`] module -//! * `derive` is needed to use [`ProtocolMessage`](macro@ProtocolMessage) proc macro +//! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro //! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](runtime) //! //! ## Join us in Discord! //! Feel free to reach out to us [in Discord](https://discordapp.com/channels/905194001349627914/1285268686147424388)! #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))] -#![forbid(unused_crate_dependencies, missing_docs)] +#![warn(unused_crate_dependencies, missing_docs)] +#![allow(async_fn_in_trait)] #![no_std] extern crate alloc; @@ -66,9 +67,8 @@ mod false_positives { } mod delivery; -pub mod party; -pub mod rounds_router; -pub mod runtime; +pub mod mpc; +pub mod round; #[cfg(feature = "state-machine")] pub mod state_machine; @@ -76,17 +76,15 @@ pub mod state_machine; pub mod sim; pub use self::delivery::*; +pub use self::mpc::MpcParty; #[doc(no_inline)] -pub use self::{ - party::{Mpc, MpcParty}, - rounds_router::{ProtocolMessage, RoundMessage}, -}; +pub use self::mpc::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; #[doc(hidden)] pub mod _docs; -/// Derives [`ProtocolMessage`] and [`RoundMessage`] traits +/// Derives [`ProtocolMsg`] and [`RoundMessage`] traits /// -/// See [`ProtocolMessage`] docs for more details +/// See [`ProtocolMsg`] docs for more details #[cfg(feature = "derive")] -pub use round_based_derive::ProtocolMessage; +pub use round_based_derive::ProtocolMsg; diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs new file mode 100644 index 0000000..a481c92 --- /dev/null +++ b/round-based/src/mpc/mod.rs @@ -0,0 +1,237 @@ +//! Party of MPC protocol +//! +//! [`MpcParty`] is party of MPC protocol, connected to network, ready to start carrying out the protocol. +//! +//! ```rust +//! use round_based::{Mpc, MpcParty, Delivery, PartyIndex}; +//! +//! # struct KeygenMsg; +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! async fn keygen(party: M, i: PartyIndex, n: u16) -> Result +//! where +//! M: Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! async fn connect() -> impl Delivery { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! +//! let delivery = connect().await; +//! let party = MpcParty::connected(delivery); +//! +//! # let (i, n) = (1, 3); +//! let keyshare = keygen(party, i, n).await?; +//! # Ok(()) } +//! ``` + +use crate::{round::RoundStore, Outgoing, PartyIndex}; + +mod party; +mod rounds_router; +pub mod runtime; + +pub use self::{ + party::{Halves, MpcParty}, + rounds_router::Round, +}; + +/// When something goes wrong +pub mod errors { + pub use super::{party::WithIo, rounds_router::errors::*}; +} + +/// Abstracts functionalities needed for MPC protocol execution +pub trait Mpc { + /// Protocol message + type Msg; + + /// Returned in [`Self::complete`] + type Exec: MpcExecution; + /// Error indicating that sending a message has failed + type SendErr; + + /// Registers a round + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg; + + /// Indicates that network setup is complete + /// + /// Once this method is called, no more rounds can be added, + /// but the protocol can receive and send messages. + fn finish(self) -> Self::Exec; +} + +/// Abstracts functionalities needed for MPC protocol execution +pub trait MpcExecution { + /// Witness that round was registered + /// + /// It is used to retrieve messages in [`MpcExecution::complete`]. + type Round; + + /// Protocol message + type Msg; + + /// Error indicating that completing a round has failed + type CompleteRoundErr; + /// Error indicating that sending a message has failed + type SendErr; + + /// Completes the round + async fn complete( + &mut self, + round: Self::Round, + ) -> Result> + where + R: RoundStore, + Self::Msg: RoundMsg; + + /// Sends a message + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; + + /// Sends a p2p message to another party + async fn send_p2p( + &mut self, + recipient: PartyIndex, + msg: Self::Msg, + ) -> Result<(), Self::SendErr> { + self.send(Outgoing::p2p(recipient, msg)).await + } + + /// Sends a broadcast message + async fn send_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::broadcast(msg)).await + } + + /// Yields execution + async fn yield_now(&self); +} + +/// Alias to `<::Exec as MpcExecution>::CompleteRoundErr` +pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRoundErr; + +/// Message of MPC protocol +/// +/// MPC protocols typically consist of several rounds, each round has differently typed message. +/// `ProtocolMsg` and [`RoundMsg`] traits are used to examine received message: `ProtocolMsg::round` +/// determines which round message belongs to, and then `RoundMessage` trait can be used to retrieve +/// actual round-specific message. +/// +/// You should derive these traits using proc macro (requires `derive` feature): +/// ```rust +/// use round_based::ProtocolMsg; +/// +/// #[derive(ProtocolMsg)] +/// pub enum Message { +/// Round1(Msg1), +/// Round2(Msg2), +/// // ... +/// } +/// +/// pub struct Msg1 { /* ... */ } +/// pub struct Msg2 { /* ... */ } +/// ``` +/// +/// This desugars into: +/// +/// ```rust +/// use round_based::rounds_router::{ProtocolMsg, RoundMessage}; +/// +/// pub enum Message { +/// Round1(Msg1), +/// Round2(Msg2), +/// // ... +/// } +/// +/// pub struct Msg1 { /* ... */ } +/// pub struct Msg2 { /* ... */ } +/// +/// impl ProtocolMsg for Message { +/// fn round(&self) -> u16 { +/// match self { +/// Message::Round1(_) => 1, +/// Message::Round2(_) => 2, +/// // ... +/// } +/// } +/// } +/// impl RoundMessage for Message { +/// const ROUND: u16 = 1; +/// fn to_protocol_msg(round_msg: Msg1) -> Self { +/// Message::Round1(round_msg) +/// } +/// fn from_protocol_msg(protocol_msg: Self) -> Result { +/// match protocol_msg { +/// Message::Round1(msg) => Ok(msg), +/// msg => Err(msg), +/// } +/// } +/// } +/// impl RoundMessage for Message { +/// const ROUND: u16 = 2; +/// fn to_protocol_msg(round_msg: Msg2) -> Self { +/// Message::Round2(round_message) +/// } +/// fn from_protocol_msg(protocol_msg: Self) -> Result { +/// match protocol_msg { +/// Message::Round2(msg) => Ok(msg), +/// msg => Err(msg), +/// } +/// } +/// } +/// ``` +pub trait ProtocolMsg: Sized { + /// Number of round this message originates from + fn round(&self) -> u16; +} + +/// Round message +/// +/// See [`ProtocolMsg`] trait documentation. +pub trait RoundMsg: ProtocolMsg { + /// Number of the round this message belongs to + const ROUND: u16; + + /// Converts round message into protocol message (never fails) + fn to_protocol_msg(round_msg: M) -> Self; + /// Extracts round message from protocol message + /// + /// Returns `Err(protocol_message)` if `protocol_message.round() != Self::ROUND`, otherwise + /// returns `Ok(round_message)` + fn from_protocol_msg(protocol_msg: Self) -> Result; +} + +/// Construct an [`MpcParty`] that can be used to carry out MPC protocol +/// +/// Accepts a channels with incoming and outgoing messages. +/// +/// Alias to [`MpcParty::connected`] +pub fn connected(delivery: D) -> MpcParty +where + M: ProtocolMsg + 'static, + D: futures_util::Stream, D::Error>> + Unpin, + D: futures_util::Sink> + Unpin, +{ + MpcParty::connected(delivery) +} + +/// Construct an [`MpcParty`] that can be used to carry out MPC protocol +/// +/// Accepts separately a channel for incoming and a channel for outgoing messages. +/// +/// Alias to [`MpcParty::connected_halves`] +pub fn connected_halves(incomings: In, outgoings: Out) -> MpcParty> +where + M: ProtocolMsg + 'static, + In: futures_util::Stream, Out::Error>> + Unpin, + Out: futures_util::Sink> + Unpin, +{ + MpcParty::connected_halves(incomings, outgoings) +} diff --git a/round-based/src/mpc/party.rs b/round-based/src/mpc/party.rs new file mode 100644 index 0000000..20b1a3d --- /dev/null +++ b/round-based/src/mpc/party.rs @@ -0,0 +1,233 @@ +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +use crate::{round::RoundStore, Incoming, Outgoing}; + +use super::{rounds_router, runtime, Mpc, MpcExecution, ProtocolMsg, RoundMsg}; + +/// MPC engine, carries out the protocol +/// +/// Can be constructed via [`MpcParty::connected`] or [`MpcParty::connected_halves`], which wraps +/// a channel of incoming and outgoing messages, and implements additional logic on top of this +/// to facilitate the MPC protocol execution, such as routing incoming messages between round +/// stores. +/// +/// Implements [`Mpc`] and [`MpcExecution`]. +pub struct MpcParty { + router: rounds_router::RoundsRouter, + io: D, + runtime: R, +} + +impl MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, +{ + /// Constructs [`MpcParty`] + pub fn connected(delivery: D) -> Self { + Self { + router: rounds_router::RoundsRouter::new(), + io: delivery, + runtime: runtime::DefaultRuntime::default(), + } + } +} + +impl MpcParty> +where + M: ProtocolMsg + 'static, + In: Stream, E>> + Unpin, + Out: Sink, Error = E> + Unpin, +{ + /// Constructs [`MpcParty`] + pub fn connected_halves(incomings: In, outgoings: Out) -> Self { + Self::connected(Halves::new(incomings, outgoings)) + } +} + +impl MpcParty { + /// Changes which async runtime to use + pub fn with_runtime(self, runtime: R) -> MpcParty { + MpcParty { + router: self.router, + io: self.io, + runtime, + } + } +} + +impl Mpc for MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Msg = M; + + type Exec = MpcParty; + + type SendErr = E; + + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg, + { + self.router.add_round(round) + } + + fn finish(self) -> Self::Exec { + MpcParty { + router: self.router, + io: self.io, + runtime: self.runtime, + } + } +} + +impl MpcExecution for MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, IoErr>> + Unpin, + D: Sink, Error = IoErr> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Round = rounds_router::Round; + + type Msg = M; + + type CompleteRoundErr = WithIo>; + + type SendErr = IoErr; + + async fn complete( + &mut self, + mut round: Self::Round, + ) -> Result> + where + R: RoundStore, + Self::Msg: RoundMsg, + { + // Check if round is already completed + round = match self.router.complete_round(round) { + Ok(output) => return output.map_err(WithIo::Other), + Err(w) => w, + }; + + // Round is not completed - we need more messages + loop { + let incoming = self + .io + .next() + .await + .ok_or(WithIo::UnexpectedEof)? + .map_err(WithIo::Io)?; + self.router + .received_msg(incoming) + .map_err(|err| WithIo::Other(err.into()))?; + + // Check if round was just completed + round = match self.router.complete_round(round) { + Ok(output) => return output.map_err(WithIo::Other), + Err(w) => w, + }; + } + } + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.io.send(msg).await + } + + async fn yield_now(&self) { + self.runtime.yield_now().await + } +} + +/// Error indicating that either `IoErr` occurred, or `OtherErr` +#[derive(Debug, thiserror::Error)] +pub enum WithIo { + /// IO error + #[error(transparent)] + Io(IoErr), + /// Unexpected EOF + #[error("unexpected eof")] + UnexpectedEof, + /// Other error + #[error(transparent)] + Other(OtherErr), +} + +pin_project_lite::pin_project! { + /// Merges a stream and a sink into one structure that implements both [`Stream`] and [`Sink`] + pub struct Halves { + #[pin] + incomings: In, + #[pin] + outgoings: Out, + } +} + +impl Halves { + /// Constructs `Halves` + pub fn new(incomings: In, outgoings: Out) -> Self { + Self { + incomings, + outgoings, + } + } + + /// Deconstructs back into halves + pub fn into_inner(self) -> (In, Out) { + (self.incomings, self.outgoings) + } +} + +impl Stream for Halves +where + In: Stream>, +{ + type Item = Result; + + fn poll_next( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.incomings.poll_next(cx) + } +} + +impl Sink for Halves +where + Out: Sink, +{ + type Error = E; + + fn poll_ready( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_ready(cx) + } + fn start_send(self: core::pin::Pin<&mut Self>, item: M) -> Result<(), Self::Error> { + let this = self.project(); + this.outgoings.start_send(item) + } + fn poll_flush( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_flush(cx) + } + fn poll_close( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_close(cx) + } +} diff --git a/round-based/src/mpc/rounds_router/mod.rs b/round-based/src/mpc/rounds_router/mod.rs new file mode 100644 index 0000000..09b5a1b --- /dev/null +++ b/round-based/src/mpc/rounds_router/mod.rs @@ -0,0 +1,517 @@ +//! Routes incoming MPC messages between rounds +//! +//! [`RoundsRouter`] is an essential building block of MPC protocol, it processes incoming messages, groups +//! them by rounds, and provides convenient API for retrieving received messages at certain round. +//! +//! ## Example +//! +//! ```rust +//! use round_based::{Mpc, MpcParty, ProtocolMsg, Delivery, PartyIndex}; +//! use round_based::rounds_router::{RoundsRouter, simple_store::{RoundInput, RoundMsgs}}; +//! +//! #[derive(ProtocolMsg)] +//! pub enum Msg { +//! Round1(Msg1), +//! Round2(Msg2), +//! } +//! +//! pub struct Msg1 { /* ... */ } +//! pub struct Msg2 { /* ... */ } +//! +//! pub async fn some_mpc_protocol(party: M, i: PartyIndex, n: u16) -> Result +//! where +//! M: Mpc, +//! { +//! let MpcParty{ delivery, .. } = party.into_party(); +//! +//! let (incomings, _outgoings) = delivery.split(); +//! +//! // Build `Rounds` +//! let mut rounds = RoundsRouter::builder(); +//! let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); +//! let round2 = rounds.add_round(RoundInput::::p2p(i, n)); +//! let mut rounds = rounds.listen(incomings); +//! +//! // Receive messages from round 1 +//! let msgs: RoundMsgs = rounds.complete(round1).await?; +//! +//! // ... process received messages +//! +//! // Receive messages from round 2 +//! let msgs = rounds.complete(round2).await?; +//! +//! // ... +//! # todo!() +//! } +//! # type Output = (); +//! # type Error = Box; +//! ``` + +use alloc::{boxed::Box, collections::BTreeMap}; +use core::{any::Any, mem}; + +use phantom_type::PhantomType; +use tracing::{error, trace_span, warn}; + +use crate::{round::RoundStore, Incoming, ProtocolMsg, RoundMsg}; + +/// Routes received messages between protocol rounds +pub struct RoundsRouter { + rounds: BTreeMap>>>, +} + +impl RoundsRouter +where + M: ProtocolMsg + 'static, +{ + pub fn new() -> Self { + Self { + rounds: Default::default(), + } + } + + /// Registers new round + /// + /// ## Panics + /// Panics if round `R` was already registered + pub fn add_round(&mut self, message_store: R) -> Round + where + R: RoundStore, + M: RoundMsg, + { + let overridden_round = self.rounds.insert( + M::ROUND, + Some(Box::new(ProcessRoundMessageImpl::new(message_store))), + ); + if overridden_round.is_some() { + panic!("round {} is overridden", M::ROUND); + } + Round { + _ph: PhantomType::new(), + } + } + + pub fn received_msg(&mut self, incoming: Incoming) -> Result<(), errors::UnregisteredRound> { + let msg_round_n = incoming.msg.round(); + let span = trace_span!( + "Round::received_msg", + round = %msg_round_n, + sender = %incoming.sender, + ty = ?incoming.msg_type + ); + let _guard = span.enter(); + + let message_round = match self.rounds.get_mut(&msg_round_n) { + Some(Some(round)) => round, + Some(None) => { + warn!("got message for the round that was already completed, ignoring it"); + return Ok(()); + } + None => { + return Err(errors::UnregisteredRound { + n: msg_round_n, + witness_provided: false, + }) + } + }; + if message_round.needs_more_messages().no() { + warn!("received message for the round that was already completed, ignoring it"); + return Ok(()); + } + message_round.process_message(incoming); + Ok(()) + } + + pub fn complete_round( + &mut self, + round: Round, + ) -> Result>, Round> + where + R: RoundStore, + M: RoundMsg, + { + let message_round = match self.rounds.get_mut(&M::ROUND) { + Some(Some(round)) => round, + Some(None) => { + return Ok(Err( + errors::Bug::RoundGoneButWitnessExists { n: M::ROUND }.into() + )); + } + None => { + return Ok(Err(errors::UnregisteredRound { + n: M::ROUND, + witness_provided: true, + } + .into())) + } + }; + if message_round.needs_more_messages().yes() { + return Err(round); + } + Ok(Self::retrieve_round_output::(message_round)) + } + + fn retrieve_round_output( + round: &mut Box>, + ) -> Result> + where + R: RoundStore, + M: RoundMsg, + { + match round.take_output() { + Ok(Ok(any)) => Ok(*any + .downcast::() + .or(Err(errors::Bug::MismatchedOutputType))?), + Ok(Err(any)) => Err(*any + .downcast::>() + .or(Err(errors::Bug::MismatchedErrorType))?), + Err(err) => Err(errors::Bug::TakeRoundResult(err).into()), + } + } +} + +/// A witness that round has been registered in the router +/// +/// Can be used later to claim messages received in this round +pub struct Round { + _ph: PhantomType, +} + +impl core::fmt::Debug for Round { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Round").finish_non_exhaustive() + } +} + +trait ProcessRoundMessage { + type Msg; + + /// Processes round message + /// + /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`, + /// otherwise calling this method is unexpected. + fn process_message(&mut self, msg: Incoming); + + /// Indicated whether the store needs more messages + /// + /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No` + /// then you need to take the round output by calling `.take_output()`. + fn needs_more_messages(&self) -> NeedsMoreMessages; + + /// Tries to obtain round output + /// + /// Can be called once `process_message()` returned `NeedMoreMessages::No`. + /// + /// Returns: + /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output` + /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError` + /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`] + #[allow(clippy::type_complexity)] + fn take_output(&mut self) -> Result, Box>, TakeOutputError>; +} + +#[derive(Debug, thiserror::Error)] +enum TakeOutputError { + #[error("output is already taken")] + AlreadyTaken, + #[error("output is not ready yet, more messages are needed")] + NotReady, +} + +enum ProcessRoundMessageImpl> { + InProgress { store: S, _ph: PhantomType }, + Completed(Result>), + Gone, +} + +impl> ProcessRoundMessageImpl { + pub fn new(store: S) -> Self { + if store.wants_more() { + Self::InProgress { + store, + _ph: Default::default(), + } + } else { + Self::Completed( + store + .output() + .map_err(|_| errors::ImproperRoundStore::StoreDidntOutput.into()), + ) + } + } +} + +impl ProcessRoundMessageImpl +where + S: RoundStore, + M: ProtocolMsg + RoundMsg, +{ + fn _process_message( + store: &mut S, + msg: Incoming, + ) -> Result<(), errors::CompleteRoundError> { + let msg = msg.try_map(M::from_protocol_msg).map_err(|msg| { + errors::Bug::MessageFromAnotherRound { + actual_number: msg.round(), + expected_round: M::ROUND, + } + })?; + + store + .add_message(msg) + .map_err(errors::CompleteRoundError::ProcessMsg)?; + Ok(()) + } +} + +impl ProcessRoundMessage for ProcessRoundMessageImpl +where + S: RoundStore, + M: ProtocolMsg + RoundMsg, +{ + type Msg = M; + + fn process_message(&mut self, msg: Incoming) { + let store = match self { + Self::InProgress { store, .. } => store, + _ => { + return; + } + }; + + match Self::_process_message(store, msg) { + Ok(()) => { + if store.wants_more() { + return; + } + + let store = match mem::replace(self, Self::Gone) { + Self::InProgress { store, .. } => store, + _ => { + *self = Self::Completed(Err(errors::Bug::IncoherentState { + expected: "InProgress", + justification: + "we checked at beginning of the function that `state` is InProgress", + }.into())); + return; + } + }; + + match store.output() { + Ok(output) => *self = Self::Completed(Ok(output)), + Err(_err) => { + *self = Self::Completed(Err( + errors::ImproperRoundStore::StoreDidntOutput.into() + )) + } + } + } + Err(err) => { + *self = Self::Completed(Err(err)); + } + } + } + + fn needs_more_messages(&self) -> NeedsMoreMessages { + match self { + Self::InProgress { .. } => NeedsMoreMessages::Yes, + _ => NeedsMoreMessages::No, + } + } + + fn take_output(&mut self) -> Result, Box>, TakeOutputError> { + match self { + Self::InProgress { .. } => return Err(TakeOutputError::NotReady), + Self::Gone => return Err(TakeOutputError::AlreadyTaken), + _ => (), + } + match mem::replace(self, Self::Gone) { + Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))), + Self::Completed(Err(err)) => Ok(Err(Box::new(err))), + _ => unreachable!("it's checked to be completed"), + } + } +} + +enum NeedsMoreMessages { + Yes, + No, +} + +#[allow(dead_code)] +impl NeedsMoreMessages { + pub fn yes(&self) -> bool { + matches!(self, Self::Yes) + } + pub fn no(&self) -> bool { + matches!(self, Self::No) + } +} + +/// When something goes wrong +pub mod errors { + use super::TakeOutputError; + + #[derive(Debug, thiserror::Error)] + #[error("received a message for unregistered round")] + pub(in crate::mpc) struct UnregisteredRound { + pub n: u16, + pub(super) witness_provided: bool, + } + + /// Error returned when processing incoming messages at certain round + /// + /// May indicate malicious behavior (e.g. adversary sent a message that aborts protocol execution) + /// or some misconfiguration of the protocol network (e.g. received a message from the round that + /// was not registered via [`Mpc::add_round`](crate::Mpc::add_round)). + #[derive(Debug, thiserror::Error)] + pub enum CompleteRoundError { + /// [`RoundStore`](crate::round::RoundStore) returned an error + /// + /// Refer to this rounds store documentation to understand why it could fail + #[error(transparent)] + ProcessMsg(ProcessErr), + + /// Router error + /// + /// Indicates that for some reason router was not able to process a message. This can be the case of: + /// - Router API misuse \ + /// E.g. when received a message from the round that was not registered in the router + /// - Improper [`RoundStore`](crate::round::RoundStore) implementation \ + /// Indicates that round store is not properly implemented and contains a flaw. \ + /// For instance, this error is returned when round store indicates that it doesn't need + /// any more messages ([`RoundStore::wants_more`](crate::round::RoundStore::wants_more) + /// returns `false`), but then it didn't output anything ([`RoundStore::output`](crate::round::RoundStore::output) + /// returns `Err(_)`) + /// - Bug in the router + /// + /// This error is always related to some implementation flaw or bug: either in the code that uses + /// the router, or in the round store implementation, or in the router itself. When implementation + /// is correct, this error never appears. Thus, it should not be possible for the adversary to "make + /// this error happen." + Router(RouterError), + } + + /// Router error + /// + /// Refer to [`CompleteRound::Router`] docs + #[derive(Debug, thiserror::Error)] + #[error(transparent)] + pub struct RouterError(Reason); + + #[derive(Debug, thiserror::Error)] + pub(super) enum Reason { + /// Router API has been misused + /// + /// For instance, this error is returned when protocol implementation does not register + /// certain round of the protocol, but then a message from this round is received. In + /// this case, router doesn't have anywhere to route the message to, so an [`ApiMisuse`] + /// error is returned. + #[error("api misuse")] + ApiMisuse(#[source] ApiMisuse), + /// Improper [`RoundStore`](crate::round::RoundStore) implementation + /// + /// For instance, this error is returned when round store indicates that it doesn't need + /// any more messages ([`RoundStore::wants_more`](crate::round::RoundStore::wants_more) + /// returns `false`), but then it didn't output anything ([`RoundStore::output`](crate::round::RoundStore::output) + /// returns `Err(_)`) + #[error("improper round store")] + ImproperRoundStore(#[source] ImproperRoundStore), + /// Indicates that there's a bug in the router implementation + #[error("bug (please, open an issue)")] + Bug(#[source] Bug), + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum ApiMisuse { + #[error(transparent)] + UnregisteredRound(#[from] UnregisteredRound), + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum ImproperRoundStore { + /// Store indicated that it received enough messages but didn't output + /// + /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`. + #[error("store didn't output")] + StoreDidntOutput, + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum Bug { + #[error("round is gone, but witness exists")] + RoundGoneButWitnessExists { n: u16 }, + #[error( + "message originates from another round: we process messages from round \ + {expected_round}, got message from round {actual_number}" + )] + MessageFromAnotherRound { + expected_round: u16, + actual_number: u16, + }, + #[error("state is incoherent, it's expected to be {expected}: {justification}")] + IncoherentState { + expected: &'static str, + justification: &'static str, + }, + #[error("take round result")] + TakeRoundResult(#[source] TakeOutputError), + #[error("mismatched output type")] + MismatchedOutputType, + #[error("mismatched error type")] + MismatchedErrorType, + } + + macro_rules! impl_round_complete_from { + ($(|$err:ident: $err_ty:ty| $err_fn:expr),+$(,)?) => {$( + impl From<$err_ty> for CompleteRoundError { + fn from($err: $err_ty) -> Self { + $err_fn + } + } + )+}; + } + + impl_round_complete_from! { + |err: ApiMisuse| CompleteRoundError::Router(RouterError(Reason::ApiMisuse(err))), + |err: ImproperRoundStore| CompleteRoundError::Router(RouterError(Reason::ImproperRoundStore(err))), + |err: Bug| CompleteRoundError::Router(RouterError(Reason::Bug(err))), + |err: UnregisteredRound| ApiMisuse::UnregisteredRound(err).into(), + } +} + +#[cfg(test)] +mod tests { + struct Store; + + #[derive(crate::ProtocolMsg)] + #[protocol_msg(root = crate)] + enum FakeProtocolMsg { + R1(Msg1), + } + struct Msg1; + + impl super::RoundStore for Store { + type Msg = Msg1; + type Output = (); + type Error = core::convert::Infallible; + + fn add_message(&mut self, _msg: crate::Incoming) -> Result<(), Self::Error> { + Ok(()) + } + fn wants_more(&self) -> bool { + false + } + fn output(self) -> Result { + Ok(()) + } + } + + #[test] + fn complete_round_that_expects_no_messages() { + let mut rounds = super::RoundsRouter::::new(); + let round1 = rounds.add_round(Store); + + rounds.complete_round(round1).unwrap().unwrap(); + } +} diff --git a/round-based/src/mpc/rounds_router/store.rs b/round-based/src/mpc/rounds_router/store.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/round-based/src/mpc/rounds_router/store.rs @@ -0,0 +1 @@ + diff --git a/round-based/src/runtime.rs b/round-based/src/mpc/runtime.rs similarity index 77% rename from round-based/src/runtime.rs rename to round-based/src/mpc/runtime.rs index 7daa68e..1852186 100644 --- a/round-based/src/runtime.rs +++ b/round-based/src/mpc/runtime.rs @@ -8,15 +8,12 @@ /// Abstracts async runtime like [tokio]. Currently only exposes a [yield_now](Self::yield_now) /// function. pub trait AsyncRuntime { - /// Future type returned by [yield_now](Self::yield_now) - type YieldNowFuture: core::future::Future; - /// Yields the execution back to the runtime /// /// If the protocol performs a long computation, it might be better for performance - /// to split it with yield points, so the signle computation does not starve other + /// to split it with yield points, so the single computation does not starve other /// tasks. - fn yield_now(&self) -> Self::YieldNowFuture; + async fn yield_now(&self); } /// [Tokio](tokio)-specific async runtime @@ -26,11 +23,8 @@ pub struct TokioRuntime; #[cfg(feature = "runtime-tokio")] impl AsyncRuntime for TokioRuntime { - type YieldNowFuture = - core::pin::Pin + Send>>; - - fn yield_now(&self) -> Self::YieldNowFuture { - alloc::boxed::Box::pin(tokio::task::yield_now()) + async fn yield_now(&self) { + tokio::task::yield_now().await } } @@ -56,10 +50,8 @@ pub mod unknown_runtime { pub struct UnknownRuntime; impl super::AsyncRuntime for UnknownRuntime { - type YieldNowFuture = YieldNow; - - fn yield_now(&self) -> Self::YieldNowFuture { - YieldNow(false) + async fn yield_now(&self) { + YieldNow(false).await } } diff --git a/round-based/src/party.rs b/round-based/src/party.rs deleted file mode 100644 index 6ab2199..0000000 --- a/round-based/src/party.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Party of MPC protocol -//! -//! [`MpcParty`] is party of MPC protocol, connected to network, ready to start carrying out the protocol. -//! -//! ```rust -//! use round_based::{Mpc, MpcParty, Delivery, PartyIndex}; -//! -//! # struct KeygenMsg; -//! # struct KeyShare; -//! # struct Error; -//! # type Result = std::result::Result; -//! # async fn doc() -> Result<()> { -//! async fn keygen(party: M, i: PartyIndex, n: u16) -> Result -//! where -//! M: Mpc -//! { -//! // ... -//! # unimplemented!() -//! } -//! async fn connect() -> impl Delivery { -//! // ... -//! # round_based::_docs::fake_delivery() -//! } -//! -//! let delivery = connect().await; -//! let party = MpcParty::connected(delivery); -//! -//! # let (i, n) = (1, 3); -//! let keyshare = keygen(party, i, n).await?; -//! # Ok(()) } -//! ``` - -use phantom_type::PhantomType; - -use crate::delivery::Delivery; -use crate::runtime::{self, AsyncRuntime}; - -/// Party of MPC protocol (trait) -/// -/// [`MpcParty`] is the only struct that implement this trait. Motivation to have this trait is to fewer amount of -/// generic bounds that are needed to be specified. -/// -/// Typical usage of this trait when implementing MPC protocol: -/// -/// ```rust -/// use round_based::{Mpc, MpcParty, PartyIndex}; -/// -/// # struct Msg; -/// async fn keygen(party: M, i: PartyIndex, n: u16) -/// where -/// M: Mpc -/// { -/// let MpcParty{ delivery, .. } = party.into_party(); -/// // ... -/// } -/// ``` -/// -/// If we didn't have this trait, generics would be less readable: -/// ```rust -/// use round_based::{MpcParty, Delivery, runtime::AsyncRuntime, PartyIndex}; -/// -/// # struct Msg; -/// async fn keygen(party: MpcParty, i: PartyIndex, n: u16) -/// where -/// D: Delivery, -/// R: AsyncRuntime -/// { -/// // ... -/// } -/// ``` -pub trait Mpc: internal::Sealed { - /// MPC message - type ProtocolMessage; - /// Transport layer implementation - type Delivery: Delivery< - Self::ProtocolMessage, - SendError = Self::SendError, - ReceiveError = Self::ReceiveError, - >; - /// Async runtime - type Runtime: AsyncRuntime; - - /// Sending message error - type SendError: core::error::Error + Send + Sync + 'static; - /// Receiving message error - type ReceiveError: core::error::Error + Send + Sync + 'static; - - /// Converts into [`MpcParty`] - fn into_party(self) -> MpcParty; -} - -mod internal { - pub trait Sealed {} -} - -/// Party of MPC protocol -#[non_exhaustive] -pub struct MpcParty { - /// Defines transport layer - pub delivery: D, - /// Defines how computationally heavy tasks should be handled - pub runtime: R, - _msg: PhantomType, -} - -impl MpcParty -where - D: Delivery, -{ - /// Party connected to the network - /// - /// Takes the delivery object determining how to deliver/receive other parties' messages - pub fn connected(delivery: D) -> Self { - Self { - delivery, - runtime: Default::default(), - _msg: PhantomType::new(), - } - } -} - -impl MpcParty -where - D: Delivery, -{ - /// Modify the delivery of this party while keeping everything else the same - pub fn map_delivery(self, f: impl FnOnce(D) -> D2) -> MpcParty { - let delivery = f(self.delivery); - MpcParty { - delivery, - runtime: self.runtime, - _msg: self._msg, - } - } - - /// Modify the runtime of this party while keeping everything else the same - pub fn map_runtime(self, f: impl FnOnce(X) -> R) -> MpcParty { - let runtime = f(self.runtime); - MpcParty { - delivery: self.delivery, - runtime, - _msg: self._msg, - } - } - - /// Specifies a [async runtime](runtime) - pub fn set_runtime(self, runtime: R) -> MpcParty - where - R: AsyncRuntime, - { - MpcParty { - delivery: self.delivery, - runtime, - _msg: self._msg, - } - } -} - -impl internal::Sealed for MpcParty {} - -impl Mpc for MpcParty -where - D: Delivery, - D::SendError: core::error::Error + Send + Sync + 'static, - D::ReceiveError: core::error::Error + Send + Sync + 'static, - R: AsyncRuntime, -{ - type ProtocolMessage = M; - type Delivery = D; - type Runtime = R; - - type SendError = D::SendError; - type ReceiveError = D::ReceiveError; - - fn into_party(self) -> MpcParty { - self - } -} diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs new file mode 100644 index 0000000..a4e8d93 --- /dev/null +++ b/round-based/src/round/mod.rs @@ -0,0 +1,47 @@ +//! Primitives that process and collect messages received at certain round + +use crate::Incoming; + +pub use self::simple_store::{broadcast, p2p, RoundInput, RoundInputError, RoundMsgs}; + +mod simple_store; + +/// Stores messages received at particular round +/// +/// In MPC protocol, party at every round usually needs to receive up to `n` messages. `RoundsStore` +/// is a container that stores messages, it knows how many messages are expected to be received, +/// and should implement extra measures against malicious parties (e.g. prohibit message overwrite). +/// +/// ## Flow +/// `RoundStore` stores received messages. Once enough messages are received, it outputs [`RoundStore::Output`]. +/// In order to save received messages, [`.add_message(msg)`] is called. Then, [`.wants_more()`] tells whether more +/// messages are needed to be received. If it returned `false`, then output can be retrieved by calling [`.output()`]. +/// +/// [`.add_message(msg)`]: Self::add_message +/// [`.wants_more()`]: Self::wants_more +/// [`.output()`]: Self::output +/// +/// ## Example +/// [`RoundInput`](super::simple_store::RoundInput) is an simple messages store. Refer to its docs to see usage examples. +pub trait RoundStore: Sized + 'static { + /// Message type + type Msg; + /// Store output (e.g. `Vec<_>` of received messages) + type Output; + /// Store error + type Error: core::error::Error; + + /// Adds received message to the store + /// + /// Returns error if message cannot be processed. Usually it means that sender behaves maliciously. + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error>; + /// Indicates if store expects more messages to receive + fn wants_more(&self) -> bool; + /// Retrieves store output if enough messages are received + /// + /// Returns `Err(self)` if more message are needed to be received. + /// + /// If store indicated that it needs no more messages (ie `store.wants_more() == false`), then + /// this function must return `Ok(_)`. + fn output(self) -> Result; +} diff --git a/round-based/src/rounds_router/simple_store.rs b/round-based/src/round/simple_store.rs similarity index 96% rename from round-based/src/rounds_router/simple_store.rs rename to round-based/src/round/simple_store.rs index 265dd80..7a3c4a9 100644 --- a/round-based/src/rounds_router/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -1,13 +1,13 @@ -//! Simple implementation of `MessagesStore` +//! Simple implementation of [`RoundStore`] use alloc::{vec, vec::Vec}; use core::iter; use crate::{Incoming, MessageType, MsgId, PartyIndex}; -use super::MessagesStore; +use super::RoundStore; -/// Simple implementation of [MessagesStore] that waits for all parties to send a message +/// Simple implementation of [`RoundStore`] that waits for all parties to send a message /// /// Round is considered complete when the store received a message from every party. Note that the /// store will ignore all the messages such as `msg.sender == local_party_index`. @@ -102,7 +102,7 @@ impl RoundInput { } } -impl MessagesStore for RoundInput +impl RoundStore for RoundInput where M: 'static, { @@ -251,7 +251,7 @@ impl RoundMsgs { } } -/// Error explaining why `RoundInput` wasn't able to process a message +/// Error explaining why [`RoundInput`] wasn't able to process a message #[derive(Debug, thiserror::Error)] pub enum RoundInputError { /// Party sent two messages in one round @@ -292,12 +292,25 @@ pub enum RoundInputError { }, } +/// Round messages store for p2p round +/// +/// Alias to [`RoundInput::p2p`] +pub fn p2p(i: u16, n: u16) -> RoundInput { + RoundInput::p2p(i, n) +} +/// Round messages store for broadcast round +/// +/// Alias to [`RoundInput::broadcast`] +pub fn broadcast(i: u16, n: u16) -> RoundInput { + RoundInput::broadcast(i, n) +} + #[cfg(test)] mod tests { use alloc::vec::Vec; use matches::assert_matches; - use crate::rounds_router::store::MessagesStore; + use crate::round::RoundStore; use crate::{Incoming, MessageType}; use super::{RoundInput, RoundInputError}; diff --git a/round-based/src/rounds_router/mod.rs b/round-based/src/rounds_router/mod.rs deleted file mode 100644 index ad52299..0000000 --- a/round-based/src/rounds_router/mod.rs +++ /dev/null @@ -1,638 +0,0 @@ -//! Routes incoming MPC messages between rounds -//! -//! [`RoundsRouter`] is an essential building block of MPC protocol, it processes incoming messages, groups -//! them by rounds, and provides convenient API for retrieving received messages at certain round. -//! -//! ## Example -//! -//! ```rust -//! use round_based::{Mpc, MpcParty, ProtocolMessage, Delivery, PartyIndex}; -//! use round_based::rounds_router::{RoundsRouter, simple_store::{RoundInput, RoundMsgs}}; -//! -//! #[derive(ProtocolMessage)] -//! pub enum Msg { -//! Round1(Msg1), -//! Round2(Msg2), -//! } -//! -//! pub struct Msg1 { /* ... */ } -//! pub struct Msg2 { /* ... */ } -//! -//! pub async fn some_mpc_protocol(party: M, i: PartyIndex, n: u16) -> Result -//! where -//! M: Mpc, -//! { -//! let MpcParty{ delivery, .. } = party.into_party(); -//! -//! let (incomings, _outgoings) = delivery.split(); -//! -//! // Build `Rounds` -//! let mut rounds = RoundsRouter::builder(); -//! let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); -//! let round2 = rounds.add_round(RoundInput::::p2p(i, n)); -//! let mut rounds = rounds.listen(incomings); -//! -//! // Receive messages from round 1 -//! let msgs: RoundMsgs = rounds.complete(round1).await?; -//! -//! // ... process received messages -//! -//! // Receive messages from round 2 -//! let msgs = rounds.complete(round2).await?; -//! -//! // ... -//! # todo!() -//! } -//! # type Output = (); -//! # type Error = Box; -//! ``` - -use alloc::{boxed::Box, collections::BTreeMap}; -use core::{any::Any, convert::Infallible, mem}; - -use futures_util::{Stream, StreamExt}; -use phantom_type::PhantomType; -use tracing::{debug, error, trace, trace_span, warn, Span}; - -use crate::Incoming; - -#[doc(inline)] -pub use self::errors::CompleteRoundError; -pub use self::store::*; - -pub mod simple_store; -mod store; - -/// Routes received messages between protocol rounds -/// -/// See [module level](self) documentation to learn more about it. -pub struct RoundsRouter { - incomings: S, - rounds: BTreeMap + Send>>>, -} - -impl RoundsRouter { - /// Instantiates [`RoundsRouterBuilder`] - pub fn builder() -> RoundsRouterBuilder { - RoundsRouterBuilder::new() - } -} - -impl RoundsRouter -where - M: ProtocolMessage, - S: Stream, E>> + Unpin, - E: core::error::Error, -{ - /// Completes specified round - /// - /// Waits until all messages at specified round are received. Returns received - /// messages if round is successfully completed, or error otherwise. - #[inline(always)] - pub async fn complete( - &mut self, - round: Round, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let round_number = >::ROUND; - let span = trace_span!("Round", n = round_number); - debug!(parent: &span, "pending round to complete"); - - match self.complete_with_span(&span, round).await { - Ok(output) => { - trace!(parent: &span, "round successfully completed"); - Ok(output) - } - Err(err) => { - error!(parent: &span, %err, "round terminated with error"); - Err(err) - } - } - } - - async fn complete_with_span( - &mut self, - span: &Span, - _round: Round, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let pending_round = >::ROUND; - if let Some(output) = self.retrieve_round_output_if_its_completed::() { - return output; - } - - loop { - let incoming = match self.incomings.next().await { - Some(Ok(msg)) => msg, - Some(Err(err)) => return Err(errors::IoError::Io(err).into()), - None => return Err(errors::IoError::UnexpectedEof.into()), - }; - let message_round_n = incoming.msg.round(); - - let message_round = match self.rounds.get_mut(&message_round_n) { - Some(Some(round)) => round, - Some(None) => { - warn!( - parent: span, - n = message_round_n, - "got message for the round that was already completed, ignoring it" - ); - continue; - } - None => { - return Err( - errors::RoundsMisuse::UnregisteredRound { n: message_round_n }.into(), - ) - } - }; - if message_round.needs_more_messages().no() { - warn!( - parent: span, - n = message_round_n, - "received message for the round that was already completed, ignoring it" - ); - continue; - } - message_round.process_message(incoming); - - if pending_round == message_round_n { - if let Some(output) = self.retrieve_round_output_if_its_completed::() { - return output; - } - } - } - } - - #[allow(clippy::type_complexity)] - fn retrieve_round_output_if_its_completed( - &mut self, - ) -> Option>> - where - R: MessagesStore, - M: RoundMessage, - { - let round_number = >::ROUND; - let round_slot = match self - .rounds - .get_mut(&round_number) - .ok_or(errors::RoundsMisuse::UnregisteredRound { n: round_number }) - { - Ok(slot) => slot, - Err(err) => return Some(Err(err.into())), - }; - let round = match round_slot - .as_mut() - .ok_or(errors::RoundsMisuse::RoundAlreadyCompleted) - { - Ok(round) => round, - Err(err) => return Some(Err(err.into())), - }; - if round.needs_more_messages().no() { - Some(Self::retrieve_round_output::(round_slot)) - } else { - None - } - } - - fn retrieve_round_output( - slot: &mut Option + Send>>, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let mut round = slot.take().ok_or(errors::RoundsMisuse::UnregisteredRound { - n: >::ROUND, - })?; - match round.take_output() { - Ok(Ok(any)) => Ok(*any - .downcast::() - .or(Err(CompleteRoundError::from( - errors::Bug::MismatchedOutputType, - )))?), - Ok(Err(any)) => Err(any - .downcast::>() - .or(Err(CompleteRoundError::from( - errors::Bug::MismatchedErrorType, - )))? - .map_io_err(|e| match e {})), - Err(err) => Err(errors::Bug::TakeRoundResult(err).into()), - } - } -} - -/// Builds [`RoundsRouter`] -pub struct RoundsRouterBuilder { - rounds: BTreeMap + Send>>>, -} - -impl Default for RoundsRouterBuilder -where - M: ProtocolMessage + 'static, -{ - fn default() -> Self { - Self::new() - } -} - -impl RoundsRouterBuilder -where - M: ProtocolMessage + 'static, -{ - /// Constructs [`RoundsRouterBuilder`] - /// - /// Alias to [`RoundsRouter::builder`] - pub fn new() -> Self { - Self { - rounds: BTreeMap::new(), - } - } - - /// Registers new round - /// - /// ## Panics - /// Panics if round `R` was already registered - pub fn add_round(&mut self, message_store: R) -> Round - where - R: MessagesStore + Send + 'static, - R::Output: Send, - R::Error: Send, - M: RoundMessage, - { - let overridden_round = self.rounds.insert( - M::ROUND, - Some(Box::new(ProcessRoundMessageImpl::new(message_store))), - ); - if overridden_round.is_some() { - panic!("round {} is overridden", M::ROUND); - } - Round { - _ph: PhantomType::new(), - } - } - - /// Builds [`RoundsRouter`] - /// - /// Takes a stream of incoming messages which will be routed between registered rounds - pub fn listen(self, incomings: S) -> RoundsRouter - where - S: Stream, E>>, - { - RoundsRouter { - incomings, - rounds: self.rounds, - } - } -} - -/// A round of MPC protocol -/// -/// `Round` can be used to retrieve messages received at this round by calling [`RoundsRouter::complete`]. See -/// [module level](self) documentation to see usage. -pub struct Round { - _ph: PhantomType, -} - -trait ProcessRoundMessage { - type Msg; - - /// Processes round message - /// - /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`, - /// otherwise calling this method is unexpected. - fn process_message(&mut self, msg: Incoming); - - /// Indicated whether the store needs more messages - /// - /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No` - /// then you need to take the round output by calling `.take_output()`. - fn needs_more_messages(&self) -> NeedsMoreMessages; - - /// Tries to obtain round output - /// - /// Can be called once `process_message()` returned `NeedMoreMessages::No`. - /// - /// Returns: - /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output` - /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError` - /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`] - #[allow(clippy::type_complexity)] - fn take_output(&mut self) -> Result, Box>, TakeOutputError>; -} - -#[derive(Debug, thiserror::Error)] -enum TakeOutputError { - #[error("output is already taken")] - AlreadyTaken, - #[error("output is not ready yet, more messages are needed")] - NotReady, -} - -enum ProcessRoundMessageImpl> { - InProgress { store: S, _ph: PhantomType }, - Completed(Result>), - Gone, -} - -impl> ProcessRoundMessageImpl { - pub fn new(store: S) -> Self { - if store.wants_more() { - Self::InProgress { - store, - _ph: Default::default(), - } - } else { - Self::Completed( - store - .output() - .map_err(|_| errors::ImproperStoreImpl::StoreDidntOutput.into()), - ) - } - } -} - -impl ProcessRoundMessageImpl -where - S: MessagesStore, - M: ProtocolMessage + RoundMessage, -{ - fn _process_message( - store: &mut S, - msg: Incoming, - ) -> Result<(), CompleteRoundError> { - let msg = msg.try_map(M::from_protocol_message).map_err(|msg| { - errors::Bug::MessageFromAnotherRound { - actual_number: msg.round(), - expected_round: M::ROUND, - } - })?; - - store - .add_message(msg) - .map_err(CompleteRoundError::ProcessMessage)?; - Ok(()) - } -} - -impl ProcessRoundMessage for ProcessRoundMessageImpl -where - S: MessagesStore, - M: ProtocolMessage + RoundMessage, -{ - type Msg = M; - - fn process_message(&mut self, msg: Incoming) { - let store = match self { - Self::InProgress { store, .. } => store, - _ => { - return; - } - }; - - match Self::_process_message(store, msg) { - Ok(()) => { - if store.wants_more() { - return; - } - - let store = match mem::replace(self, Self::Gone) { - Self::InProgress { store, .. } => store, - _ => { - *self = Self::Completed(Err(errors::Bug::IncoherentState { - expected: "InProgress", - justification: - "we checked at beginning of the function that `state` is InProgress", - } - .into())); - return; - } - }; - - match store.output() { - Ok(output) => *self = Self::Completed(Ok(output)), - Err(_err) => { - *self = - Self::Completed(Err(errors::ImproperStoreImpl::StoreDidntOutput.into())) - } - } - } - Err(err) => { - *self = Self::Completed(Err(err)); - } - } - } - - fn needs_more_messages(&self) -> NeedsMoreMessages { - match self { - Self::InProgress { .. } => NeedsMoreMessages::Yes, - _ => NeedsMoreMessages::No, - } - } - - fn take_output(&mut self) -> Result, Box>, TakeOutputError> { - match self { - Self::InProgress { .. } => return Err(TakeOutputError::NotReady), - Self::Gone => return Err(TakeOutputError::AlreadyTaken), - _ => (), - } - match mem::replace(self, Self::Gone) { - Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))), - Self::Completed(Err(err)) => Ok(Err(Box::new(err))), - _ => unreachable!("it's checked to be completed"), - } - } -} - -enum NeedsMoreMessages { - Yes, - No, -} - -#[allow(dead_code)] -impl NeedsMoreMessages { - pub fn yes(&self) -> bool { - matches!(self, Self::Yes) - } - pub fn no(&self) -> bool { - matches!(self, Self::No) - } -} - -/// When something goes wrong -pub mod errors { - use super::TakeOutputError; - - /// Error indicating that `Rounds` failed to complete certain round - #[derive(Debug, thiserror::Error)] - pub enum CompleteRoundError { - /// [`MessagesStore`](super::MessagesStore) failed to process this message - #[error("failed to process the message")] - ProcessMessage(#[source] ProcessErr), - /// Receiving next message resulted into i/o error - #[error("receive next message")] - Io(#[from] IoError), - /// Some implementation specific error - /// - /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug - /// in `Rounds` implementation - #[error("implementation error")] - Other(#[source] OtherError), - } - - /// Error indicating that receiving next message resulted into i/o error - #[derive(Debug, thiserror::Error)] - pub enum IoError { - /// I/O error - #[error("i/o error")] - Io(#[source] E), - /// Encountered unexpected EOF - #[error("unexpected eof")] - UnexpectedEof, - } - - /// Some implementation specific error - /// - /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug - /// in `Rounds` implementation - #[derive(Debug, thiserror::Error)] - #[error(transparent)] - pub struct OtherError(OtherReason); - - #[derive(Debug, thiserror::Error)] - pub(super) enum OtherReason { - #[error("improper `MessagesStore` implementation")] - ImproperStoreImpl(#[source] ImproperStoreImpl), - #[error("`Rounds` API misuse")] - RoundsMisuse(#[source] RoundsMisuse), - #[error("bug in `Rounds` (please, open a issue)")] - Bug(#[source] Bug), - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum ImproperStoreImpl { - /// Store indicated that it received enough messages but didn't output - /// - /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`. - #[error("store didn't output")] - StoreDidntOutput, - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum RoundsMisuse { - #[error("round is already completed")] - RoundAlreadyCompleted, - #[error("round {n} is not registered")] - UnregisteredRound { n: u16 }, - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum Bug { - #[error( - "message originates from another round: we process messages from round \ - {expected_round}, got message from round {actual_number}" - )] - MessageFromAnotherRound { - expected_round: u16, - actual_number: u16, - }, - #[error("state is incoherent, it's expected to be {expected}: {justification}")] - IncoherentState { - expected: &'static str, - justification: &'static str, - }, - #[error("mismatched output type")] - MismatchedOutputType, - #[error("mismatched error type")] - MismatchedErrorType, - #[error("take round result")] - TakeRoundResult(#[source] TakeOutputError), - } - - impl CompleteRoundError { - pub(super) fn map_io_err(self, f: F) -> CompleteRoundError - where - F: FnOnce(IoErr) -> E, - { - match self { - CompleteRoundError::Io(err) => CompleteRoundError::Io(err.map_err(f)), - CompleteRoundError::ProcessMessage(err) => CompleteRoundError::ProcessMessage(err), - CompleteRoundError::Other(err) => CompleteRoundError::Other(err), - } - } - } - - impl IoError { - pub(super) fn map_err(self, f: F) -> IoError - where - F: FnOnce(E) -> B, - { - match self { - IoError::Io(e) => IoError::Io(f(e)), - IoError::UnexpectedEof => IoError::UnexpectedEof, - } - } - } - - macro_rules! impl_from_other_error { - ($($err:ident),+,) => {$( - impl From<$err> for CompleteRoundError { - fn from(err: $err) -> Self { - Self::Other(OtherError(OtherReason::$err(err))) - } - } - )+}; - } - - impl_from_other_error! { - ImproperStoreImpl, - RoundsMisuse, - Bug, - } -} - -#[cfg(test)] -mod tests { - struct Store; - - #[derive(crate::ProtocolMessage)] - #[protocol_message(root = crate)] - enum FakeProtocolMsg { - R1(Msg1), - } - struct Msg1; - - impl super::MessagesStore for Store { - type Msg = Msg1; - type Output = (); - type Error = core::convert::Infallible; - - fn add_message(&mut self, _msg: crate::Incoming) -> Result<(), Self::Error> { - Ok(()) - } - fn wants_more(&self) -> bool { - false - } - fn output(self) -> Result { - Ok(()) - } - } - - #[tokio::test] - async fn complete_round_that_expects_no_messages() { - let incomings = futures::stream::pending::< - Result, core::convert::Infallible>, - >(); - - let mut rounds = super::RoundsRouter::builder(); - let round1 = rounds.add_round(Store); - let mut rounds = rounds.listen(incomings); - - rounds.complete(round1).await.unwrap(); - } -} diff --git a/round-based/src/rounds_router/store.rs b/round-based/src/rounds_router/store.rs deleted file mode 100644 index a7758df..0000000 --- a/round-based/src/rounds_router/store.rs +++ /dev/null @@ -1,132 +0,0 @@ -use crate::Incoming; - -/// Stores messages received at particular round -/// -/// In MPC protocol, party at every round usually needs to receive up to `n` messages. `MessagesStore` -/// is a container that stores messages, it knows how many messages are expected to be received, -/// and should implement extra measures against malicious parties (e.g. prohibit message overwrite). -/// -/// ## Procedure -/// `MessagesStore` stores received messages. Once enough messages are received, it outputs [`MessagesStore::Output`]. -/// In order to save received messages, [`.add_message(msg)`] is called. Then, [`.wants_more()`] tells whether more -/// messages are needed to be received. If it returned `false`, then output can be retrieved by calling [`.output()`]. -/// -/// [`.add_message(msg)`]: Self::add_message -/// [`.wants_more()`]: Self::wants_more -/// [`.output()`]: Self::output -/// -/// ## Example -/// [`RoundInput`](super::simple_store::RoundInput) is an simple messages store. Refer to its docs to see usage examples. -pub trait MessagesStore: Sized + 'static { - /// Message type - type Msg; - /// Store output (e.g. `Vec<_>` of received messages) - type Output; - /// Store error - type Error: core::error::Error; - - /// Adds received message to the store - /// - /// Returns error if message cannot be processed. Usually it means that sender behaves maliciously. - fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error>; - /// Indicates if store expects more messages to receive - fn wants_more(&self) -> bool; - /// Retrieves store output if enough messages are received - /// - /// Returns `Err(self)` if more message are needed to be received. - /// - /// If store indicated that it needs no more messages (ie `store.wants_more() == false`), then - /// this function must return `Ok(_)`. - fn output(self) -> Result; -} - -/// Message of MPC protocol -/// -/// MPC protocols typically consist of several rounds, each round has differently typed message. -/// `ProtocolMessage` and [`RoundMessage`] traits are used to examine received message: `ProtocolMessage::round` -/// determines which round message belongs to, and then `RoundMessage` trait can be used to retrieve -/// actual round-specific message. -/// -/// You should derive these traits using proc macro (requires `derive` feature): -/// ```rust -/// use round_based::ProtocolMessage; -/// -/// #[derive(ProtocolMessage)] -/// pub enum Message { -/// Round1(Msg1), -/// Round2(Msg2), -/// // ... -/// } -/// -/// pub struct Msg1 { /* ... */ } -/// pub struct Msg2 { /* ... */ } -/// ``` -/// -/// This desugars into: -/// -/// ```rust -/// use round_based::rounds_router::{ProtocolMessage, RoundMessage}; -/// -/// pub enum Message { -/// Round1(Msg1), -/// Round2(Msg2), -/// // ... -/// } -/// -/// pub struct Msg1 { /* ... */ } -/// pub struct Msg2 { /* ... */ } -/// -/// impl ProtocolMessage for Message { -/// fn round(&self) -> u16 { -/// match self { -/// Message::Round1(_) => 1, -/// Message::Round2(_) => 2, -/// // ... -/// } -/// } -/// } -/// impl RoundMessage for Message { -/// const ROUND: u16 = 1; -/// fn to_protocol_message(round_message: Msg1) -> Self { -/// Message::Round1(round_message) -/// } -/// fn from_protocol_message(protocol_message: Self) -> Result { -/// match protocol_message { -/// Message::Round1(msg) => Ok(msg), -/// msg => Err(msg), -/// } -/// } -/// } -/// impl RoundMessage for Message { -/// const ROUND: u16 = 2; -/// fn to_protocol_message(round_message: Msg2) -> Self { -/// Message::Round2(round_message) -/// } -/// fn from_protocol_message(protocol_message: Self) -> Result { -/// match protocol_message { -/// Message::Round2(msg) => Ok(msg), -/// msg => Err(msg), -/// } -/// } -/// } -/// ``` -pub trait ProtocolMessage: Sized { - /// Number of round this message originates from - fn round(&self) -> u16; -} - -/// Round message -/// -/// See [`ProtocolMessage`] trait documentation. -pub trait RoundMessage: ProtocolMessage { - /// Number of the round this message belongs to - const ROUND: u16; - - /// Converts round message into protocol message (never fails) - fn to_protocol_message(round_message: M) -> Self; - /// Extracts round message from protocol message - /// - /// Returns `Err(protocol_message)` if `protocol_message.round() != Self::ROUND`, otherwise - /// returns `Ok(round_message)` - fn from_protocol_message(protocol_message: Self) -> Result; -} diff --git a/round-based/src/sim/async_env.rs b/round-based/src/sim/async_env.rs index 8bb43df..127862f 100644 --- a/round-based/src/sim/async_env.rs +++ b/round-based/src/sim/async_env.rs @@ -41,7 +41,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -75,7 +75,10 @@ use futures_util::{Sink, Stream}; use tokio::sync::broadcast; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use crate::delivery::{Delivery, Incoming, Outgoing}; +use crate::{ + delivery::{Incoming, Outgoing}, + ProtocolMsg, +}; use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex}; use super::SimResult; @@ -91,7 +94,7 @@ pub struct Network { impl Network where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, { /// Instantiates a new simulation pub fn new() -> Self { @@ -126,23 +129,23 @@ where let local_party_idx = self.next_party_idx; self.next_party_idx += 1; - MockedDelivery { - incoming: MockedIncoming { + MockedDelivery::new( + MockedIncoming { local_party_idx, receiver: BroadcastStream::new(self.channel.subscribe()), }, - outgoing: MockedOutgoing { + MockedOutgoing { local_party_idx, sender: self.channel.clone(), next_msg_id: self.next_msg_id.clone(), }, - } + ) } } impl Default for Network where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, { fn default() -> Self { Self::new() @@ -150,23 +153,17 @@ where } /// Mocked networking -pub struct MockedDelivery { - incoming: MockedIncoming, - outgoing: MockedOutgoing, -} +pub type MockedDelivery = crate::mpc::Halves, MockedOutgoing>; -impl Delivery for MockedDelivery -where - M: Clone + Send + Unpin + 'static, -{ - type Send = MockedOutgoing; - type Receive = MockedIncoming; - type SendError = broadcast::error::SendError<()>; - type ReceiveError = BroadcastStreamRecvError; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.incoming, self.outgoing) - } +/// Delivery error +#[derive(Debug, thiserror::Error)] +pub enum MockedDeliveryError { + /// Error occurred when sending a message + #[error(transparent)] + Recv(BroadcastStreamRecvError), + /// Error occurred when receiving a message + #[error(transparent)] + Send(broadcast::error::SendError<()>), } /// Incoming channel of mocked network @@ -179,13 +176,13 @@ impl Stream for MockedIncoming where M: Clone + Send + 'static, { - type Item = Result, BroadcastStreamRecvError>; + type Item = Result, MockedDeliveryError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) { Some(Ok(m)) => m, - Some(Err(e)) => return Poll::Ready(Some(Err(e))), + Some(Err(e)) => return Poll::Ready(Some(Err(MockedDeliveryError::Recv(e)))), None => return Poll::Ready(None), }; if msg.recipient.is_p2p() @@ -206,7 +203,7 @@ pub struct MockedOutgoing { } impl Sink> for MockedOutgoing { - type Error = broadcast::error::SendError<()>; + type Error = MockedDeliveryError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -224,7 +221,7 @@ impl Sink> for MockedOutgoing { msg_type, msg: m, })) - .map_err(|_| broadcast::error::SendError(()))?; + .map_err(|_| MockedDeliveryError::Send(broadcast::error::SendError(())))?; Ok(()) } @@ -268,7 +265,7 @@ impl NextMessageId { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -294,7 +291,7 @@ pub async fn run( party_start: impl FnMut(u16, MpcParty>) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity(DEFAULT_CAPACITY, n, party_start).await @@ -311,7 +308,7 @@ pub async fn run_with_capacity( mut party_start: impl FnMut(u16, MpcParty>) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity_and_setup( @@ -346,7 +343,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -372,7 +369,7 @@ pub async fn run_with_setup( party_start: impl FnMut(u16, MpcParty>, S) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity_and_setup::(DEFAULT_CAPACITY, setups, party_start).await @@ -389,7 +386,7 @@ pub async fn run_with_capacity_and_setup( mut party_start: impl FnMut(u16, MpcParty>, S) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { let mut network = Network::::with_capacity(capacity); diff --git a/round-based/src/sim/mod.rs b/round-based/src/sim/mod.rs index 56aa794..7900052 100644 --- a/round-based/src/sim/mod.rs +++ b/round-based/src/sim/mod.rs @@ -43,7 +43,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -67,7 +67,9 @@ use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec}; use core::future::Future; -use crate::{state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing}; +use crate::{ + state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing, ProtocolMsg, +}; #[cfg(feature = "sim-async")] pub mod async_env; @@ -229,7 +231,7 @@ enum Party<'a, O, M> { impl<'a, O, M> Simulation<'a, O, M> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, { /// Creates empty simulation containing no parties /// @@ -479,7 +481,7 @@ impl MessagesQueue { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -504,7 +506,7 @@ pub fn run( mut party_start: impl FnMut(u16, crate::state_machine::MpcParty) -> F, ) -> Result, SimError> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, F: Future, { run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| { @@ -534,7 +536,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -559,7 +561,7 @@ pub fn run_with_setup( mut party_start: impl FnMut(u16, crate::state_machine::MpcParty, S) -> F, ) -> Result, SimError> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, F: Future, { let mut sim = Simulation::empty(); diff --git a/round-based/src/state_machine/delivery.rs b/round-based/src/state_machine/delivery.rs index 34b5936..50826ef 100644 --- a/round-based/src/state_machine/delivery.rs +++ b/round-based/src/state_machine/delivery.rs @@ -1,18 +1,18 @@ use core::task::{ready, Poll}; -/// Stream of incoming messages -pub struct Incomings { +/// Provides a stream of incoming and sink for outgoing messages +pub struct Delivery { shared_state: super::shared_state::SharedStateRef, } -impl Incomings { +impl Delivery { pub(super) fn new(shared_state: super::shared_state::SharedStateRef) -> Self { Self { shared_state } } } -impl crate::Stream for Incomings { - type Item = Result, core::convert::Infallible>; +impl futures_util::Stream for Delivery { + type Item = Result, DeliveryErr>; fn poll_next( self: core::pin::Pin<&mut Self>, @@ -26,19 +26,8 @@ impl crate::Stream for Incomings { } } -/// Sink for outgoing messages -pub struct Outgoings { - shared_state: super::shared_state::SharedStateRef, -} - -impl Outgoings { - pub(super) fn new(shared_state: super::shared_state::SharedStateRef) -> Self { - Self { shared_state } - } -} - -impl crate::Sink> for Outgoings { - type Error = SendErr; +impl crate::Sink> for Delivery { + type Error = DeliveryErr; fn poll_ready( self: core::pin::Pin<&mut Self>, @@ -54,7 +43,7 @@ impl crate::Sink> for Outgoings { ) -> Result<(), Self::Error> { self.shared_state .protocol_saves_msg_to_be_sent(msg) - .map_err(|_| SendErr(SendErrReason::NotReady)) + .map_err(|_| DeliveryErr(Reason::NotReady)) } fn poll_flush( @@ -76,10 +65,10 @@ impl crate::Sink> for Outgoings { /// Error returned by [`Outgoings`] sink #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct SendErr(SendErrReason); +pub struct DeliveryErr(Reason); #[derive(Debug, thiserror::Error)] -enum SendErrReason { +enum Reason { #[error("sink is not ready")] NotReady, } diff --git a/round-based/src/state_machine/mod.rs b/round-based/src/state_machine/mod.rs index d674122..ff32c74 100644 --- a/round-based/src/state_machine/mod.rs +++ b/round-based/src/state_machine/mod.rs @@ -20,7 +20,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -67,8 +67,10 @@ mod shared_state; use core::{future::Future, task::Poll}; +use crate::ProtocolMsg; + pub use self::{ - delivery::{Incomings, Outgoings, SendErr}, + delivery::{Delivery, DeliveryErr}, runtime::{Runtime, YieldNow}, }; @@ -226,9 +228,6 @@ where } } -/// Delivery implementation used in the state machine -pub type Delivery = (Incomings, Outgoings); - /// MpcParty instantiated with state machine implementation of delivery and async runtime pub type MpcParty = crate::MpcParty, Runtime>; @@ -245,15 +244,13 @@ pub fn wrap_protocol<'a, M, F>( ) -> impl StateMachine + 'a where F: Future + 'a, - M: 'static, + M: ProtocolMsg + 'static, { let shared_state = shared_state::SharedStateRef::new(); - let incomings = Incomings::new(shared_state.clone()); - let outgoings = Outgoings::new(shared_state.clone()); - let delivery = (incomings, outgoings); + let delivery = Delivery::new(shared_state.clone()); let runtime = Runtime::new(shared_state.clone()); - let future = protocol(crate::MpcParty::connected(delivery).set_runtime(runtime)); + let future = protocol(crate::mpc::connected(delivery).with_runtime(runtime)); let future = alloc::boxed::Box::pin(future); StateMachineImpl { diff --git a/round-based/src/state_machine/runtime.rs b/round-based/src/state_machine/runtime.rs index 005796c..81229bd 100644 --- a/round-based/src/state_machine/runtime.rs +++ b/round-based/src/state_machine/runtime.rs @@ -11,14 +11,13 @@ impl Runtime { } } -impl crate::runtime::AsyncRuntime for Runtime { - type YieldNowFuture = YieldNow; - - fn yield_now(&self) -> Self::YieldNowFuture { +impl crate::mpc::runtime::AsyncRuntime for Runtime { + async fn yield_now(&self) { YieldNow { shared_state: self.shared_state.clone(), yielded: false, } + .await } } diff --git a/round-based/tests/derive/compile-fail/wrong_usage.rs b/round-based/tests/derive/compile-fail/wrong_usage.rs index d081241..93cc030 100644 --- a/round-based/tests/derive/compile-fail/wrong_usage.rs +++ b/round-based/tests/derive/compile-fail/wrong_usage.rs @@ -1,9 +1,9 @@ -use round_based::ProtocolMessage; +use round_based::ProtocolMsg; -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] enum Msg { // Unnamed variant with single field is the only correct enum variant - // that doesn't contradicts with ProtocolMessage derivation + // that doesn't contradicts with ProtocolMsg derivation VariantA(u16), // Error: You can't have named variants VariantB { n: u32 }, @@ -15,20 +15,20 @@ enum Msg { VariantE, } -// Structure cannot implement ProtocolMessage -#[derive(ProtocolMessage)] +// Structure cannot implement ProtocolMsg +#[derive(ProtocolMsg)] struct Msg2 { some_field: u64, } -// Union cannot implement ProtocolMessage -#[derive(ProtocolMessage)] +// Union cannot implement ProtocolMsg +#[derive(ProtocolMsg)] union Msg3 { variant: u64, } // protocol_message is repeated twice -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] #[protocol_message(root = one)] #[protocol_message(root = two)] enum Msg4 { @@ -37,7 +37,7 @@ enum Msg4 { } // ", blah blah" is not permitted input -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] #[protocol_message(root = one, blah blah)] enum Msg5 { One(u32), @@ -45,7 +45,7 @@ enum Msg5 { } // `protocol_message` must not be empty -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] #[protocol_message()] enum Msg6 { One(u32), diff --git a/round-based/tests/derive/compile-pass/correct_usage.rs b/round-based/tests/derive/compile-pass/correct_usage.rs index e4be823..16e3c16 100644 --- a/round-based/tests/derive/compile-pass/correct_usage.rs +++ b/round-based/tests/derive/compile-pass/correct_usage.rs @@ -1,13 +1,13 @@ -use round_based::ProtocolMessage; +use round_based::ProtocolMsg; -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] enum Msg { VariantA(u16), VariantB(String), VariantC((u16, String)), VariantD(MyStruct), } -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] #[protocol_message(root = round_based)] enum Msg2 { VariantA(u16), From e2b27b5fb5be4d57e051de72ccdd9aef228427b6 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 5 May 2025 17:01:49 +0200 Subject: [PATCH 02/29] clippy fix Signed-off-by: Denis Varlakov --- round-based/src/delivery.rs | 40 ++---------------------- round-based/src/lib.rs | 2 +- round-based/src/mpc/rounds_router/mod.rs | 1 + round-based/src/sim/async_env.rs | 2 +- round-based/src/sim/mod.rs | 2 +- 5 files changed, 6 insertions(+), 41 deletions(-) diff --git a/round-based/src/delivery.rs b/round-based/src/delivery.rs index 25550d6..f5a562c 100644 --- a/round-based/src/delivery.rs +++ b/round-based/src/delivery.rs @@ -1,39 +1,3 @@ -use futures_util::{Sink, Stream}; - -/// Networking abstraction -/// -/// Basically, it's pair of channels: [`Stream`] for receiving messages, and [`Sink`] for sending -/// messages to other parties. -pub trait Delivery { - /// Outgoing delivery channel - type Send: Sink, Error = Self::SendError> + Unpin; - /// Incoming delivery channel - type Receive: Stream, Self::ReceiveError>> + Unpin; - /// Error of outgoing delivery channel - type SendError: core::error::Error + Send + Sync + 'static; - /// Error of incoming delivery channel - type ReceiveError: core::error::Error + Send + Sync + 'static; - /// Returns a pair of incoming and outgoing delivery channels - fn split(self) -> (Self::Receive, Self::Send); -} - -impl Delivery for (I, O) -where - I: Stream, IErr>> + Unpin, - O: Sink, Error = OErr> + Unpin, - IErr: core::error::Error + Send + Sync + 'static, - OErr: core::error::Error + Send + Sync + 'static, -{ - type Send = O; - type Receive = I; - type SendError = OErr; - type ReceiveError = IErr; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.0, self.1) - } -} - /// Incoming message #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct Incoming { @@ -106,7 +70,7 @@ impl Incoming { /// Checks whether it's broadcast message pub fn is_broadcast(&self) -> bool { - matches!(self.msg_type, MessageType::Broadcast { .. }) + matches!(self.msg_type, MessageType::Broadcast) } /// Checks whether it's p2p message @@ -187,6 +151,6 @@ impl MessageDestination { } /// Returns `true` if it's broadcast message pub fn is_broadcast(&self) -> bool { - matches!(self, MessageDestination::AllParties { .. }) + matches!(self, MessageDestination::AllParties) } } diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index 4cb649d..be04ee4 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -39,7 +39,7 @@ //! * `sim-async` enables protocol execution simulation with tokio runtime, see [`sim::async_env`] //! module //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync -//! API, see [`state_machine`] module +//! API, see [`state_machine`] module //! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro //! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](runtime) //! diff --git a/round-based/src/mpc/rounds_router/mod.rs b/round-based/src/mpc/rounds_router/mod.rs index 09b5a1b..5b310c7 100644 --- a/round-based/src/mpc/rounds_router/mod.rs +++ b/round-based/src/mpc/rounds_router/mod.rs @@ -122,6 +122,7 @@ where Ok(()) } + #[allow(clippy::type_complexity)] pub fn complete_round( &mut self, round: Round, diff --git a/round-based/src/sim/async_env.rs b/round-based/src/sim/async_env.rs index 127862f..851a442 100644 --- a/round-based/src/sim/async_env.rs +++ b/round-based/src/sim/async_env.rs @@ -313,7 +313,7 @@ where { run_with_capacity_and_setup( capacity, - core::iter::repeat(()).take(n.into()), + core::iter::repeat_n((), n.into()), |i, party, ()| party_start(i, party), ) .await diff --git a/round-based/src/sim/mod.rs b/round-based/src/sim/mod.rs index 7900052..ba1bd97 100644 --- a/round-based/src/sim/mod.rs +++ b/round-based/src/sim/mod.rs @@ -509,7 +509,7 @@ where M: ProtocolMsg + Clone + 'static, F: Future, { - run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| { + run_with_setup(core::iter::repeat_n((), n.into()), |i, party, ()| { party_start(i, party) }) } From 4e21e398790f4c3b497561deec0c134eb8c89b92 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 6 May 2025 11:56:44 +0200 Subject: [PATCH 03/29] Fix docs Signed-off-by: Denis Varlakov --- round-based/src/_docs.rs | 26 ++++--------- round-based/src/mpc/mod.rs | 25 +++++++----- round-based/src/mpc/rounds_router/mod.rs | 48 +----------------------- round-based/src/round/simple_store.rs | 12 ++++-- round-based/src/sim/async_env.rs | 15 +++++--- round-based/src/sim/mod.rs | 15 +++++--- round-based/src/state_machine/mod.rs | 8 ++-- 7 files changed, 55 insertions(+), 94 deletions(-) diff --git a/round-based/src/_docs.rs b/round-based/src/_docs.rs index 82c5a16..dc57265 100644 --- a/round-based/src/_docs.rs +++ b/round-based/src/_docs.rs @@ -1,21 +1,11 @@ -use core::convert::Infallible; +use futures_util::{Sink, SinkExt, Stream}; -use phantom_type::PhantomType; +use crate::{Incoming, Outgoing}; -use crate::{Delivery, Incoming, Outgoing}; - -pub fn fake_delivery() -> impl Delivery { - struct FakeDelivery(PhantomType); - impl Delivery for FakeDelivery { - type Send = futures_util::sink::Drain>; - type Receive = futures_util::stream::Pending, Infallible>>; - - type SendError = Infallible; - type ReceiveError = Infallible; - - fn split(self) -> (Self::Receive, Self::Send) { - (futures_util::stream::pending(), futures_util::sink::drain()) - } - } - FakeDelivery(PhantomType::new()) +pub fn fake_delivery( +) -> impl Stream, E>> + Sink, Error = E> + Unpin { + crate::mpc::Halves::new( + futures_util::stream::pending::, E>>(), + futures_util::sink::drain().sink_map_err(|e| match e {}), + ) } diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index a481c92..ae42452 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -3,27 +3,32 @@ //! [`MpcParty`] is party of MPC protocol, connected to network, ready to start carrying out the protocol. //! //! ```rust -//! use round_based::{Mpc, MpcParty, Delivery, PartyIndex}; +//! use round_based::{Incoming, Outgoing}; //! -//! # struct KeygenMsg; +//! # #[derive(round_based::ProtocolMsg)] +//! # enum KeygenMsg {} //! # struct KeyShare; //! # struct Error; //! # type Result = std::result::Result; //! # async fn doc() -> Result<()> { -//! async fn keygen(party: M, i: PartyIndex, n: u16) -> Result +//! async fn keygen(party: M, i: u16, n: u16) -> Result //! where -//! M: Mpc +//! M: round_based::Mpc //! { //! // ... //! # unimplemented!() //! } -//! async fn connect() -> impl Delivery { +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { //! // ... //! # round_based::_docs::fake_delivery() //! } //! //! let delivery = connect().await; -//! let party = MpcParty::connected(delivery); +//! let party = round_based::mpc::connected(delivery); //! //! # let (i, n) = (1, 3); //! let keyshare = keygen(party, i, n).await?; @@ -142,7 +147,7 @@ pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRo /// This desugars into: /// /// ```rust -/// use round_based::rounds_router::{ProtocolMsg, RoundMessage}; +/// use round_based::{ProtocolMsg, RoundMsg}; /// /// pub enum Message { /// Round1(Msg1), @@ -162,7 +167,7 @@ pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRo /// } /// } /// } -/// impl RoundMessage for Message { +/// impl RoundMsg for Message { /// const ROUND: u16 = 1; /// fn to_protocol_msg(round_msg: Msg1) -> Self { /// Message::Round1(round_msg) @@ -174,10 +179,10 @@ pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRo /// } /// } /// } -/// impl RoundMessage for Message { +/// impl RoundMsg for Message { /// const ROUND: u16 = 2; /// fn to_protocol_msg(round_msg: Msg2) -> Self { -/// Message::Round2(round_message) +/// Message::Round2(round_msg) /// } /// fn from_protocol_msg(protocol_msg: Self) -> Result { /// match protocol_msg { diff --git a/round-based/src/mpc/rounds_router/mod.rs b/round-based/src/mpc/rounds_router/mod.rs index 5b310c7..bd65054 100644 --- a/round-based/src/mpc/rounds_router/mod.rs +++ b/round-based/src/mpc/rounds_router/mod.rs @@ -1,51 +1,7 @@ //! Routes incoming MPC messages between rounds //! -//! [`RoundsRouter`] is an essential building block of MPC protocol, it processes incoming messages, groups -//! them by rounds, and provides convenient API for retrieving received messages at certain round. -//! -//! ## Example -//! -//! ```rust -//! use round_based::{Mpc, MpcParty, ProtocolMsg, Delivery, PartyIndex}; -//! use round_based::rounds_router::{RoundsRouter, simple_store::{RoundInput, RoundMsgs}}; -//! -//! #[derive(ProtocolMsg)] -//! pub enum Msg { -//! Round1(Msg1), -//! Round2(Msg2), -//! } -//! -//! pub struct Msg1 { /* ... */ } -//! pub struct Msg2 { /* ... */ } -//! -//! pub async fn some_mpc_protocol(party: M, i: PartyIndex, n: u16) -> Result -//! where -//! M: Mpc, -//! { -//! let MpcParty{ delivery, .. } = party.into_party(); -//! -//! let (incomings, _outgoings) = delivery.split(); -//! -//! // Build `Rounds` -//! let mut rounds = RoundsRouter::builder(); -//! let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); -//! let round2 = rounds.add_round(RoundInput::::p2p(i, n)); -//! let mut rounds = rounds.listen(incomings); -//! -//! // Receive messages from round 1 -//! let msgs: RoundMsgs = rounds.complete(round1).await?; -//! -//! // ... process received messages -//! -//! // Receive messages from round 2 -//! let msgs = rounds.complete(round2).await?; -//! -//! // ... -//! # todo!() -//! } -//! # type Output = (); -//! # type Error = Box; -//! ``` +//! Router is a building block, used in MpcParty to register rounds and route +//! incoming messages between them use alloc::{boxed::Box, collections::BTreeMap}; use core::{any::Any, mem}; diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index 7a3c4a9..01d7911 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -16,8 +16,9 @@ use super::RoundStore; /// /// ## Example /// ```rust -/// # use round_based::rounds_router::{MessagesStore, simple_store::RoundInput}; -/// # use round_based::{Incoming, MessageType}; +/// use round_based::{Incoming, MessageType}; +/// use round_based::round::{RoundStore, RoundInput}; +/// /// # fn main() -> Result<(), Box> { /// let mut input = RoundInput::<&'static str>::broadcast(1, 3); /// input.add_message(Incoming{ @@ -35,9 +36,12 @@ use super::RoundStore; /// assert!(!input.wants_more()); /// /// let output = input.output().unwrap(); -/// assert_eq!(output.clone().into_vec_without_me(), ["first party message", "third party message"]); /// assert_eq!( -/// output.clone().into_vec_including_me("my msg"), +/// output.clone().into_vec_without_me(), +/// ["first party message", "third party message"] +/// ); +/// assert_eq!( +/// output.into_vec_including_me("my msg"), /// ["first party message", "my msg", "third party message"] /// ); /// # Ok(()) } diff --git a/round-based/src/sim/async_env.rs b/round-based/src/sim/async_env.rs index 851a442..df4ef43 100644 --- a/round-based/src/sim/async_env.rs +++ b/round-based/src/sim/async_env.rs @@ -33,7 +33,8 @@ //! //! # type Result = std::result::Result; //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol you want to test //! pub async fn protocol_of_random_generation( //! party: M, @@ -41,7 +42,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -257,7 +258,8 @@ impl NextMessageId { /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// party: M, @@ -265,7 +267,7 @@ impl NextMessageId { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -334,7 +336,8 @@ where /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// rng: impl rand::RngCore, @@ -343,7 +346,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() diff --git a/round-based/src/sim/mod.rs b/round-based/src/sim/mod.rs index ba1bd97..92b845b 100644 --- a/round-based/src/sim/mod.rs +++ b/round-based/src/sim/mod.rs @@ -35,7 +35,8 @@ //! //! # type Result = std::result::Result; //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol you want to test //! pub async fn protocol_of_random_generation( //! party: M, @@ -43,7 +44,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -473,7 +474,8 @@ impl MessagesQueue { /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// party: M, @@ -481,7 +483,7 @@ impl MessagesQueue { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -527,7 +529,8 @@ where /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// rng: impl rand::RngCore, @@ -536,7 +539,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() diff --git a/round-based/src/state_machine/mod.rs b/round-based/src/state_machine/mod.rs index ff32c74..dded810 100644 --- a/round-based/src/state_machine/mod.rs +++ b/round-based/src/state_machine/mod.rs @@ -8,19 +8,19 @@ //! ## Example //! ```rust,no_run //! # fn main() -> anyhow::Result<()> { -//! use round_based::{Mpc, PartyIndex}; //! use anyhow::{Result, Error, Context as _}; //! //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol //! pub async fn protocol_of_random_generation( //! party: M, -//! i: PartyIndex, +//! i: u16, //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: round_based::Mpc //! { //! // ... //! # todo!() From f601e77ff2d3b849c6bbf8a2ee1283ac6666ba02 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 6 May 2025 12:02:44 +0200 Subject: [PATCH 04/29] Fix docs Signed-off-by: Denis Varlakov --- round-based/src/lib.rs | 6 +++--- round-based/src/mpc/mod.rs | 2 +- round-based/src/mpc/rounds_router/mod.rs | 2 +- round-based/src/mpc/runtime.rs | 4 ++-- round-based/src/round/mod.rs | 2 +- round-based/src/state_machine/delivery.rs | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index be04ee4..809707a 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -18,7 +18,7 @@ //! ## Networking //! //! In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to -//! implement [`Delivery`] trait which is basically a stream and a sink for receiving and sending messages. +//! provide a channel which implements a stream and a sink for receiving and sending messages. //! //! Message delivery should meet certain criterias that differ from protocol to protocol (refer to //! the documentation of the protocol you're using), but usually they are: @@ -41,7 +41,7 @@ //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync //! API, see [`state_machine`] module //! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro -//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](runtime) +//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](mpc::runtime) //! //! ## Join us in Discord! //! Feel free to reach out to us [in Discord](https://discordapp.com/channels/905194001349627914/1285268686147424388)! @@ -83,7 +83,7 @@ pub use self::mpc::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; #[doc(hidden)] pub mod _docs; -/// Derives [`ProtocolMsg`] and [`RoundMessage`] traits +/// Derives [`ProtocolMsg`] and [`RoundMsg`] traits /// /// See [`ProtocolMsg`] docs for more details #[cfg(feature = "derive")] diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index ae42452..8986423 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -56,7 +56,7 @@ pub trait Mpc { /// Protocol message type Msg; - /// Returned in [`Self::complete`] + /// Returned in [`Self::finish`] type Exec: MpcExecution; /// Error indicating that sending a message has failed type SendErr; diff --git a/round-based/src/mpc/rounds_router/mod.rs b/round-based/src/mpc/rounds_router/mod.rs index bd65054..4dcdf41 100644 --- a/round-based/src/mpc/rounds_router/mod.rs +++ b/round-based/src/mpc/rounds_router/mod.rs @@ -351,7 +351,7 @@ pub mod errors { /// Router error /// - /// Refer to [`CompleteRound::Router`] docs + /// Refer to [`CompleteRoundError::Router`] docs #[derive(Debug, thiserror::Error)] #[error(transparent)] pub struct RouterError(Reason); diff --git a/round-based/src/mpc/runtime.rs b/round-based/src/mpc/runtime.rs index 1852186..3331419 100644 --- a/round-based/src/mpc/runtime.rs +++ b/round-based/src/mpc/runtime.rs @@ -41,7 +41,7 @@ pub type DefaultRuntime = TokioRuntime; pub type DefaultRuntime = UnknownRuntime; /// Unknown async runtime -pub mod unknown_runtime { +mod unknown_runtime { /// Unknown async runtime /// /// Tries to implement runtime features using generic futures code. It's better to use @@ -56,7 +56,7 @@ pub mod unknown_runtime { } /// Future for the `yield_now` function. - pub struct YieldNow(bool); + struct YieldNow(bool); impl core::future::Future for YieldNow { type Output = (); diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index a4e8d93..71b7915 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -22,7 +22,7 @@ mod simple_store; /// [`.output()`]: Self::output /// /// ## Example -/// [`RoundInput`](super::simple_store::RoundInput) is an simple messages store. Refer to its docs to see usage examples. +/// [`RoundInput`] is an simple messages store. Refer to its docs to see usage examples. pub trait RoundStore: Sized + 'static { /// Message type type Msg; diff --git a/round-based/src/state_machine/delivery.rs b/round-based/src/state_machine/delivery.rs index 50826ef..ef92fcb 100644 --- a/round-based/src/state_machine/delivery.rs +++ b/round-based/src/state_machine/delivery.rs @@ -62,7 +62,7 @@ impl crate::Sink> for Delivery { } } -/// Error returned by [`Outgoings`] sink +/// Error returned by [`Delivery`] #[derive(Debug, thiserror::Error)] #[error(transparent)] pub struct DeliveryErr(Reason); From 502085de1b6fba2d387f3400057a10510c358abc Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 6 May 2025 12:32:48 +0200 Subject: [PATCH 05/29] Restructure the lib a bit Signed-off-by: Denis Varlakov --- round-based-tests/tests/rounds.rs | 31 ++--- round-based/src/lib.rs | 2 +- round-based/src/mpc/mod.rs | 17 +-- .../src/mpc/{party.rs => party/mod.rs} | 107 +++++++++++++----- .../{rounds_router/mod.rs => party/router.rs} | 49 ++------ round-based/src/mpc/{ => party}/runtime.rs | 0 round-based/src/mpc/rounds_router/store.rs | 1 - round-based/src/state_machine/runtime.rs | 2 +- 8 files changed, 107 insertions(+), 102 deletions(-) rename round-based/src/mpc/{party.rs => party/mod.rs} (60%) rename round-based/src/mpc/{rounds_router/mod.rs => party/router.rs} (86%) rename round-based/src/mpc/{ => party}/runtime.rs (100%) delete mode 100644 round-based/src/mpc/rounds_router/store.rs diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/rounds.rs index 2fe81e3..84cc5de 100644 --- a/round-based-tests/tests/rounds.rs +++ b/round-based-tests/tests/rounds.rs @@ -8,10 +8,7 @@ use rand_chacha::rand_core::SeedableRng; use random_generation_protocol::{ protocol_of_random_generation, CommitMsg, DecommitMsg, Error, Msg, }; -use round_based::{ - mpc::errors::{CompleteRoundError, WithIo}, - Incoming, MessageType, -}; +use round_based::{mpc::party::CompleteRoundError, Incoming, MessageType}; const PARTY0_SEED: [u8; 32] = hex!("6772d079d5c984b3936a291e36b0d3dc6c474e36ed4afdfc973ef79a431ca870"); @@ -94,9 +91,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round1Receive(WithIo::Other( - CompleteRoundError::ProcessMsg(_) - ))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) ) } @@ -140,9 +135,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round2Receive(WithIo::Other( - CompleteRoundError::ProcessMsg(_) - ))) + Err(Error::Round2Receive(CompleteRoundError::ProcessMsg(_))) ) } @@ -160,9 +153,7 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() assert_matches!( output, - Err(Error::Round1Receive(WithIo::Other( - CompleteRoundError::ProcessMsg(_) - ))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) ) } @@ -298,7 +289,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { ]) .await; - assert_matches!(output, Err(Error::Round2Receive(WithIo::Io(_)))); + assert_matches!(output, Err(Error::Round2Receive(CompleteRoundError::Io(_)))); } #[tokio::test] @@ -340,7 +331,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { ]) .await; - assert_matches!(output, Err(Error::Round1Receive(WithIo::Io(_)))); + assert_matches!(output, Err(Error::Round1Receive(CompleteRoundError::Io(_)))); } #[tokio::test] @@ -373,7 +364,10 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { ]) .await; - assert_matches!(output, Err(Error::Round2Receive(WithIo::UnexpectedEof))); + assert_matches!( + output, + Err(Error::Round2Receive(CompleteRoundError::UnexpectedEof)) + ); } async fn run_protocol( @@ -381,10 +375,7 @@ async fn run_protocol( ) -> Result< [u8; 32], random_generation_protocol::Error< - round_based::mpc::errors::WithIo< - E, - round_based::mpc::errors::CompleteRoundError, - >, + round_based::mpc::party::CompleteRoundError, E, >, > diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index 809707a..f23752b 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -41,7 +41,7 @@ //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync //! API, see [`state_machine`] module //! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro -//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](mpc::runtime) +//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](mpc::party::runtime) //! //! ## Join us in Discord! //! Feel free to reach out to us [in Discord](https://discordapp.com/channels/905194001349627914/1285268686147424388)! diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index 8986423..db9f1a7 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -37,19 +37,10 @@ use crate::{round::RoundStore, Outgoing, PartyIndex}; -mod party; -mod rounds_router; -pub mod runtime; - -pub use self::{ - party::{Halves, MpcParty}, - rounds_router::Round, -}; - -/// When something goes wrong -pub mod errors { - pub use super::{party::WithIo, rounds_router::errors::*}; -} +pub mod party; + +#[doc(no_inline)] +pub use self::party::{Halves, MpcParty}; /// Abstracts functionalities needed for MPC protocol execution pub trait Mpc { diff --git a/round-based/src/mpc/party.rs b/round-based/src/mpc/party/mod.rs similarity index 60% rename from round-based/src/mpc/party.rs rename to round-based/src/mpc/party/mod.rs index 20b1a3d..cdcf35a 100644 --- a/round-based/src/mpc/party.rs +++ b/round-based/src/mpc/party/mod.rs @@ -1,8 +1,17 @@ +//! Provides [`MpcParty`], default engine for MPC protocol execution that implements [`Mpc`] and [`MpcExecution`] traits + use futures_util::{Sink, SinkExt, Stream, StreamExt}; use crate::{round::RoundStore, Incoming, Outgoing}; -use super::{rounds_router, runtime, Mpc, MpcExecution, ProtocolMsg, RoundMsg}; +use super::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; + +mod router; +pub mod runtime; + +pub use self::router::{errors::RouterError, Round}; +#[doc(no_inline)] +pub use self::runtime::AsyncRuntime; /// MPC engine, carries out the protocol /// @@ -13,7 +22,7 @@ use super::{rounds_router, runtime, Mpc, MpcExecution, ProtocolMsg, RoundMsg}; /// /// Implements [`Mpc`] and [`MpcExecution`]. pub struct MpcParty { - router: rounds_router::RoundsRouter, + router: router::RoundsRouter, io: D, runtime: R, } @@ -27,7 +36,7 @@ where /// Constructs [`MpcParty`] pub fn connected(delivery: D) -> Self { Self { - router: rounds_router::RoundsRouter::new(), + router: router::RoundsRouter::new(), io: delivery, runtime: runtime::DefaultRuntime::default(), } @@ -94,11 +103,11 @@ where D: Sink, Error = IoErr> + Unpin, AsyncR: runtime::AsyncRuntime, { - type Round = rounds_router::Round; + type Round = router::Round; type Msg = M; - type CompleteRoundErr = WithIo>; + type CompleteRoundErr = CompleteRoundError; type SendErr = IoErr; @@ -112,7 +121,7 @@ where { // Check if round is already completed round = match self.router.complete_round(round) { - Ok(output) => return output.map_err(WithIo::Other), + Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})), Err(w) => w, }; @@ -122,15 +131,13 @@ where .io .next() .await - .ok_or(WithIo::UnexpectedEof)? - .map_err(WithIo::Io)?; - self.router - .received_msg(incoming) - .map_err(|err| WithIo::Other(err.into()))?; + .ok_or(CompleteRoundError::UnexpectedEof)? + .map_err(CompleteRoundError::Io)?; + self.router.received_msg(incoming)?; // Check if round was just completed round = match self.router.complete_round(round) { - Ok(output) => return output.map_err(WithIo::Other), + Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})), Err(w) => w, }; } @@ -145,20 +152,6 @@ where } } -/// Error indicating that either `IoErr` occurred, or `OtherErr` -#[derive(Debug, thiserror::Error)] -pub enum WithIo { - /// IO error - #[error(transparent)] - Io(IoErr), - /// Unexpected EOF - #[error("unexpected eof")] - UnexpectedEof, - /// Other error - #[error(transparent)] - Other(OtherErr), -} - pin_project_lite::pin_project! { /// Merges a stream and a sink into one structure that implements both [`Stream`] and [`Sink`] pub struct Halves { @@ -231,3 +224,65 @@ where this.outgoings.poll_close(cx) } } + +/// Error returned by [`MpcParty::complete`] +/// +/// May indicate malicious behavior (e.g. adversary sent a message that aborts protocol execution) +/// or some misconfiguration of the protocol network (e.g. received a message from the round that +/// was not registered via [`Mpc::add_round`]). +#[derive(Debug, thiserror::Error)] +pub enum CompleteRoundError { + /// [`RoundStore`] returned an error + /// + /// Refer to this rounds store documentation to understand why it could fail + #[error(transparent)] + ProcessMsg(ProcessErr), + + /// Router error + /// + /// Indicates that for some reason router was not able to process a message. This can be the case of: + /// - Router API misuse \ + /// E.g. when received a message from the round that was not registered in the router + /// - Improper [`RoundStore`] implementation \ + /// Indicates that round store is not properly implemented and contains a flaw. \ + /// For instance, this error is returned when round store indicates that it doesn't need + /// any more messages ([`RoundStore::wants_more`] + /// returns `false`), but then it didn't output anything ([`RoundStore::output`] + /// returns `Err(_)`) + /// - Bug in the router + /// + /// This error is always related to some implementation flaw or bug: either in the code that uses + /// the router, or in the round store implementation, or in the router itself. When implementation + /// is correct, this error never appears. Thus, it should not be possible for the adversary to "make + /// this error happen." + Router(router::errors::RouterError), + + /// Receiving the next message resulted into I/O error + Io(IoErr), + /// Channel of incoming messages was closed before protocol completion + UnexpectedEof, +} + +impl CompleteRoundError { + /// Maps I/O error + pub fn map_io_err(self, f: impl FnOnce(IoErr) -> E) -> CompleteRoundError { + match self { + CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(e), + CompleteRoundError::Router(e) => CompleteRoundError::Router(e), + CompleteRoundError::Io(e) => CompleteRoundError::Io(f(e)), + CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof, + } + } + /// Maps [`CompleteRoundError::ProcessMsg`] + pub fn map_process_err( + self, + f: impl FnOnce(ProcessErr) -> E, + ) -> CompleteRoundError { + match self { + CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(f(e)), + CompleteRoundError::Router(e) => CompleteRoundError::Router(e), + CompleteRoundError::Io(e) => CompleteRoundError::Io(e), + CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof, + } + } +} diff --git a/round-based/src/mpc/rounds_router/mod.rs b/round-based/src/mpc/party/router.rs similarity index 86% rename from round-based/src/mpc/rounds_router/mod.rs rename to round-based/src/mpc/party/router.rs index 4dcdf41..e2a85c6 100644 --- a/round-based/src/mpc/rounds_router/mod.rs +++ b/round-based/src/mpc/party/router.rs @@ -4,7 +4,7 @@ //! incoming messages between them use alloc::{boxed::Box, collections::BTreeMap}; -use core::{any::Any, mem}; +use core::{any::Any, convert::Infallible, mem}; use phantom_type::PhantomType; use tracing::{error, trace_span, warn}; @@ -82,7 +82,7 @@ where pub fn complete_round( &mut self, round: Round, - ) -> Result>, Round> + ) -> Result>, Round> where R: RoundStore, M: RoundMsg, @@ -110,7 +110,7 @@ where fn retrieve_round_output( round: &mut Box>, - ) -> Result> + ) -> Result> where R: RoundStore, M: RoundMsg, @@ -120,7 +120,7 @@ where .downcast::() .or(Err(errors::Bug::MismatchedOutputType))?), Ok(Err(any)) => Err(*any - .downcast::>() + .downcast::>() .or(Err(errors::Bug::MismatchedErrorType))?), Err(err) => Err(errors::Bug::TakeRoundResult(err).into()), } @@ -177,7 +177,7 @@ enum TakeOutputError { enum ProcessRoundMessageImpl> { InProgress { store: S, _ph: PhantomType }, - Completed(Result>), + Completed(Result>), Gone, } @@ -206,7 +206,7 @@ where fn _process_message( store: &mut S, msg: Incoming, - ) -> Result<(), errors::CompleteRoundError> { + ) -> Result<(), errors::CompleteRoundError> { let msg = msg.try_map(M::from_protocol_msg).map_err(|msg| { errors::Bug::MessageFromAnotherRound { actual_number: msg.round(), @@ -307,6 +307,8 @@ impl NeedsMoreMessages { /// When something goes wrong pub mod errors { + pub use crate::mpc::party::CompleteRoundError; + use super::TakeOutputError; #[derive(Debug, thiserror::Error)] @@ -316,39 +318,6 @@ pub mod errors { pub(super) witness_provided: bool, } - /// Error returned when processing incoming messages at certain round - /// - /// May indicate malicious behavior (e.g. adversary sent a message that aborts protocol execution) - /// or some misconfiguration of the protocol network (e.g. received a message from the round that - /// was not registered via [`Mpc::add_round`](crate::Mpc::add_round)). - #[derive(Debug, thiserror::Error)] - pub enum CompleteRoundError { - /// [`RoundStore`](crate::round::RoundStore) returned an error - /// - /// Refer to this rounds store documentation to understand why it could fail - #[error(transparent)] - ProcessMsg(ProcessErr), - - /// Router error - /// - /// Indicates that for some reason router was not able to process a message. This can be the case of: - /// - Router API misuse \ - /// E.g. when received a message from the round that was not registered in the router - /// - Improper [`RoundStore`](crate::round::RoundStore) implementation \ - /// Indicates that round store is not properly implemented and contains a flaw. \ - /// For instance, this error is returned when round store indicates that it doesn't need - /// any more messages ([`RoundStore::wants_more`](crate::round::RoundStore::wants_more) - /// returns `false`), but then it didn't output anything ([`RoundStore::output`](crate::round::RoundStore::output) - /// returns `Err(_)`) - /// - Bug in the router - /// - /// This error is always related to some implementation flaw or bug: either in the code that uses - /// the router, or in the round store implementation, or in the router itself. When implementation - /// is correct, this error never appears. Thus, it should not be possible for the adversary to "make - /// this error happen." - Router(RouterError), - } - /// Router error /// /// Refer to [`CompleteRoundError::Router`] docs @@ -421,7 +390,7 @@ pub mod errors { macro_rules! impl_round_complete_from { ($(|$err:ident: $err_ty:ty| $err_fn:expr),+$(,)?) => {$( - impl From<$err_ty> for CompleteRoundError { + impl From<$err_ty> for CompleteRoundError { fn from($err: $err_ty) -> Self { $err_fn } diff --git a/round-based/src/mpc/runtime.rs b/round-based/src/mpc/party/runtime.rs similarity index 100% rename from round-based/src/mpc/runtime.rs rename to round-based/src/mpc/party/runtime.rs diff --git a/round-based/src/mpc/rounds_router/store.rs b/round-based/src/mpc/rounds_router/store.rs deleted file mode 100644 index 8b13789..0000000 --- a/round-based/src/mpc/rounds_router/store.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/round-based/src/state_machine/runtime.rs b/round-based/src/state_machine/runtime.rs index 81229bd..6d64577 100644 --- a/round-based/src/state_machine/runtime.rs +++ b/round-based/src/state_machine/runtime.rs @@ -11,7 +11,7 @@ impl Runtime { } } -impl crate::mpc::runtime::AsyncRuntime for Runtime { +impl crate::mpc::party::AsyncRuntime for Runtime { async fn yield_now(&self) { YieldNow { shared_state: self.shared_state.clone(), From 6855664068b4d36cccd583c8bac3ec02fb90fde0 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 6 May 2025 15:18:19 +0200 Subject: [PATCH 06/29] Update example Signed-off-by: Denis Varlakov --- examples/random-generation-protocol/src/lib.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index d1649e2..e2ae374 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -52,7 +52,7 @@ pub async fn protocol_of_random_generation( i: u16, n: u16, mut rng: R, -) -> Result<[u8; 32], Error, M::SendErr>> +) -> Result<[u8; 32], ErrorM> where M: Mpc, R: rand_core::RngCore, @@ -140,9 +140,11 @@ pub enum Error { }, } -/// Error indicating that receiving message at certain round failed -pub type CompleteRoundErr = - round_based::mpc::CompleteRoundErr; +/// Error type deduced from `M: Mpc` +pub type ErrorM = Error< + round_based::mpc::CompleteRoundErr, + ::SendErr, +>; /// Blames a party in cheating during the protocol #[derive(Debug)] From 81a7e516bd2c7774d1cad45a89ea91c03bb18a2a Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 6 May 2025 16:29:03 +0200 Subject: [PATCH 07/29] Make messages aware of reliable broadcast Signed-off-by: Denis Varlakov --- .../random-generation-protocol/src/lib.rs | 15 +-- round-based-tests/tests/rounds.rs | 80 +++++++++------- round-based/src/delivery.rs | 42 +++++++-- round-based/src/mpc/mod.rs | 23 ++++- round-based/src/round/mod.rs | 92 +++++++++++++++++++ round-based/src/round/simple_store.rs | 81 ++++++++++------ round-based/src/sim/async_env.rs | 2 +- round-based/src/sim/mod.rs | 4 +- round-based/src/state_machine/shared_state.rs | 8 +- 9 files changed, 266 insertions(+), 81 deletions(-) diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index e2ae374..55b1995 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -69,8 +69,9 @@ where rng.fill_bytes(&mut local_randomness); // 2. Commit local randomness (broadcast m=sha256(randomness)) + // This message must be reliably broadcasted to guarantee protocol security let commitment = Sha256::digest(local_randomness); - mpc.send_broadcast(Msg::CommitMsg(CommitMsg { commitment })) + mpc.reliably_broadcast(Msg::CommitMsg(CommitMsg { commitment })) .await .map_err(Error::Round1Send)?; @@ -78,7 +79,9 @@ where let commitments = mpc.complete(round1).await.map_err(Error::Round1Receive)?; // 4. Open local randomness - mpc.send_broadcast(Msg::DecommitMsg(DecommitMsg { + // This message will be sent to all other participants, but it doesn't require + // a reliable broadcast + mpc.send_to_all(Msg::DecommitMsg(DecommitMsg { randomness: local_randomness, })) .await @@ -242,7 +245,7 @@ mod tests { .received_msg(Incoming { id: 0, sender: 1, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: party1_com, }), @@ -256,7 +259,7 @@ mod tests { .received_msg(Incoming { id: 1, sender: 2, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: party2_com, }), @@ -289,7 +292,7 @@ mod tests { .received_msg(Incoming { id: 3, sender: 1, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: party1_rng, }), @@ -303,7 +306,7 @@ mod tests { .received_msg(Incoming { id: 3, sender: 2, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: party2_rng, }), diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/rounds.rs index 84cc5de..9bc1087 100644 --- a/round-based-tests/tests/rounds.rs +++ b/round-based-tests/tests/rounds.rs @@ -31,7 +31,7 @@ async fn random_generation_completes() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -39,7 +39,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -47,7 +47,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -55,7 +55,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -67,13 +67,31 @@ async fn random_generation_completes() { assert_eq!(output, PROTOCOL_OUTPUT); } +#[tokio::test] +async fn protocol_terminates_with_error_if_party_broadcasts_msg_unreliably_at_round1() { + let output = run_protocol([Ok::<_, Infallible>(Incoming { + id: 0, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + }), + })]) + .await; + + assert_matches!( + output, + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) + ) +} + #[tokio::test] async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_round1() { let output = run_protocol([ Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -81,7 +99,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 1, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY_OVERWRITES.into(), }), @@ -101,7 +119,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -109,7 +127,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -117,7 +135,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -125,7 +143,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 3, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY_OVERWRITES, }), @@ -144,7 +162,7 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() let output = run_protocol([Ok::<_, Infallible>(Incoming { id: 0, sender: 3, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -163,7 +181,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -171,7 +189,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -179,7 +197,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY_OVERWRITES.into(), }), @@ -187,7 +205,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 3, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -195,7 +213,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 4, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -213,7 +231,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -221,7 +239,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -229,7 +247,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -237,7 +255,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -256,7 +274,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -264,7 +282,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -272,7 +290,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -281,7 +299,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -299,7 +317,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -307,7 +325,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -315,7 +333,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -323,7 +341,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -340,7 +358,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -348,7 +366,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -356,7 +374,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), diff --git a/round-based/src/delivery.rs b/round-based/src/delivery.rs index f5a562c..ea9c856 100644 --- a/round-based/src/delivery.rs +++ b/round-based/src/delivery.rs @@ -16,7 +16,11 @@ pub struct Incoming { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MessageType { /// Message was broadcasted - Broadcast, + Broadcast { + /// Indicates that message was reliably broadcasted, meaning that it's guaranteed (cryptographically or through + /// other trust assumptions) that all honest participants of the protocol received the same message + reliable: bool, + }, /// P2P message P2P, } @@ -68,9 +72,14 @@ impl Incoming { } } - /// Checks whether it's broadcast message + /// Checks whether it's broadcast message (regardless if it's reliable or not) pub fn is_broadcast(&self) -> bool { - matches!(self.msg_type, MessageType::Broadcast) + matches!(self.msg_type, MessageType::Broadcast { .. }) + } + + /// Checks if message was reliably broadcasted + pub fn is_reliably_broadcasted(&self) -> bool { + matches!(self.msg_type, MessageType::Broadcast { reliable: true }) } /// Checks whether it's p2p message @@ -90,9 +99,17 @@ pub struct Outgoing { impl Outgoing { /// Constructs an outgoing message addressed to all parties - pub fn broadcast(msg: M) -> Self { + pub fn all_parties(msg: M) -> Self { Self { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: false }, + msg, + } + } + + /// Constructs an outgoing message addressed to all parties via reliable broadcast channel + pub fn reliable_broadcast(msg: M) -> Self { + Self { + recipient: MessageDestination::AllParties { reliable: true }, msg, } } @@ -139,7 +156,12 @@ impl Outgoing { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MessageDestination { /// Broadcast message - AllParties, + AllParties { + /// Indicates that message needs to be reliably broadcasted, meaning that when recipient receives this message, + /// it must be assured (cryptographically or through other trust assumptions) that all honest participants of the + /// protocol received the same message + reliable: bool, + }, /// P2P message OneParty(PartyIndex), } @@ -149,8 +171,12 @@ impl MessageDestination { pub fn is_p2p(&self) -> bool { matches!(self, MessageDestination::OneParty(_)) } - /// Returns `true` if it's broadcast message + /// Returns `true` if it's broadcast message (regardless if it's reliable or not) pub fn is_broadcast(&self) -> bool { - matches!(self, MessageDestination::AllParties) + matches!(self, MessageDestination::AllParties { .. }) + } + /// Returns `true` if it's reliable broadcast message + pub fn is_reliable_broadcast(&self) -> bool { + matches!(self, MessageDestination::AllParties { reliable: true }) } } diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index db9f1a7..014212b 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -101,9 +101,26 @@ pub trait MpcExecution { self.send(Outgoing::p2p(recipient, msg)).await } - /// Sends a broadcast message - async fn send_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { - self.send(Outgoing::broadcast(msg)).await + /// Sends a message that will be received by all parties + /// + /// Message will be broadcasted, but not reliably. If you need a reliable broadcast, use + /// [`reliably_broadcast`] method. + async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::all_parties(msg)).await + } + + /// Reliably broadcasts a message + /// + /// Message will be received by all participants of the protocol. Moreover, when recipient receives a + /// message, it will be assured (cryptographically or through other trust assumptions) that all honest + /// participants of the protocol received the same message. + /// + /// It's a responsibility of a message delivery layer to provide the reliable broadcast mechanism. If + /// it's not supported, this method returns an error. Note that not every MPC protocol requires the + /// reliable broadcast, so it's totally normal to have a message delivery implementation that does + /// not support it. + async fn reliably_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::reliable_broadcast(msg)).await } /// Yields execution diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index 71b7915..52ff937 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -1,5 +1,7 @@ //! Primitives that process and collect messages received at certain round +use core::any::Any; + use crate::Incoming; pub use self::simple_store::{broadcast, p2p, RoundInput, RoundInputError, RoundMsgs}; @@ -44,4 +46,94 @@ pub trait RoundStore: Sized + 'static { /// If store indicated that it needs no more messages (ie `store.wants_more() == false`), then /// this function must return `Ok(_)`. fn output(self) -> Result; + + /// Interface that exposes ability to retrieve generic information about the round store + /// + /// For reading store properties, it's recommended to use [`RoundStoreExt::read_prop`] method which + /// uses this function internally. + /// + /// When implementing `RoundStore` trait, if you wish to expose no extra information, leave the default + /// implementation of this method. If you do wish to expose certain properties that will be accessible + /// through [`RoundStoreExt::read_prop`], follow this example: + /// + /// ```rust + /// todo!() + /// ``` + fn read_any_prop(&self, property: &mut dyn Any) { + let _ = property; + } +} + +/// Extra functionalities defined for any [`RoundStore`] +pub trait RoundStoreExt: RoundStore { + /// Reads a property `P` of the store + /// + /// Returns `Some(property_value)` if this store exposes property `P`, otherwise returns `None` + fn read_prop(&self) -> Option

; + + /// Constructs a new store that exposes property `P` with provided value + /// + /// If store already provides a property `P`, it will be overwritten + fn set_prop(self, value: P) -> WithProp; +} + +impl RoundStoreExt for S { + fn read_prop(&self) -> Option

{ + let mut p: Option

= None; + self.read_any_prop(&mut p); + p + } + + fn set_prop(self, value: P) -> WithProp { + WithProp { + prop: value, + store: self, + } + } +} + +/// Returned by [`RoundStoreExt::set_prop`] +pub struct WithProp { + prop: P, + store: S, +} + +impl RoundStore for WithProp +where + S: RoundStore, + P: Clone + 'static, +{ + type Msg = S::Msg; + type Output = S::Output; + type Error = S::Error; + + #[inline(always)] + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.store.add_message(msg) + } + #[inline(always)] + fn wants_more(&self) -> bool { + self.store.wants_more() + } + #[inline(always)] + fn output(self) -> Result { + self.store.output().map_err(|store| Self { + prop: self.prop, + store, + }) + } + + fn read_any_prop(&self, property: &mut dyn Any) { + if let Some(p) = property.downcast_mut::>() { + *p = Some(self.prop.clone()) + } else { + self.store.read_any_prop(property); + } + } +} + +/// Properties that may be exposed by [`RoundStore`] +pub mod props { + /// Indicates whether the round requires messages to be reliably broadcasted + pub struct RequiresReliableBroadcast(pub bool); } diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index 01d7911..0608187 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -89,9 +89,16 @@ impl RoundInput { /// Construct a new store for broadcast messages /// - /// The same as `RoundInput::new(i, n, MessageType::Broadcast)` + /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: false })` pub fn broadcast(i: PartyIndex, n: u16) -> Self { - Self::new(i, n, MessageType::Broadcast) + Self::new(i, n, MessageType::Broadcast { reliable: false }) + } + + /// Construct a new store for reliable broadcast messages + /// + /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: true })` + pub fn reliable_broadcast(i: PartyIndex, n: u16) -> Self { + Self::new(i, n, MessageType::Broadcast { reliable: true }) } /// Construct a new store for p2p messages @@ -101,8 +108,17 @@ impl RoundInput { Self::new(i, n, MessageType::P2P) } - fn is_expected_type_of_msg(&self, msg_type: MessageType) -> bool { - self.expected_msg_type == msg_type + fn is_expected_type_of_msg(&self, actual_msg_type: MessageType) -> bool { + // self.expected_msg_type == actual_msg_type + match (self.expected_msg_type, actual_msg_type) { + (MessageType::P2P, MessageType::P2P) + | (MessageType::Broadcast { reliable: false }, MessageType::Broadcast { .. }) + | ( + MessageType::Broadcast { reliable: true }, + MessageType::Broadcast { reliable: true }, + ) => true, + _ => false, + } } } @@ -431,11 +447,31 @@ mod tests { #[test] fn store_returns_error_if_message_type_mismatched() { let mut store = RoundInput::::p2p(3, 5); + for reliable in [true, false] { + let err = store + .add_message(Incoming { + id: 0, + sender: 0, + msg_type: MessageType::Broadcast { reliable }, + msg: Msg(1), + }) + .unwrap_err(); + assert_matches!( + err, + RoundInputError::MismatchedMessageType { + msg_id: 0, + expected: MessageType::P2P, + actual: MessageType::Broadcast { reliable: r } + } if r == reliable + ); + } + + let mut store = RoundInput::::broadcast(3, 5); let err = store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::Broadcast, + msg_type: MessageType::P2P, msg: Msg(1), }) .unwrap_err(); @@ -443,12 +479,12 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::P2P, - actual: MessageType::Broadcast + expected: MessageType::Broadcast { reliable: false }, + actual: MessageType::P2P, } ); - let mut store = RoundInput::::broadcast(3, 5); + let mut store = RoundInput::::reliable_broadcast(3, 5); let err = store .add_message(Incoming { id: 0, @@ -461,27 +497,15 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::Broadcast, + expected: MessageType::Broadcast { reliable: true }, actual: MessageType::P2P, } ); - for sender in 0u16..5 { - store - .add_message(Incoming { - id: 0, - sender, - msg_type: MessageType::Broadcast, - msg: Msg(1), - }) - .unwrap(); - } - - let mut store = RoundInput::::broadcast(3, 5); let err = store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::P2P, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg(1), }) .unwrap_err(); @@ -489,15 +513,20 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::Broadcast, - actual, - } if actual == MessageType::P2P + expected: MessageType::Broadcast { reliable: true }, + actual: MessageType::Broadcast { reliable: false }, + } ); + } + + #[test] + fn non_reliable_broadcast_round_accepts_reliable_broadcast_messages() { + let mut store = RoundInput::::broadcast(3, 5); store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg(1), }) .unwrap(); diff --git a/round-based/src/sim/async_env.rs b/round-based/src/sim/async_env.rs index df4ef43..727d093 100644 --- a/round-based/src/sim/async_env.rs +++ b/round-based/src/sim/async_env.rs @@ -212,7 +212,7 @@ impl Sink> for MockedOutgoing { fn start_send(self: Pin<&mut Self>, msg: Outgoing) -> Result<(), Self::Error> { let msg_type = match msg.recipient { - MessageDestination::AllParties => MessageType::Broadcast, + MessageDestination::AllParties { reliable } => MessageType::Broadcast { reliable }, MessageDestination::OneParty(_) => MessageType::P2P, }; self.sender diff --git a/round-based/src/sim/mod.rs b/round-based/src/sim/mod.rs index 92b845b..e1974f5 100644 --- a/round-based/src/sim/mod.rs +++ b/round-based/src/sim/mod.rs @@ -418,7 +418,7 @@ impl MessagesQueue { fn send_message(&mut self, sender: u16, msg: Outgoing) -> Result<(), SimError> { match msg.recipient { - MessageDestination::AllParties => { + MessageDestination::AllParties { reliable } => { let mut msg_ids = self.next_id..; for (destination, msg_id) in (0..) .zip(&mut self.queue) @@ -429,7 +429,7 @@ impl MessagesQueue { destination.push_back(Incoming { id: msg_id, sender, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable }, msg: msg.msg.clone(), }) } diff --git a/round-based/src/state_machine/shared_state.rs b/round-based/src/state_machine/shared_state.rs index cd8ddf8..86bc622 100644 --- a/round-based/src/state_machine/shared_state.rs +++ b/round-based/src/state_machine/shared_state.rs @@ -29,7 +29,7 @@ impl SharedStateRef { ))) } - /// Any protocol-initated work (like flushing message to be sent, receiving message, etc.) can + /// Any protocol-initiated work (like flushing message to be sent, receiving message, etc.) can /// only be scheduled when there was no other task scheduled. /// /// This method checks whether a task can be scheduled, and returns [`CanSchedule`] which @@ -185,7 +185,7 @@ mod test { let executor_state = shared_state; let msg = Outgoing { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: false }, msg: 1, }; outgoings_state @@ -235,7 +235,7 @@ mod test { let incoming_msg = Incoming { id: 0, sender: 1, - msg_type: crate::MessageType::Broadcast, + msg_type: crate::MessageType::Broadcast { reliable: false }, msg: "hello", }; executor_state.executor_received_msg(incoming_msg).unwrap(); @@ -293,7 +293,7 @@ mod test { let shared_state = SharedStateRef::new(); shared_state .protocol_saves_msg_to_be_sent(Outgoing { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: false }, msg: 1, }) .expect("msg slot isn't empty"); From 823bb3d535c386021bdc11fe32212f1a5e5aa630 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Wed, 7 May 2025 11:16:09 +0200 Subject: [PATCH 08/29] fix readme & clippy Signed-off-by: Denis Varlakov --- README.md | 4 ++-- round-based/src/round/simple_store.rs | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index bec638e..94500f0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ multiparty protocols (e.g. threshold signing, random beacons, etc.). ## Networking In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to -implement `Delivery` trait which is basically a stream and a sink for receiving and sending messages. +provide a channel which implements a stream and a sink for receiving and sending messages. Message delivery should meet certain criterias that differ from protocol to protocol (refer to the documentation of the protocol you're using), but usually they are: @@ -39,7 +39,7 @@ the documentation of the protocol you're using), but usually they are: * `sim-async` enables protocol execution simulation with tokio runtime, see `sim::async_env` module * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync - API, see `state_machine` module + API, see `state_machine` module * `derive` is needed to use `ProtocolMsg` proc macro * `runtime-tokio` enables tokio-specific implementation of async runtime diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index 0608187..9b62c90 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -109,16 +109,18 @@ impl RoundInput { } fn is_expected_type_of_msg(&self, actual_msg_type: MessageType) -> bool { - // self.expected_msg_type == actual_msg_type - match (self.expected_msg_type, actual_msg_type) { + matches!( + (self.expected_msg_type, actual_msg_type), (MessageType::P2P, MessageType::P2P) - | (MessageType::Broadcast { reliable: false }, MessageType::Broadcast { .. }) - | ( - MessageType::Broadcast { reliable: true }, - MessageType::Broadcast { reliable: true }, - ) => true, - _ => false, - } + | ( + MessageType::Broadcast { reliable: false }, + MessageType::Broadcast { .. } + ) + | ( + MessageType::Broadcast { reliable: true }, + MessageType::Broadcast { reliable: true }, + ) + ) } } From 22a6312b6dc1ab946a687aa15defad5f0617637d Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Wed, 7 May 2025 11:43:59 +0200 Subject: [PATCH 09/29] Update tests & docs Signed-off-by: Denis Varlakov --- .../random-generation-protocol/src/lib.rs | 2 +- round-based-tests/tests/rounds.rs | 42 +++++++++++++++--- round-based/src/mpc/mod.rs | 2 +- round-based/src/round/mod.rs | 44 +++++++++++++++++-- round-based/src/round/simple_store.rs | 14 ++++-- .../tests/derive/compile-fail/wrong_usage.rs | 10 ++--- .../derive/compile-fail/wrong_usage.stderr | 30 ++++++------- .../derive/compile-pass/correct_usage.rs | 2 +- 8 files changed, 109 insertions(+), 37 deletions(-) diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index 55b1995..a73077c 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -58,7 +58,7 @@ where R: rand_core::RngCore, { // Define rounds - let round1 = mpc.add_round(round_based::round::broadcast::(i, n)); + let round1 = mpc.add_round(round_based::round::reliable_broadcast::(i, n)); let round2 = mpc.add_round(round_based::round::broadcast::(i, n)); let mut mpc = mpc.finish(); diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/rounds.rs index 9bc1087..c2790f0 100644 --- a/round-based-tests/tests/rounds.rs +++ b/round-based-tests/tests/rounds.rs @@ -71,7 +71,7 @@ async fn random_generation_completes() { async fn protocol_terminates_with_error_if_party_broadcasts_msg_unreliably_at_round1() { let output = run_protocol([Ok::<_, Infallible>(Incoming { id: 0, - sender: 1, + sender: 2, msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), @@ -81,7 +81,13 @@ async fn protocol_terminates_with_error_if_party_broadcasts_msg_unreliably_at_ro assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::MismatchedMessageType { + msg_id: 0, + expected: round_based::MessageType::Broadcast { reliable: true }, + actual: round_based::MessageType::Broadcast { reliable: false } + } + ))) ) } @@ -109,7 +115,12 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::AttemptToOverwriteReceivedMsg { + msgs_ids: [0, 1], + sender: 1 + } + ))) ) } @@ -153,7 +164,12 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round2Receive(CompleteRoundError::ProcessMsg(_))) + Err(Error::Round2Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::AttemptToOverwriteReceivedMsg { + msgs_ids: [2, 3], + sender: 1 + } + ))) ) } @@ -171,7 +187,13 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMsg(_))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::SenderIndexOutOfRange { + msg_id: 0, + sender: 3, + n: 3 + } + ))) ) } @@ -307,7 +329,10 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { ]) .await; - assert_matches!(output, Err(Error::Round2Receive(CompleteRoundError::Io(_)))); + assert_matches!( + output, + Err(Error::Round2Receive(CompleteRoundError::Io(DummyError))) + ); } #[tokio::test] @@ -349,7 +374,10 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { ]) .await; - assert_matches!(output, Err(Error::Round1Receive(CompleteRoundError::Io(_)))); + assert_matches!( + output, + Err(Error::Round1Receive(CompleteRoundError::Io(DummyError))) + ); } #[tokio::test] diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index 014212b..ce9e2be 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -104,7 +104,7 @@ pub trait MpcExecution { /// Sends a message that will be received by all parties /// /// Message will be broadcasted, but not reliably. If you need a reliable broadcast, use - /// [`reliably_broadcast`] method. + /// [`MpcExecution::reliably_broadcast`] method. async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { self.send(Outgoing::all_parties(msg)).await } diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index 52ff937..8a9b313 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -4,7 +4,9 @@ use core::any::Any; use crate::Incoming; -pub use self::simple_store::{broadcast, p2p, RoundInput, RoundInputError, RoundMsgs}; +pub use self::simple_store::{ + broadcast, p2p, reliable_broadcast, RoundInput, RoundInputError, RoundMsgs, +}; mod simple_store; @@ -53,11 +55,47 @@ pub trait RoundStore: Sized + 'static { /// uses this function internally. /// /// When implementing `RoundStore` trait, if you wish to expose no extra information, leave the default - /// implementation of this method. If you do wish to expose certain properties that will be accessible + /// implementation of this method. If you do want to expose certain properties that will be accessible /// through [`RoundStoreExt::read_prop`], follow this example: /// /// ```rust - /// todo!() + /// pub struct MyStore { /* ... */ } + /// + /// #[derive(Debug, PartialEq, Eq)] + /// pub struct SomePropertyWeWantToExpose { value: u64 } + /// #[derive(Debug, PartialEq, Eq)] + /// pub struct AnotherProperty(String); + /// + /// # type Msg = (); + /// impl round_based::round::RoundStore for MyStore { + /// # type Msg = Msg; + /// # type Output = Vec; + /// # type Error = core::convert::Infallible; + /// # fn add_message(&mut self, msg: round_based::Incoming) -> Result<(), Self::Error> { unimplemented!() } + /// # fn wants_more(&self) -> bool { unimplemented!() } + /// # fn output(self) -> Result { unimplemented!() } + /// // ... + /// + /// fn read_any_prop(&self, property: &mut dyn core::any::Any) { + /// if let Some(p) = property.downcast_mut::>() { + /// *p = Some(SomePropertyWeWantToExpose { value: 42 }) + /// } else if let Some(p) = property.downcast_mut::>() { + /// *p = Some(AnotherProperty("here we return a string".to_owned())) + /// } + /// } + /// } + /// + /// // Which then can be accessed via `.read_prop()` method: + /// use round_based::round::RoundStoreExt; + /// let store = MyStore { /* ... */ }; + /// assert_eq!( + /// store.read_prop::(), + /// Some(SomePropertyWeWantToExpose { value: 42 }), + /// ); + /// assert_eq!( + /// store.read_prop::(), + /// Some(AnotherProperty("here we return a string".to_owned())), + /// ); /// ``` fn read_any_prop(&self, property: &mut dyn Any) { let _ = property; diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index 9b62c90..0c4153b 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -24,13 +24,13 @@ use super::RoundStore; /// input.add_message(Incoming{ /// id: 0, /// sender: 0, -/// msg_type: MessageType::Broadcast, +/// msg_type: MessageType::Broadcast { reliable: false }, /// msg: "first party message", /// })?; /// input.add_message(Incoming{ /// id: 1, /// sender: 2, -/// msg_type: MessageType::Broadcast, +/// msg_type: MessageType::Broadcast { reliable: false }, /// msg: "third party message", /// })?; /// assert!(!input.wants_more()); @@ -314,18 +314,24 @@ pub enum RoundInputError { }, } -/// Round messages store for p2p round +/// p2p round /// /// Alias to [`RoundInput::p2p`] pub fn p2p(i: u16, n: u16) -> RoundInput { RoundInput::p2p(i, n) } -/// Round messages store for broadcast round +/// Broadcast round /// /// Alias to [`RoundInput::broadcast`] pub fn broadcast(i: u16, n: u16) -> RoundInput { RoundInput::broadcast(i, n) } +/// Reliable broadcast round +/// +/// Alias to [`RoundInput::broadcast`] +pub fn reliable_broadcast(i: u16, n: u16) -> RoundInput { + RoundInput::reliable_broadcast(i, n) +} #[cfg(test)] mod tests { diff --git a/round-based/tests/derive/compile-fail/wrong_usage.rs b/round-based/tests/derive/compile-fail/wrong_usage.rs index 93cc030..fb35bcc 100644 --- a/round-based/tests/derive/compile-fail/wrong_usage.rs +++ b/round-based/tests/derive/compile-fail/wrong_usage.rs @@ -29,8 +29,8 @@ union Msg3 { // protocol_message is repeated twice #[derive(ProtocolMsg)] -#[protocol_message(root = one)] -#[protocol_message(root = two)] +#[protocol_msg(root = one)] +#[protocol_msg(root = two)] enum Msg4 { One(u32), Two(u16), @@ -38,15 +38,15 @@ enum Msg4 { // ", blah blah" is not permitted input #[derive(ProtocolMsg)] -#[protocol_message(root = one, blah blah)] +#[protocol_msg(root = one, blah blah)] enum Msg5 { One(u32), Two(u16), } -// `protocol_message` must not be empty +// `protocol_msh` must not be empty #[derive(ProtocolMsg)] -#[protocol_message()] +#[protocol_msg()] enum Msg6 { One(u32), Two(u16), diff --git a/round-based/tests/derive/compile-fail/wrong_usage.stderr b/round-based/tests/derive/compile-fail/wrong_usage.stderr index f3d8341..335630c 100644 --- a/round-based/tests/derive/compile-fail/wrong_usage.stderr +++ b/round-based/tests/derive/compile-fail/wrong_usage.stderr @@ -1,53 +1,53 @@ -error: named variants are not allowed in ProtocolMessage +error: named variants are not allowed in ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:9:5 | 9 | VariantB { n: u32 }, | ^^^^^^^^ -error: this variant must contain exactly one field to be valid ProtocolMessage +error: this variant must contain exactly one field to be valid ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:11:5 | 11 | VariantC(u32, String), | ^^^^^^^^ -error: this variant must contain exactly one field to be valid ProtocolMessage +error: this variant must contain exactly one field to be valid ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:13:5 | 13 | VariantD(), | ^^^^^^^^ -error: unit variants are not allowed in ProtocolMessage +error: unit variants are not allowed in ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:15:5 | 15 | VariantE, | ^^^^^^^^ -error: only enum may implement ProtocolMessage +error: only enum may implement ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:20:1 | 20 | struct Msg2 { | ^^^^^^ -error: only enum may implement ProtocolMessage +error: only enum may implement ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:26:1 | 26 | union Msg3 { | ^^^^^ -error: #[protocol_message] attribute appears more than once +error: #[protocol_msg] attribute appears more than once --> tests/derive/compile-fail/wrong_usage.rs:33:3 | -33 | #[protocol_message(root = two)] - | ^^^^^^^^^^^^^^^^ +33 | #[protocol_msg(root = two)] + | ^^^^^^^^^^^^ error: unexpected token - --> tests/derive/compile-fail/wrong_usage.rs:41:30 + --> tests/derive/compile-fail/wrong_usage.rs:41:26 | -41 | #[protocol_message(root = one, blah blah)] - | ^ +41 | #[protocol_msg(root = one, blah blah)] + | ^ error: unexpected end of input, expected `root` - --> tests/derive/compile-fail/wrong_usage.rs:49:20 + --> tests/derive/compile-fail/wrong_usage.rs:49:16 | -49 | #[protocol_message()] - | ^ +49 | #[protocol_msg()] + | ^ diff --git a/round-based/tests/derive/compile-pass/correct_usage.rs b/round-based/tests/derive/compile-pass/correct_usage.rs index 16e3c16..21ee289 100644 --- a/round-based/tests/derive/compile-pass/correct_usage.rs +++ b/round-based/tests/derive/compile-pass/correct_usage.rs @@ -8,7 +8,7 @@ enum Msg { VariantD(MyStruct), } #[derive(ProtocolMsg)] -#[protocol_message(root = round_based)] +#[protocol_msg(root = round_based)] enum Msg2 { VariantA(u16), VariantB(String), From dc3b186eb791b3a8d9a37faab476c11b23e6a0d3 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Wed, 7 May 2025 11:48:02 +0200 Subject: [PATCH 10/29] Expose RequiresReliableBroadcast prop in simple store Signed-off-by: Denis Varlakov --- round-based/src/round/simple_store.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index 0c4153b..cbe478d 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -185,6 +185,17 @@ where }) } } + + fn read_any_prop(&self, property: &mut dyn core::any::Any) { + if let Some(p) = + property.downcast_mut::>() + { + *p = Some(crate::round::props::RequiresReliableBroadcast(matches!( + self.expected_msg_type, + MessageType::Broadcast { reliable: true } + ))); + } + } } impl RoundMsgs { From 9e568192c93f186d093b47bc32f204bd6645ddb1 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Wed, 7 May 2025 12:36:03 +0200 Subject: [PATCH 11/29] Add send_many Signed-off-by: Denis Varlakov --- round-based/src/mpc/mod.rs | 74 ++++++++++++++++++++++++++++++++ round-based/src/mpc/party/mod.rs | 34 +++++++++++++-- 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index ce9e2be..dddb42f 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -80,6 +80,9 @@ pub trait MpcExecution { /// Error indicating that sending a message has failed type SendErr; + /// Returned by [`.send_many()`](Self::send_many) + type SendMany: SendMany; + /// Completes the round async fn complete( &mut self, @@ -90,9 +93,15 @@ pub trait MpcExecution { Self::Msg: RoundMsg; /// Sends a message + /// + /// This method awaits until the message is sent, which might be not the best method to use if you + /// need to send many messages at once. If it's the case, prefer using [`.send_many()`](Self::send_many). async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; /// Sends a p2p message to another party + /// + /// Note: when you send many messages at once (it's most likely the case when you send a p2p message), this method + /// is not efficient, prefer using [`.send_many()`](Self::send_many). async fn send_p2p( &mut self, recipient: PartyIndex, @@ -123,10 +132,75 @@ pub trait MpcExecution { self.send(Outgoing::reliable_broadcast(msg)).await } + /// Creates a buffer of outgoing messages so they can be sent all at once + /// + /// When you have many messages that you want to send at once, using [`.send()`](Self::send) is not efficient, + /// as it will accumulate message delivery cost. Use this method to optimize sending many messages at once. + fn send_many(self) -> Self::SendMany; + /// Yields execution async fn yield_now(&self); } +/// Buffer, optimized for sending many messages at once +pub trait SendMany { + /// MPC executor, returned after successful [`.flush()`](Self::flush) + type Exec: MpcExecution; + /// Protocol message + type Msg; + /// Error indicating that sending a message has failed + type SendErr; + + /// Adds a message to the sending queue + /// + /// Similar to [`MpcExecution::send`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed + /// until [`.flush()`](Self::flush) is called. + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; + + /// Adds a p2p message to the sending queue + /// + /// Similar to [`MpcExecution::send_p2p`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed + /// until [`.flush()`](Self::flush) is called. + async fn send_p2p( + &mut self, + recipient: PartyIndex, + msg: Self::Msg, + ) -> Result<(), Self::SendErr> { + self.send(Outgoing::p2p(recipient, msg)).await + } + + /// Adds a broadcast message to the sending queue + /// + /// Similar to [`MpcExecution::send_to_all`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed + /// until [`.flush()`](Self::flush) is called. + async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::all_parties(msg)).await + } + + /// Adds a reliable broadcast message to the sending queue + /// + /// Similar to [`MpcExecution::reliably_broadcast`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed + /// until [`.flush()`](Self::flush) is called. + async fn reliably_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::reliable_broadcast(msg)).await + } + + /// Flushes internal buffer by sending all messages in the queue + async fn flush(self) -> Result; +} + /// Alias to `<::Exec as MpcExecution>::CompleteRoundErr` pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRoundErr; diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs index cdcf35a..85076da 100644 --- a/round-based/src/mpc/party/mod.rs +++ b/round-based/src/mpc/party/mod.rs @@ -104,12 +104,10 @@ where AsyncR: runtime::AsyncRuntime, { type Round = router::Round; - type Msg = M; - type CompleteRoundErr = CompleteRoundError; - type SendErr = IoErr; + type SendMany = SendMany; async fn complete( &mut self, @@ -147,11 +145,41 @@ where self.io.send(msg).await } + fn send_many(self) -> Self::SendMany { + SendMany { party: self } + } + async fn yield_now(&self) { self.runtime.yield_now().await } } +/// Returned by [`MpcParty::send_many()`] +pub struct SendMany { + party: MpcParty, +} + +impl super::SendMany for SendMany +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Exec = MpcParty; + type Msg = as Mpc>::Msg; + type SendErr = as Mpc>::SendErr; + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.party.io.feed(msg).await + } + + async fn flush(mut self) -> Result { + self.party.io.flush().await?; + Ok(self.party) + } +} + pin_project_lite::pin_project! { /// Merges a stream and a sink into one structure that implements both [`Stream`] and [`Sink`] pub struct Halves { From 8f70c7e770b04353110344b9eef97016dfba7048 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 19 May 2025 16:12:15 +0200 Subject: [PATCH 12/29] sync Signed-off-by: Denis Varlakov --- Cargo.lock | 11 + round-based/Cargo.toml | 6 + round-based/src/echo_broadcast/error.rs | 52 +++++ round-based/src/echo_broadcast/mod.rs | 299 ++++++++++++++++++++++++ round-based/src/echo_broadcast/state.rs | 271 +++++++++++++++++++++ round-based/src/echo_broadcast/store.rs | 233 ++++++++++++++++++ round-based/src/lib.rs | 2 + round-based/src/mpc/mod.rs | 12 + round-based/src/mpc/party/mod.rs | 21 +- round-based/src/mpc/party/router.rs | 1 - 10 files changed, 901 insertions(+), 7 deletions(-) create mode 100644 round-based/src/echo_broadcast/error.rs create mode 100644 round-based/src/echo_broadcast/mod.rs create mode 100644 round-based/src/echo_broadcast/state.rs create mode 100644 round-based/src/echo_broadcast/store.rs diff --git a/Cargo.lock b/Cargo.lock index 9c5ab72..4ce8fc9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -466,6 +466,7 @@ name = "round-based" version = "0.4.1" dependencies = [ "anyhow", + "digest", "futures", "futures-util", "hex", @@ -480,6 +481,7 @@ dependencies = [ "tokio-stream", "tracing", "trybuild", + "udigest", ] [[package]] @@ -745,6 +747,15 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "udigest" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cd61fa9fb78569e9fe34acf0048fd8cb9ebdbacc47af740745487287043ff0" +dependencies = [ + "digest", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/round-based/Cargo.toml b/round-based/Cargo.toml index be42da1..53489cc 100644 --- a/round-based/Cargo.toml +++ b/round-based/Cargo.toml @@ -28,6 +28,10 @@ tokio-stream = { version = "0.1", features = ["sync"], optional = true } pin-project-lite = "0.2" +# echo broadcast +digest = { version = "0.10", default-features = false, optional = true } +udigest = { version = "0.2", default-features = false, features = ["alloc", "digest", "inline-struct"], optional = true } + [dev-dependencies] trybuild = "1" matches = "0.1" @@ -48,6 +52,8 @@ sim-async = ["sim", "tokio/sync", "tokio-stream", "futures-util/alloc"] derive = ["round-based-derive"] runtime-tokio = ["tokio"] +echo-broadcast = ["dep:digest", "dep:udigest"] + [[test]] name = "derive" required-features = ["derive"] diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs new file mode 100644 index 0000000..04bc07e --- /dev/null +++ b/round-based/src/echo_broadcast/error.rs @@ -0,0 +1,52 @@ +#[derive(thiserror::Error, Debug)] +#[error(transparent)] +pub struct EchoError(#[from] Reason); + +#[derive(thiserror::Error, Debug)] +pub(super) enum Reason { + #[error( + "round store received two msgs from same party and \ + didn't return an error" + )] + StoreReceivedTwoMsgsFromSameParty, + #[error("received a msg from principal protocol when round is over")] + ReceivedPrincipalMsgWhenRoundOver, + #[error("received an echo msg when round is over")] + ReceivedEchoMsgWhenRoundOver, + #[error("unknown sender i={i} (n={n})")] + UnknownSender { i: u16, n: usize }, + #[error("local party index is out of bounds, probably indicates a bug")] + OwnIndexOutOfBounds { i: u16, n: usize }, + #[error("principal round is finished, but store doesn't output")] + PrincipalRoundFinishedButStoreDoesntOutput, + #[error("echo round is finished, but store doesn't output")] + EchoRoundFinishedButStoreDoesntOutput, + + #[error("handle incoming echo msg")] + HandleEcho(#[source] crate::round::RoundInputError), + + #[error("reliability check error: messages were not reliably broadcasted")] + MismatchedHash, + + #[error("round has already returned output or error")] + StateFinished, + #[error("impossible state (it's a bug)")] + StateGone, + + #[error("main round is in unexpected state (it's a bug)")] + UnexpectedMainRoundState, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("error originated in principal protocol")] + Principal(#[source] E), + #[error("echo broadcast")] + Echo(#[from] EchoError), +} + +impl From for Error { + fn from(value: Reason) -> Self { + Error::Echo(value.into()) + } +} diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs new file mode 100644 index 0000000..786e1a0 --- /dev/null +++ b/round-based/src/echo_broadcast/mod.rs @@ -0,0 +1,299 @@ +use core::marker::PhantomData; + +use alloc::vec::Vec; +use digest::Digest; + +use crate::{ + round::RoundStore, Incoming, Mpc, MpcExecution, MsgId, Outgoing, PartyIndex, ProtocolMsg, + RoundMsg, +}; + +mod error; +mod store; + +// TODO: remove +// mod state; + +pub enum Msg { + /// Message from echo broadcast sub-protocol + Echo { + /// Indicates for which round of main protocol this echo message is transmitted + round: u16, + /// Hash of all messages received in `round` + hash: digest::Output, + }, + /// Message from the main protocol + Main(M), +} + +struct EchoMsg { + hash: digest::Output, + _round: PhantomData, +} + +impl Clone for Msg { + fn clone(&self) -> Self { + match self { + Self::Echo { round, hash } => Self::Echo { + round: *round, + hash: hash.clone(), + }, + Self::Main(msg) => Self::Main(msg.clone()), + } + } +} + +impl Clone for EchoMsg { + fn clone(&self) -> Self { + Self { + hash: self.hash.clone(), + _round: PhantomData, + } + } +} + +impl ProtocolMsg for Msg { + fn round(&self) -> u16 { + match self { + Self::Echo { round, .. } => 2 * round + 1, + Self::Main(m) => 2 * m.round(), + } + } +} + +impl RoundMsg> for Msg +where + M: RoundMsg, +{ + const ROUND: u16 = 2 * M::ROUND + 1; + fn to_protocol_msg(round_msg: EchoMsg) -> Self { + Self::Echo { + round: M::ROUND, + hash: round_msg.hash, + } + } + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + match protocol_msg { + Self::Echo { round, hash } if round == M::ROUND => Ok(EchoMsg { + hash, + _round: PhantomData, + }), + _ => Err(protocol_msg), + } + } +} + +struct Principal(M); + +impl RoundMsg> for Msg +where + ProtoM: ProtocolMsg + RoundMsg, +{ + const ROUND: u16 = 2 * >::ROUND; + fn to_protocol_msg(round_msg: Principal) -> Self { + Self::Main(ProtoM::to_protocol_msg(round_msg.0)) + } + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + if let Self::Main(msg) = protocol_msg { + ProtoM::from_protocol_msg(msg) + .map(Principal) + .map_err(|m| Self::Main(m)) + } else { + Err(protocol_msg) + } + } +} + +pub fn wrap(party: M, i: u16) -> WithReliableBroadcast +where + D: Digest, + M: Mpc>, + PrincipalMsg: udigest::Digestable, +{ + todo!() +} + +pub struct WithReliableBroadcast { + party: M, + i: u16, + n: u16, + _ph: PhantomData, +} + +impl Mpc for WithReliableBroadcast +where + D: Digest, + M: Mpc>, + PrincipalMsg: Clone, +{ + type Msg = PrincipalMsg; + + type Exec = WithReliableBroadcast; + + type SendErr = M::SendErr; + + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg, + { + // let reliable_broadcast_required = round + // .read_prop::() + // .map(|x| x.0); + // if reliable_broadcast_required == Some(true) { + // let round_state = state::State::::init(round, self.i, self.n); + // let round_state = core::cell::RefCell::new(round_state); + // let round_state = alloc::rc::Rc::new(round_state); + + // let store = store::RoundWithEcho(round_state.clone()); + + // todo!() + // } + // Round(self.party.add_round(Principal(round))) + todo!() + } + + fn finish(self) -> Self::Exec { + // WithReliableBroadcast { + // party: self.party.finish(), + // i: self.i, + // _ph: PhantomData, + // } + todo!() + } +} + +impl MpcExecution for WithReliableBroadcast +where + D: Digest, + M: MpcExecution>, + PrincipalMsg: Clone, +{ + type Round = Round; + type Msg = PrincipalMsg; + type CompleteRoundErr = M::CompleteRoundErr; + type SendErr = M::SendErr; + type SendMany = SendMany; + + async fn complete( + &mut self, + round: Self::Round, + ) -> Result> + where + R: RoundStore, + Self::Msg: RoundMsg, + { + match round.0 { + Inner::Unmodified(round) => self.party.complete(round).await, + Inner::WithReliabilityCheck { + main_round, + echo_round, + } => { + todo!() + } + } + } + + async fn receive_and_process_one_message( + &mut self, + ) -> Result<(), Self::CompleteRoundErr> { + self.party.receive_and_process_one_message().await + } + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.party.send(msg.map(Msg::Main)).await + } + + fn send_many(self) -> Self::SendMany { + SendMany { + sender: self.party.send_many(), + i: self.i, + _ph: PhantomData, + } + } + + async fn yield_now(&self) { + self.party.yield_now().await + } +} + +pub struct Round(Inner); + +enum Inner { + /// Round that we do not modify (round that doesn't require reliable broadcast) + Unmodified(M::Round>), + WithReliabilityCheck { + main_round: M::Round>, + echo_round: M::Round>, + }, +} + +impl RoundStore for Principal { + type Msg = Principal; + type Output = R::Output; + type Error = R::Error; + + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.0.add_message(msg.map(|m| m.0)) + } + fn wants_more(&self) -> bool { + self.0.wants_more() + } + fn output(self) -> Result { + self.0.output().map_err(|s| Principal(s)) + } +} + +pub struct SendMany { + sender: M, + i: u16, + _ph: PhantomData, +} + +impl crate::mpc::SendMany for SendMany +where + D: Digest, + M: crate::mpc::SendMany>, + PrincipalMsg: Clone, +{ + type Exec = WithReliableBroadcast; + type Msg = PrincipalMsg; + type SendErr = M::SendErr; + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.sender.send(msg.map(Msg::Main)).await + } + + async fn flush(self) -> Result { + let party = self.sender.flush().await?; + // Ok(WithReliableBroadcast { + // party, + // i: self.i, + // _ph: PhantomData, + // }) + todo!() + } +} + +type SharedRoundState = + alloc::rc::Rc, RoundStateError>>>; + +struct RoundState { + my_msg: Option, + received_msgs: Vec>, + received_echoes: Vec>>, +} + +#[derive(Debug, thiserror::Error)] +enum RoundStateError { + /// Party sent two messages in one round + /// + /// `msgs_ids` are ids of conflicting messages + #[error("party {sender} tried to overwrite message")] + AttemptToOverwriteReceivedMsg { + /// IDs of conflicting messages + msgs_ids: [MsgId; 2], + /// Index of party who sent two messages in one round + sender: PartyIndex, + }, +} diff --git a/round-based/src/echo_broadcast/state.rs b/round-based/src/echo_broadcast/state.rs new file mode 100644 index 0000000..8f66b07 --- /dev/null +++ b/round-based/src/echo_broadcast/state.rs @@ -0,0 +1,271 @@ +#![allow(dead_code)] // TODO: remove + +use alloc::vec::Vec; + +use digest::Digest; + +use crate::{ + round::{RoundInput, RoundStore}, + Incoming, +}; + +use super::error; + +const TAG: &[u8] = b"dfns.round_based.echo_broadcast"; + +pub enum State { + /// A round from principal protocol is ongoing + PrincipalRound(PrincipalRound), + /// Principal round is completed, we need to send an echo message + SendEchoMsg(digest::Output, EchoRound), + /// Principal round is completed, echo round is ongoing + EchoRound(EchoRound), + + /// Principal and echo rounds are finished, reliability check has passed + Output(S::Output), + + /// Indicates that the round has previously already returned the output. + /// Calling any methods when state is finished results into an error. + Finished, + + /// Indicates that the state is temporarily moved. Calling methods when state + /// is gone results into an error and indicates a bug. + Gone, +} + +struct PrincipalRound { + my_msg: Option, + received_msgs: Vec>, + store: S, + received_echoes: RoundInput>, +} + +struct EchoRound { + store_output: O, + received_echoes: RoundInput>, + expected_hash: digest::Output, +} + +impl State { + pub fn init(store: S, i: u16, n: u16) -> Self { + State::PrincipalRound(PrincipalRound { + my_msg: None, + received_msgs: core::iter::repeat_with(|| None).take(n.into()).collect(), + store, + received_echoes: crate::round::broadcast(i, n), + }) + } + + /// Takes the state by value, replaces `self` with `State::Gone`. + /// + /// `self` must be overwritten. Not overwriting a `State::Gone` is a bug + fn take(&mut self) -> Result { + match core::mem::replace(self, State::Gone) { + State::Gone => Err(error::Reason::StateGone), + state => Ok(state), + } + } + + /// Indicates that [`take_output`] method will return `Some(_)` + pub fn ready_to_output(&self) -> bool { + matches!(self, State::Output(_)) + } + + /// Retrieves the round output if reliability check has passed + pub fn take_output(&mut self) -> Option { + match self.take().ok()? { + State::Output(out) => { + *self = State::Finished; + Some(out) + } + state => { + *self = state; + None + } + } + } +} +impl State +where + D: Digest, + S: RoundStore, + S::Msg: udigest::Digestable + Clone, +{ + pub fn received_principal_msg( + &mut self, + incoming: Incoming, + ) -> Result<(), error::Error> { + let state = self.take()?; + match state.received_principal_msg_inner(incoming) { + Ok(next_state) => { + *self = next_state; + Ok(()) + } + Err(err) => { + *self = State::Finished; + Err(err) + } + } + } + fn received_principal_msg_inner( + self, + incoming: Incoming, + ) -> Result> { + match self { + State::PrincipalRound(mut principal_round) => { + principal_round + .store + .add_message(incoming.clone()) + .map_err(error::Error::Principal)?; + let n = principal_round.received_msgs.len(); + let slot = principal_round + .received_msgs + .get_mut(usize::from(incoming.sender)) + .ok_or(error::Reason::UnknownSender { + i: incoming.sender, + n, + })?; + if slot.is_some() { + return Err(error::Reason::StoreReceivedTwoMsgsFromSameParty.into()); + } + *slot = Some(incoming.msg); + + principal_round.advance_if_possible().map_err(Into::into) + } + State::SendEchoMsg(..) | State::EchoRound(_) | State::Output(_) => { + Err(error::Reason::ReceivedPrincipalMsgWhenRoundOver.into()) + } + State::Finished => Err(error::Reason::StateFinished.into()), + State::Gone => Err(error::Reason::StateGone.into()), + } + } + + pub fn received_echo_msg( + &mut self, + incoming: Incoming>, + ) -> Result<(), error::EchoError> { + let state = self.take()?; + match state.received_echo_msg_inner(incoming) { + Ok(next_state) => { + *self = next_state; + Ok(()) + } + Err(err) => { + *self = State::Finished; + Err(err) + } + } + } + fn received_echo_msg_inner( + self, + incoming: Incoming>, + ) -> Result { + match self { + State::PrincipalRound(mut round) => { + round + .received_echoes + .add_message(incoming) + .map_err(error::Reason::HandleEcho)?; + Ok(State::PrincipalRound(round)) + } + State::SendEchoMsg(msg, mut round) => { + round + .received_echoes + .add_message(incoming) + .map_err(error::Reason::HandleEcho)?; + Ok(State::SendEchoMsg(msg, round)) + } + State::EchoRound(mut round) => { + round + .received_echoes + .add_message(incoming) + .map_err(error::Reason::HandleEcho)?; + round.advance_if_possible().map_err(Into::into) + } + State::Output(_output) => Err(error::Reason::ReceivedEchoMsgWhenRoundOver.into()), + State::Finished => Err(error::Reason::StateFinished.into()), + State::Gone => Err(error::Reason::StateGone.into()), + } + } + + /// Retrieves an echo msg and marks it as sent + /// + /// Returns `None` if there's no echo msg to be sent (yet or already) + pub fn take_echo_msg(&mut self) -> Option> { + let state = self.take().ok()?; + let (next_state, msg) = state.take_echo_msg_inner(); + *self = next_state; + msg + } + fn take_echo_msg_inner(self) -> (Self, Option>) { + match self { + State::SendEchoMsg(echo_msg, echo_round) => { + (State::EchoRound(echo_round), Some(echo_msg)) + } + state => (state, None), + } + } +} + +impl PrincipalRound +where + D: Digest, + S: RoundStore, + S::Msg: udigest::Digestable, +{ + fn advance_if_possible(self) -> Result, error::Reason> { + if !self.store.wants_more() { + // Principal round is over, we can start the echo round + let output = self + .store + .output() + .map_err(|_| error::Reason::PrincipalRoundFinishedButStoreDoesntOutput)?; + let echo_msg = udigest::hash::(&udigest::inline_struct!(TAG { + received_msgs: &self.received_msgs, + })); + Ok(State::SendEchoMsg( + echo_msg.clone(), + EchoRound { + store_output: output, + received_echoes: self.received_echoes, + expected_hash: echo_msg, + }, + )) + } else { + Ok(State::PrincipalRound(self)) + } + } +} + +impl EchoRound +where + D: Digest, +{ + fn advance_if_possible>(self) -> Result, error::Reason> { + if !self.received_echoes.wants_more() { + // Echo round is over, now handle the received messages + let echoes = self + .received_echoes + .output() + .map_err(|_| error::Reason::EchoRoundFinishedButStoreDoesntOutput)?; + if echoes.iter().any(|m| *m != self.expected_hash) { + // echo check failed, abort! + Err(error::Reason::MismatchedHash) + } else { + Ok(State::Output(self.store_output)) + } + } else { + Ok(State::EchoRound(self)) + } + } +} + +pub trait IsFinished { + fn is_finished(&self) -> bool; +} + +impl IsFinished for core::cell::RefCell> { + fn is_finished(&self) -> bool { + self.borrow().ready_to_output() + } +} diff --git a/round-based/src/echo_broadcast/store.rs b/round-based/src/echo_broadcast/store.rs new file mode 100644 index 0000000..a1ea2ee --- /dev/null +++ b/round-based/src/echo_broadcast/store.rs @@ -0,0 +1,233 @@ +use alloc::vec::Vec; +use core::marker::PhantomData; +use digest::Digest; + +use crate::{ + round::{RoundInput, RoundMsgs, RoundStore}, + Incoming, +}; + +use super::{error, EchoMsg}; + +const TAG: &[u8] = b"dfns.round_based.echo_broadcast"; + +enum MainRoundState { + Ongoing { store: S }, + Output { output: S::Output }, + Finished, + Gone, +} + +#[derive(Clone, Copy, Debug)] +struct Params { + i: u16, + n: u16, +} + +pub struct MainRound { + params: Params, + state: MainRoundState, + received_msgs: Vec>, + _digest: PhantomData, +} + +impl RoundStore for MainRound +where + S::Msg: Clone, + D: 'static, +{ + type Msg = S::Msg; + /// When a main round is finished, we output a builder that can be used to + /// calculate a hash of messages received by all parties in the reliable + /// broadcast round. Then this hash needs to be re-sent to all participants. + /// + /// Only if we receive the same hash from all parties in [`EchoRound`], only + /// then we can obtain a main round output. + type Output = NeedsOwnMsg; + type Error = error::Error; + + fn add_message(&mut self, incoming: Incoming) -> Result<(), Self::Error> { + let wants_more = match &mut self.state { + MainRoundState::Ongoing { store, .. } => { + store + .add_message(incoming.clone()) + .map_err(error::Error::Principal)?; + let n = self.received_msgs.len(); + let slot = self + .received_msgs + .get_mut(usize::from(incoming.sender)) + .ok_or(error::Reason::UnknownSender { + i: incoming.sender, + n, + })?; + if slot.is_some() { + return Err(error::Reason::StoreReceivedTwoMsgsFromSameParty.into()); + } + *slot = Some(incoming.msg); + store.wants_more() + } + MainRoundState::Gone => return Err(error::Reason::StateGone.into()), + MainRoundState::Output { .. } | MainRoundState::Finished => { + return Err(error::Reason::ReceivedPrincipalMsgWhenRoundOver.into()) + } + }; + + if !wants_more { + let store = core::mem::replace(&mut self.state, MainRoundState::Gone); + let MainRoundState::Ongoing { store } = store else { + return Err(error::Reason::UnexpectedMainRoundState.into()); + }; + let Ok(output) = store.output() else { + self.state = MainRoundState::Finished; + return Err(error::Reason::PrincipalRoundFinishedButStoreDoesntOutput.into()); + }; + self.state = MainRoundState::Output { output }; + } + + Ok(()) + } + + fn wants_more(&self) -> bool { + match &self.state { + MainRoundState::Ongoing { .. } => { + // Note that on each `add_message` we check if `store.wants_more()`, + // and if it doesn't we change the state to MainRoundState::Output + true + } + _ => false, + } + } + + fn output(self) -> Result { + match self.state { + MainRoundState::Output { output } => Ok(NeedsOwnMsg { + params: self.params, + main_round_output: output, + received_msgs: self.received_msgs, + _hash: PhantomData, + }), + state => { + return Err(Self { + params: self.params, + state, + received_msgs: self.received_msgs, + _digest: PhantomData, + }) + } + } + } +} + +/// An output of [`MainRound`] which needs an own message sent by local party +/// in this round. Once provided in [`NeedsOwnMsg::with_own_msg`], it outputs +/// a hash of messages received by all parties in this round (that needs to be +/// re-sent to all participants), and [`WithReliabilityCheck`] that takes +/// messages received in echo round and outputs main round result only if reliability +/// check passes. +pub struct NeedsOwnMsg { + params: Params, + main_round_output: S::Output, + received_msgs: Vec>, + _hash: PhantomData, +} + +pub struct WithReliabilityCheck { + expected_hash: digest::Output, + main_round_output: S::Output, +} + +impl NeedsOwnMsg +where + D: Digest, + S: RoundStore, + S::Msg: udigest::Digestable, +{ + pub fn with_my_msg( + mut self, + msg: S::Msg, + ) -> Result<(WithReliabilityCheck, digest::Output), error::EchoError> { + let n = self.received_msgs.len(); + *self + .received_msgs + .get_mut(usize::from(self.params.i)) + .ok_or(error::Reason::OwnIndexOutOfBounds { + i: self.params.i, + n, + })? = Some(msg); + + let hash = udigest::hash::(&udigest::inline_struct!(TAG { + msgs: &self.received_msgs, + n: self.params.n, + })); + let with_reliability_check = WithReliabilityCheck { + expected_hash: hash.clone(), + main_round_output: self.main_round_output, + }; + Ok((with_reliability_check, hash)) + } +} + +impl WithReliabilityCheck +where + D: Digest, + S: RoundStore, +{ + pub fn with_echo_output( + self, + echo_output: EchoRoundOutput, + ) -> Result { + if echo_output + .received_echoes + .iter() + .any(|h| *h != self.expected_hash) + { + return Err(error::Reason::MismatchedHash.into()); + } + + Ok(self.main_round_output) + } +} + +pub struct EchoRound { + echo_round: RoundInput>, + _round: PhantomData, +} + +pub struct EchoRoundOutput { + received_echoes: RoundMsgs>, + _round: PhantomData, +} + +impl RoundStore for EchoRound +where + D: Digest + 'static, + S: RoundStore, +{ + type Msg = EchoMsg; + type Output = EchoRoundOutput; + type Error = error::EchoError; + + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.echo_round + .add_message(msg.map(|m| m.hash)) + .map_err(error::Reason::HandleEcho)?; + Ok(()) + } + + fn wants_more(&self) -> bool { + self.echo_round.wants_more() + } + + fn output(self) -> Result { + self.echo_round + .output() + .map(|received_echoes| EchoRoundOutput { + received_echoes, + _round: PhantomData, + }) + .map_err(|echo_round| Self { + echo_round, + _round: PhantomData, + }) + } +} diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index f23752b..e34a1e2 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -67,6 +67,8 @@ mod false_positives { } mod delivery; +// #[cfg(feature = "echo-broadcast")] +// pub mod echo_broadcast; pub mod mpc; pub mod round; #[cfg(feature = "state-machine")] diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index dddb42f..3814314 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -35,6 +35,8 @@ //! # Ok(()) } //! ``` +use core::convert::Infallible; + use crate::{round::RoundStore, Outgoing, PartyIndex}; pub mod party; @@ -84,6 +86,8 @@ pub trait MpcExecution { type SendMany: SendMany; /// Completes the round + /// + /// Waits until all messages in the round `R` are received, returns the received messages. async fn complete( &mut self, round: Self::Round, @@ -92,6 +96,14 @@ pub trait MpcExecution { R: RoundStore, Self::Msg: RoundMsg; + /// Instructs the MPC driver to receive exactly one message and route it to its appropriate round store + /// + /// This is a low-level function, normally you don't need to use it. Use [`.complete()`](Self::complete) + /// to receive messages until round is completed. + async fn receive_and_process_one_message( + &mut self, + ) -> Result<(), Self::CompleteRoundErr>; + /// Sends a message /// /// This method awaits until the message is sent, which might be not the best method to use if you diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs index 85076da..05a5612 100644 --- a/round-based/src/mpc/party/mod.rs +++ b/round-based/src/mpc/party/mod.rs @@ -125,13 +125,9 @@ where // Round is not completed - we need more messages loop { - let incoming = self - .io - .next() + self.receive_and_process_one_message() .await - .ok_or(CompleteRoundError::UnexpectedEof)? - .map_err(CompleteRoundError::Io)?; - self.router.received_msg(incoming)?; + .map_err(|e| e.map_process_err(|e| match e {}))?; // Check if round was just completed round = match self.router.complete_round(round) { @@ -141,6 +137,19 @@ where } } + async fn receive_and_process_one_message( + &mut self, + ) -> Result<(), Self::CompleteRoundErr> { + let incoming = self + .io + .next() + .await + .ok_or(CompleteRoundError::UnexpectedEof)? + .map_err(CompleteRoundError::Io)?; + self.router.received_msg(incoming)?; + Ok(()) + } + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { self.io.send(msg).await } diff --git a/round-based/src/mpc/party/router.rs b/round-based/src/mpc/party/router.rs index e2a85c6..cc95813 100644 --- a/round-based/src/mpc/party/router.rs +++ b/round-based/src/mpc/party/router.rs @@ -113,7 +113,6 @@ where ) -> Result> where R: RoundStore, - M: RoundMsg, { match round.take_output() { Ok(Ok(any)) => Ok(*any From 0743bcd155607be7530ea6a95be5e10e212f4592 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 19 May 2025 16:16:33 +0200 Subject: [PATCH 13/29] Introduce RoundInfo Signed-off-by: Denis Varlakov --- round-based/src/mpc/party/router.rs | 5 +++-- round-based/src/round/mod.rs | 31 +++++++++++++++++---------- round-based/src/round/simple_store.rs | 10 ++++++--- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/round-based/src/mpc/party/router.rs b/round-based/src/mpc/party/router.rs index cc95813..eba3e01 100644 --- a/round-based/src/mpc/party/router.rs +++ b/round-based/src/mpc/party/router.rs @@ -416,11 +416,12 @@ mod tests { } struct Msg1; - impl super::RoundStore for Store { + impl crate::round::RoundInfo for Store { type Msg = Msg1; type Output = (); type Error = core::convert::Infallible; - + } + impl crate::round::RoundStore for Store { fn add_message(&mut self, _msg: crate::Incoming) -> Result<(), Self::Error> { Ok(()) } diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index 8a9b313..8455120 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -10,6 +10,16 @@ pub use self::simple_store::{ mod simple_store; +/// Common information about a round +pub trait RoundInfo: Sized + 'static { + /// Message type + type Msg; + /// Store output (e.g. `Vec<_>` of received messages) + type Output; + /// Store error + type Error: core::error::Error; +} + /// Stores messages received at particular round /// /// In MPC protocol, party at every round usually needs to receive up to `n` messages. `RoundsStore` @@ -27,14 +37,7 @@ mod simple_store; /// /// ## Example /// [`RoundInput`] is an simple messages store. Refer to its docs to see usage examples. -pub trait RoundStore: Sized + 'static { - /// Message type - type Msg; - /// Store output (e.g. `Vec<_>` of received messages) - type Output; - /// Store error - type Error: core::error::Error; - +pub trait RoundStore: RoundInfo { /// Adds received message to the store /// /// Returns error if message cannot be processed. Usually it means that sender behaves maliciously. @@ -136,15 +139,21 @@ pub struct WithProp { store: S, } -impl RoundStore for WithProp +impl RoundInfo for WithProp where - S: RoundStore, - P: Clone + 'static, + S: RoundInfo, + P: 'static, { type Msg = S::Msg; type Output = S::Output; type Error = S::Error; +} +impl RoundStore for WithProp +where + S: RoundStore, + P: Clone + 'static, +{ #[inline(always)] fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { self.store.add_message(msg) diff --git a/round-based/src/round/simple_store.rs b/round-based/src/round/simple_store.rs index cbe478d..cd550b1 100644 --- a/round-based/src/round/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -5,7 +5,7 @@ use core::iter; use crate::{Incoming, MessageType, MsgId, PartyIndex}; -use super::RoundStore; +use super::{RoundInfo, RoundStore}; /// Simple implementation of [`RoundStore`] that waits for all parties to send a message /// @@ -124,14 +124,18 @@ impl RoundInput { } } -impl RoundStore for RoundInput +impl RoundInfo for RoundInput where M: 'static, { type Msg = M; type Output = RoundMsgs; type Error = RoundInputError; - +} +impl RoundStore for RoundInput +where + M: 'static, +{ fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { if !self.is_expected_type_of_msg(msg.msg_type) { return Err(RoundInputError::MismatchedMessageType { From 9795efe9bbbbbf1e866232d2ea50d101bb48fb1d Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 19 May 2025 16:19:31 +0200 Subject: [PATCH 14/29] MpcExecution::complete requires RoundInfo Signed-off-by: Denis Varlakov --- round-based/src/mpc/mod.rs | 9 ++++++--- round-based/src/mpc/party/mod.rs | 9 ++++++--- round-based/src/mpc/party/router.rs | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index 3814314..10e227f 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -37,7 +37,10 @@ use core::convert::Infallible; -use crate::{round::RoundStore, Outgoing, PartyIndex}; +use crate::{ + round::{RoundInfo, RoundStore}, + Outgoing, PartyIndex, +}; pub mod party; @@ -72,7 +75,7 @@ pub trait MpcExecution { /// Witness that round was registered /// /// It is used to retrieve messages in [`MpcExecution::complete`]. - type Round; + type Round; /// Protocol message type Msg; @@ -93,7 +96,7 @@ pub trait MpcExecution { round: Self::Round, ) -> Result> where - R: RoundStore, + R: RoundInfo, Self::Msg: RoundMsg; /// Instructs the MPC driver to receive exactly one message and route it to its appropriate round store diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs index 05a5612..3d47b2c 100644 --- a/round-based/src/mpc/party/mod.rs +++ b/round-based/src/mpc/party/mod.rs @@ -2,7 +2,10 @@ use futures_util::{Sink, SinkExt, Stream, StreamExt}; -use crate::{round::RoundStore, Incoming, Outgoing}; +use crate::{ + round::{RoundInfo, RoundStore}, + Incoming, Outgoing, +}; use super::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; @@ -103,7 +106,7 @@ where D: Sink, Error = IoErr> + Unpin, AsyncR: runtime::AsyncRuntime, { - type Round = router::Round; + type Round = router::Round; type Msg = M; type CompleteRoundErr = CompleteRoundError; type SendErr = IoErr; @@ -114,7 +117,7 @@ where mut round: Self::Round, ) -> Result> where - R: RoundStore, + R: RoundInfo, Self::Msg: RoundMsg, { // Check if round is already completed diff --git a/round-based/src/mpc/party/router.rs b/round-based/src/mpc/party/router.rs index eba3e01..f1e5d73 100644 --- a/round-based/src/mpc/party/router.rs +++ b/round-based/src/mpc/party/router.rs @@ -9,7 +9,10 @@ use core::{any::Any, convert::Infallible, mem}; use phantom_type::PhantomType; use tracing::{error, trace_span, warn}; -use crate::{round::RoundStore, Incoming, ProtocolMsg, RoundMsg}; +use crate::{ + round::{RoundInfo, RoundStore}, + Incoming, ProtocolMsg, RoundMsg, +}; /// Routes received messages between protocol rounds pub struct RoundsRouter { @@ -84,7 +87,7 @@ where round: Round, ) -> Result>, Round> where - R: RoundStore, + R: RoundInfo, M: RoundMsg, { let message_round = match self.rounds.get_mut(&M::ROUND) { @@ -112,7 +115,7 @@ where round: &mut Box>, ) -> Result> where - R: RoundStore, + R: RoundInfo, { match round.take_output() { Ok(Ok(any)) => Ok(*any From 688ebba98ad154ad3b64ce5eb18132d1ea3311f0 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Tue, 20 May 2025 15:57:36 +0200 Subject: [PATCH 15/29] Finalize echo broadcast Signed-off-by: Denis Varlakov --- round-based/src/echo_broadcast/error.rs | 51 +++- round-based/src/echo_broadcast/mod.rs | 373 ++++++++++++++---------- round-based/src/echo_broadcast/state.rs | 271 ----------------- round-based/src/echo_broadcast/store.rs | 246 +++++++++++++--- round-based/src/lib.rs | 4 +- round-based/src/mpc/mod.rs | 10 - round-based/src/mpc/party/mod.rs | 21 +- 7 files changed, 478 insertions(+), 498 deletions(-) delete mode 100644 round-based/src/echo_broadcast/state.rs diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index 04bc07e..ac5b369 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -10,17 +10,13 @@ pub(super) enum Reason { )] StoreReceivedTwoMsgsFromSameParty, #[error("received a msg from principal protocol when round is over")] - ReceivedPrincipalMsgWhenRoundOver, - #[error("received an echo msg when round is over")] - ReceivedEchoMsgWhenRoundOver, + ReceivedMainMsgWhenRoundOver, #[error("unknown sender i={i} (n={n})")] UnknownSender { i: u16, n: usize }, #[error("local party index is out of bounds, probably indicates a bug")] OwnIndexOutOfBounds { i: u16, n: usize }, #[error("principal round is finished, but store doesn't output")] - PrincipalRoundFinishedButStoreDoesntOutput, - #[error("echo round is finished, but store doesn't output")] - EchoRoundFinishedButStoreDoesntOutput, + MainRoundFinishedButStoreDoesntOutput, #[error("handle incoming echo msg")] HandleEcho(#[source] crate::round::RoundInputError), @@ -28,19 +24,28 @@ pub(super) enum Reason { #[error("reliability check error: messages were not reliably broadcasted")] MismatchedHash, - #[error("round has already returned output or error")] - StateFinished, #[error("impossible state (it's a bug)")] StateGone, #[error("main round is in unexpected state (it's a bug)")] UnexpectedMainRoundState, + + #[error("clone error msg: RoundMsg implementation is incorrect")] + RoundMsgClone, + + #[error( + "protocol attempts to send a broadcast msg twice within the same round, it's unsupported" + )] + SendTwice, + + #[error("cannot convert a sent round msg back from proto msg (it's a bug)")] + SentMsgFromProto, } #[derive(thiserror::Error, Debug)] pub enum Error { #[error("error originated in principal protocol")] - Principal(#[source] E), + Main(#[source] E), #[error("echo broadcast")] Echo(#[from] EchoError), } @@ -50,3 +55,31 @@ impl From for Error { Error::Echo(value.into()) } } + +#[derive(thiserror::Error, Debug)] +#[error(transparent)] +pub struct CompleteRoundError( + #[from] CompleteRoundReason, +); + +#[derive(thiserror::Error, Debug)] +pub(super) enum CompleteRoundReason { + #[error(transparent)] + CompleteRound(CompleteErr), + #[error(transparent)] + Send(SendErr), + #[error(transparent)] + Echo(Reason), +} + +impl From for CompleteRoundError { + fn from(err: Reason) -> Self { + CompleteRoundError(CompleteRoundReason::Echo(err.into())) + } +} + +impl From for CompleteRoundError { + fn from(err: EchoError) -> Self { + err.0.into() + } +} diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index 786e1a0..8164c81 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -1,23 +1,41 @@ +//! Reliable broadcast for any protocol via echo messages +//! +//! Broadcast message is a message meant to be received by all participants of the protocol. +//! +//! We say that message is reliably broadcasted if, upon reception, it is guaranteed that all +//! honest participants of the protocol has received the same message. +//! +//! One way to achieve the reliable broadcast is by adding an echo round: when we receive +//! messages in a reliable broadcast round, we hash all messages, and send the hash to all +//! other participants. If party receives a the same hash from everyone else, we can be +//! assured that messages in the round were reliably broadcasted. +//! +//! This module provides a mechanism that automatically add an echo round per each +//! round of the protocol that requires a reliable broadcast. + use core::marker::PhantomData; -use alloc::vec::Vec; +use alloc::collections::btree_map::BTreeMap; use digest::Digest; use crate::{ - round::RoundStore, Incoming, Mpc, MpcExecution, MsgId, Outgoing, PartyIndex, ProtocolMsg, - RoundMsg, + round::{RoundInfo, RoundStore, RoundStoreExt}, + Mpc, MpcExecution, Outgoing, ProtocolMsg, RoundMsg, }; mod error; mod store; -// TODO: remove -// mod state; - +/// Message of the protocol with echo broadcast round(s) pub enum Msg { /// Message from echo broadcast sub-protocol Echo { /// Indicates for which round of main protocol this echo message is transmitted + /// + /// Note that this field is controlled by potential malicious party. If it sets it + /// to the round that doesn't exist, the protocol will likely be aborted with an error + /// that we received a message from unregistered round, which may appear as implementation + /// error (i.e. API misuse), but in fact it's a malicious abort. round: u16, /// Hash of all messages received in `round` hash: digest::Output, @@ -26,9 +44,25 @@ pub enum Msg { Main(M), } -struct EchoMsg { - hash: digest::Output, - _round: PhantomData, +/// Sub-messages of [`Msg`] +/// +/// Sub-messages implement [`RoundMsg`] trait for [`Msg`] +mod sub_msg { + pub struct EchoMsg { + pub hash: digest::Output, + pub _round: core::marker::PhantomData, + } + #[derive(Debug, Clone)] + pub struct Main(pub M); + + impl Clone for EchoMsg { + fn clone(&self) -> Self { + Self { + hash: self.hash.clone(), + _round: core::marker::PhantomData, + } + } + } } impl Clone for Msg { @@ -43,15 +77,6 @@ impl Clone for Msg { } } -impl Clone for EchoMsg { - fn clone(&self) -> Self { - Self { - hash: self.hash.clone(), - _round: PhantomData, - } - } -} - impl ProtocolMsg for Msg { fn round(&self) -> u16 { match self { @@ -61,20 +86,20 @@ impl ProtocolMsg for Msg { } } -impl RoundMsg> for Msg +impl RoundMsg> for Msg where M: RoundMsg, { const ROUND: u16 = 2 * M::ROUND + 1; - fn to_protocol_msg(round_msg: EchoMsg) -> Self { + fn to_protocol_msg(round_msg: sub_msg::EchoMsg) -> Self { Self::Echo { round: M::ROUND, hash: round_msg.hash, } } - fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { match protocol_msg { - Self::Echo { round, hash } if round == M::ROUND => Ok(EchoMsg { + Self::Echo { round, hash } if round == M::ROUND => Ok(sub_msg::EchoMsg { hash, _round: PhantomData, }), @@ -83,20 +108,18 @@ where } } -struct Principal(M); - -impl RoundMsg> for Msg +impl RoundMsg> for Msg where ProtoM: ProtocolMsg + RoundMsg, { const ROUND: u16 = 2 * >::ROUND; - fn to_protocol_msg(round_msg: Principal) -> Self { + fn to_protocol_msg(round_msg: sub_msg::Main) -> Self { Self::Main(ProtoM::to_protocol_msg(round_msg.0)) } - fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { if let Self::Main(msg) = protocol_msg { ProtoM::from_protocol_msg(msg) - .map(Principal) + .map(sub_msg::Main) .map_err(|m| Self::Main(m)) } else { Err(protocol_msg) @@ -104,112 +127,188 @@ where } } -pub fn wrap(party: M, i: u16) -> WithReliableBroadcast +/// Wraps an [`Mpc`] engine and provides echo broadcast capabilities +pub fn wrap(party: M, i: u16, n: u16) -> WithReliableBroadcast where D: Digest, - M: Mpc>, - PrincipalMsg: udigest::Digestable, + M: Mpc>, + MainMsg: udigest::Digestable, { - todo!() + WithReliableBroadcast { + party, + i, + n, + sent_reliable_msgs: Default::default(), + _ph: PhantomData, + } } -pub struct WithReliableBroadcast { +/// [`Mpc`] engine with echo-broadcast capabilities +pub struct WithReliableBroadcast { party: M, i: u16, n: u16, + sent_reliable_msgs: BTreeMap>, _ph: PhantomData, } -impl Mpc for WithReliableBroadcast +impl WithReliableBroadcast { + fn map_party

(self, f: impl FnOnce(M) -> P) -> WithReliableBroadcast { + let party = f(self.party); + WithReliableBroadcast { + party, + i: self.i, + n: self.n, + sent_reliable_msgs: self.sent_reliable_msgs, + _ph: PhantomData, + } + } +} + +impl Mpc for WithReliableBroadcast where - D: Digest, - M: Mpc>, - PrincipalMsg: Clone, + D: Digest + 'static, + M: Mpc>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, { - type Msg = PrincipalMsg; + type Msg = MainMsg; - type Exec = WithReliableBroadcast; + type Exec = WithReliableBroadcast; - type SendErr = M::SendErr; + type SendErr = error::Error; fn add_round(&mut self, round: R) -> ::Round where R: RoundStore, Self::Msg: RoundMsg, { - // let reliable_broadcast_required = round - // .read_prop::() - // .map(|x| x.0); - // if reliable_broadcast_required == Some(true) { - // let round_state = state::State::::init(round, self.i, self.n); - // let round_state = core::cell::RefCell::new(round_state); - // let round_state = alloc::rc::Rc::new(round_state); - - // let store = store::RoundWithEcho(round_state.clone()); - - // todo!() - // } - // Round(self.party.add_round(Principal(round))) - todo!() + let reliable_broadcast_required = round + .read_prop::() + .map(|x| x.0); + if reliable_broadcast_required == Some(true) { + let (main_round, echo_round) = store::new::(self.i, self.n, round); + let main_round = self.party.add_round(store::WithMainMsg(main_round)); + let echo_round = self.party.add_round(store::WithEchoError::from(echo_round)); + + self.sent_reliable_msgs.insert(Self::Msg::ROUND, None); + + Round(Inner::WithReliabilityCheck { + main_round, + echo_round, + }) + } else { + let round = self + .party + .add_round(store::WithError(store::WithMainMsg(round))); + Round(Inner::Unmodified(round)) + } } fn finish(self) -> Self::Exec { - // WithReliableBroadcast { - // party: self.party.finish(), - // i: self.i, - // _ph: PhantomData, - // } - todo!() + self.map_party(|p| p.finish()) } } -impl MpcExecution for WithReliableBroadcast +impl WithReliableBroadcast where D: Digest, - M: MpcExecution>, - PrincipalMsg: Clone, + MainMsg: ProtocolMsg + Clone, { - type Round = Round; - type Msg = PrincipalMsg; - type CompleteRoundErr = M::CompleteRoundErr; - type SendErr = M::SendErr; - type SendMany = SendMany; + fn on_send(&mut self, outgoing: &Outgoing) -> Result<(), error::EchoError> { + if let Some(slot) = self.sent_reliable_msgs.get_mut(&outgoing.msg.round()) { + if slot.is_some() { + return Err(error::Reason::SendTwice.into()); + } + *slot = Some(outgoing.msg.clone()) + } + Ok(()) + } +} + +impl MpcExecution for WithReliableBroadcast +where + D: Digest + 'static, + M: MpcExecution>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, +{ + type Round = Round; + type Msg = MainMsg; + type CompleteRoundErr = + error::CompleteRoundError>, M::SendErr>; + type SendErr = error::Error; + type SendMany = WithReliableBroadcast; async fn complete( &mut self, round: Self::Round, ) -> Result> where - R: RoundStore, + R: RoundInfo, Self::Msg: RoundMsg, { match round.0 { - Inner::Unmodified(round) => self.party.complete(round).await, + Inner::Unmodified(round) => { + // regular round that doesn't need reliable broadcast + let output = self + .party + .complete(round) + .await + .map_err(error::CompleteRoundReason::CompleteRound)?; + Ok(output) + } Inner::WithReliabilityCheck { main_round, echo_round, } => { - todo!() + // receive all messages in the main round + let main_output = self + .party + .complete(main_round) + .await + .map_err(error::CompleteRoundReason::CompleteRound)?; + // retrieve a msg that we sent in this round + let sent_msg = + if let Some(Some(msg)) = self.sent_reliable_msgs.remove(&Self::Msg::ROUND) { + let msg: R::Msg = Self::Msg::from_protocol_msg(msg) + .map_err(|_| error::Reason::SentMsgFromProto)?; + Some(msg) + } else { + None + }; + // calculate a hash and send it to all other parties + let (main_output, hash) = main_output.with_my_msg(sent_msg)?; + self.party + .send_to_all(Msg::Echo { + round: Self::Msg::ROUND, + hash, + }) + .await + .map_err(error::CompleteRoundReason::Send)?; + // receive echoes from other parties + let echoes = self + .party + .complete(echo_round) + .await + .map_err(|e| error::CompleteRoundReason::CompleteRound(e))?; + // check that everyone sent the same hash + let main_output = main_output.with_echo_output(echoes)?; + + Ok(main_output) } } } - async fn receive_and_process_one_message( - &mut self, - ) -> Result<(), Self::CompleteRoundErr> { - self.party.receive_and_process_one_message().await - } + async fn send(&mut self, outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&outgoing)?; - async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { - self.party.send(msg.map(Msg::Main)).await + self.party + .send(outgoing.map(Msg::Main)) + .await + .map_err(error::Error::Main) } fn send_many(self) -> Self::SendMany { - SendMany { - sender: self.party.send_many(), - i: self.i, - _ph: PhantomData, - } + self.map_party(|p| p.send_many()) } async fn yield_now(&self) { @@ -217,83 +316,57 @@ where } } -pub struct Round(Inner); +/// Round registration witness +/// +/// Returned by [`WithReliableBroadcast::add_round()`] +pub struct Round(Inner) +where + M: MpcExecution, + D: Digest + 'static, + ProtoMsg: 'static, + R: RoundInfo; -enum Inner { +enum Inner +where + M: MpcExecution, + D: Digest + 'static, + ProtoMsg: 'static, + R: RoundInfo, +{ /// Round that we do not modify (round that doesn't require reliable broadcast) - Unmodified(M::Round>), + Unmodified(M::Round>>), WithReliabilityCheck { - main_round: M::Round>, - echo_round: M::Round>, + main_round: M::Round>>, + echo_round: M::Round, R::Error>>, }, } -impl RoundStore for Principal { - type Msg = Principal; - type Output = R::Output; - type Error = R::Error; - - fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { - self.0.add_message(msg.map(|m| m.0)) - } - fn wants_more(&self) -> bool { - self.0.wants_more() - } - fn output(self) -> Result { - self.0.output().map_err(|s| Principal(s)) - } -} - -pub struct SendMany { - sender: M, - i: u16, - _ph: PhantomData, -} - -impl crate::mpc::SendMany for SendMany +impl crate::mpc::SendMany for WithReliableBroadcast where - D: Digest, - M: crate::mpc::SendMany>, - PrincipalMsg: Clone, + D: Digest + 'static, + M: crate::mpc::SendMany>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, { - type Exec = WithReliableBroadcast; - type Msg = PrincipalMsg; - type SendErr = M::SendErr; - - async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { - self.sender.send(msg.map(Msg::Main)).await + type Exec = WithReliableBroadcast; + type Msg = MainMsg; + type SendErr = error::Error; + + async fn send(&mut self, outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&outgoing)?; + self.party + .send(outgoing.map(Msg::Main)) + .await + .map_err(error::Error::Main) } async fn flush(self) -> Result { - let party = self.sender.flush().await?; - // Ok(WithReliableBroadcast { - // party, - // i: self.i, - // _ph: PhantomData, - // }) - todo!() + let party = self.party.flush().await.map_err(error::Error::Main)?; + Ok(WithReliableBroadcast { + party, + i: self.i, + n: self.n, + sent_reliable_msgs: self.sent_reliable_msgs, + _ph: PhantomData, + }) } } - -type SharedRoundState = - alloc::rc::Rc, RoundStateError>>>; - -struct RoundState { - my_msg: Option, - received_msgs: Vec>, - received_echoes: Vec>>, -} - -#[derive(Debug, thiserror::Error)] -enum RoundStateError { - /// Party sent two messages in one round - /// - /// `msgs_ids` are ids of conflicting messages - #[error("party {sender} tried to overwrite message")] - AttemptToOverwriteReceivedMsg { - /// IDs of conflicting messages - msgs_ids: [MsgId; 2], - /// Index of party who sent two messages in one round - sender: PartyIndex, - }, -} diff --git a/round-based/src/echo_broadcast/state.rs b/round-based/src/echo_broadcast/state.rs deleted file mode 100644 index 8f66b07..0000000 --- a/round-based/src/echo_broadcast/state.rs +++ /dev/null @@ -1,271 +0,0 @@ -#![allow(dead_code)] // TODO: remove - -use alloc::vec::Vec; - -use digest::Digest; - -use crate::{ - round::{RoundInput, RoundStore}, - Incoming, -}; - -use super::error; - -const TAG: &[u8] = b"dfns.round_based.echo_broadcast"; - -pub enum State { - /// A round from principal protocol is ongoing - PrincipalRound(PrincipalRound), - /// Principal round is completed, we need to send an echo message - SendEchoMsg(digest::Output, EchoRound), - /// Principal round is completed, echo round is ongoing - EchoRound(EchoRound), - - /// Principal and echo rounds are finished, reliability check has passed - Output(S::Output), - - /// Indicates that the round has previously already returned the output. - /// Calling any methods when state is finished results into an error. - Finished, - - /// Indicates that the state is temporarily moved. Calling methods when state - /// is gone results into an error and indicates a bug. - Gone, -} - -struct PrincipalRound { - my_msg: Option, - received_msgs: Vec>, - store: S, - received_echoes: RoundInput>, -} - -struct EchoRound { - store_output: O, - received_echoes: RoundInput>, - expected_hash: digest::Output, -} - -impl State { - pub fn init(store: S, i: u16, n: u16) -> Self { - State::PrincipalRound(PrincipalRound { - my_msg: None, - received_msgs: core::iter::repeat_with(|| None).take(n.into()).collect(), - store, - received_echoes: crate::round::broadcast(i, n), - }) - } - - /// Takes the state by value, replaces `self` with `State::Gone`. - /// - /// `self` must be overwritten. Not overwriting a `State::Gone` is a bug - fn take(&mut self) -> Result { - match core::mem::replace(self, State::Gone) { - State::Gone => Err(error::Reason::StateGone), - state => Ok(state), - } - } - - /// Indicates that [`take_output`] method will return `Some(_)` - pub fn ready_to_output(&self) -> bool { - matches!(self, State::Output(_)) - } - - /// Retrieves the round output if reliability check has passed - pub fn take_output(&mut self) -> Option { - match self.take().ok()? { - State::Output(out) => { - *self = State::Finished; - Some(out) - } - state => { - *self = state; - None - } - } - } -} -impl State -where - D: Digest, - S: RoundStore, - S::Msg: udigest::Digestable + Clone, -{ - pub fn received_principal_msg( - &mut self, - incoming: Incoming, - ) -> Result<(), error::Error> { - let state = self.take()?; - match state.received_principal_msg_inner(incoming) { - Ok(next_state) => { - *self = next_state; - Ok(()) - } - Err(err) => { - *self = State::Finished; - Err(err) - } - } - } - fn received_principal_msg_inner( - self, - incoming: Incoming, - ) -> Result> { - match self { - State::PrincipalRound(mut principal_round) => { - principal_round - .store - .add_message(incoming.clone()) - .map_err(error::Error::Principal)?; - let n = principal_round.received_msgs.len(); - let slot = principal_round - .received_msgs - .get_mut(usize::from(incoming.sender)) - .ok_or(error::Reason::UnknownSender { - i: incoming.sender, - n, - })?; - if slot.is_some() { - return Err(error::Reason::StoreReceivedTwoMsgsFromSameParty.into()); - } - *slot = Some(incoming.msg); - - principal_round.advance_if_possible().map_err(Into::into) - } - State::SendEchoMsg(..) | State::EchoRound(_) | State::Output(_) => { - Err(error::Reason::ReceivedPrincipalMsgWhenRoundOver.into()) - } - State::Finished => Err(error::Reason::StateFinished.into()), - State::Gone => Err(error::Reason::StateGone.into()), - } - } - - pub fn received_echo_msg( - &mut self, - incoming: Incoming>, - ) -> Result<(), error::EchoError> { - let state = self.take()?; - match state.received_echo_msg_inner(incoming) { - Ok(next_state) => { - *self = next_state; - Ok(()) - } - Err(err) => { - *self = State::Finished; - Err(err) - } - } - } - fn received_echo_msg_inner( - self, - incoming: Incoming>, - ) -> Result { - match self { - State::PrincipalRound(mut round) => { - round - .received_echoes - .add_message(incoming) - .map_err(error::Reason::HandleEcho)?; - Ok(State::PrincipalRound(round)) - } - State::SendEchoMsg(msg, mut round) => { - round - .received_echoes - .add_message(incoming) - .map_err(error::Reason::HandleEcho)?; - Ok(State::SendEchoMsg(msg, round)) - } - State::EchoRound(mut round) => { - round - .received_echoes - .add_message(incoming) - .map_err(error::Reason::HandleEcho)?; - round.advance_if_possible().map_err(Into::into) - } - State::Output(_output) => Err(error::Reason::ReceivedEchoMsgWhenRoundOver.into()), - State::Finished => Err(error::Reason::StateFinished.into()), - State::Gone => Err(error::Reason::StateGone.into()), - } - } - - /// Retrieves an echo msg and marks it as sent - /// - /// Returns `None` if there's no echo msg to be sent (yet or already) - pub fn take_echo_msg(&mut self) -> Option> { - let state = self.take().ok()?; - let (next_state, msg) = state.take_echo_msg_inner(); - *self = next_state; - msg - } - fn take_echo_msg_inner(self) -> (Self, Option>) { - match self { - State::SendEchoMsg(echo_msg, echo_round) => { - (State::EchoRound(echo_round), Some(echo_msg)) - } - state => (state, None), - } - } -} - -impl PrincipalRound -where - D: Digest, - S: RoundStore, - S::Msg: udigest::Digestable, -{ - fn advance_if_possible(self) -> Result, error::Reason> { - if !self.store.wants_more() { - // Principal round is over, we can start the echo round - let output = self - .store - .output() - .map_err(|_| error::Reason::PrincipalRoundFinishedButStoreDoesntOutput)?; - let echo_msg = udigest::hash::(&udigest::inline_struct!(TAG { - received_msgs: &self.received_msgs, - })); - Ok(State::SendEchoMsg( - echo_msg.clone(), - EchoRound { - store_output: output, - received_echoes: self.received_echoes, - expected_hash: echo_msg, - }, - )) - } else { - Ok(State::PrincipalRound(self)) - } - } -} - -impl EchoRound -where - D: Digest, -{ - fn advance_if_possible>(self) -> Result, error::Reason> { - if !self.received_echoes.wants_more() { - // Echo round is over, now handle the received messages - let echoes = self - .received_echoes - .output() - .map_err(|_| error::Reason::EchoRoundFinishedButStoreDoesntOutput)?; - if echoes.iter().any(|m| *m != self.expected_hash) { - // echo check failed, abort! - Err(error::Reason::MismatchedHash) - } else { - Ok(State::Output(self.store_output)) - } - } else { - Ok(State::EchoRound(self)) - } - } -} - -pub trait IsFinished { - fn is_finished(&self) -> bool; -} - -impl IsFinished for core::cell::RefCell> { - fn is_finished(&self) -> bool { - self.borrow().ready_to_output() - } -} diff --git a/round-based/src/echo_broadcast/store.rs b/round-based/src/echo_broadcast/store.rs index a1ea2ee..e9af8b3 100644 --- a/round-based/src/echo_broadcast/store.rs +++ b/round-based/src/echo_broadcast/store.rs @@ -3,15 +3,38 @@ use core::marker::PhantomData; use digest::Digest; use crate::{ - round::{RoundInput, RoundMsgs, RoundStore}, - Incoming, + round::{RoundInfo, RoundInput, RoundMsgs, RoundStore}, + Incoming, RoundMsg, }; -use super::{error, EchoMsg}; +use super::{error, sub_msg}; const TAG: &[u8] = b"dfns.round_based.echo_broadcast"; -enum MainRoundState { +pub fn new( + i: u16, + n: u16, + main_round: S, +) -> (MainRound, EchoRound) { + let params = Params { i, n }; + let state = match main_round.output() { + Ok(output) => MainRoundState::Output { output }, + Err(store) => MainRoundState::Ongoing { store }, + }; + let main_round = MainRound { + params, + state, + received_msgs: core::iter::repeat_with(|| None).take(n.into()).collect(), + _ph: PhantomData, + }; + let echo_round = EchoRound { + echo_round: RoundInput::broadcast(i, n), + _round: PhantomData, + }; + (main_round, echo_round) +} + +enum MainRoundState { Ongoing { store: S }, Output { output: S::Output }, Finished, @@ -24,18 +47,14 @@ struct Params { n: u16, } -pub struct MainRound { +pub struct MainRound { params: Params, state: MainRoundState, - received_msgs: Vec>, - _digest: PhantomData, + received_msgs: Vec>, + _ph: PhantomData<(D, ProtoMsg)>, } -impl RoundStore for MainRound -where - S::Msg: Clone, - D: 'static, -{ +impl RoundInfo for MainRound { type Msg = S::Msg; /// When a main round is finished, we output a builder that can be used to /// calculate a hash of messages received by all parties in the reliable @@ -43,32 +62,44 @@ where /// /// Only if we receive the same hash from all parties in [`EchoRound`], only /// then we can obtain a main round output. - type Output = NeedsOwnMsg; + type Output = NeedsOwnMsg; type Error = error::Error; - - fn add_message(&mut self, incoming: Incoming) -> Result<(), Self::Error> { +} +impl RoundStore for MainRound +where + ProtoMsg: RoundMsg + Clone + 'static, + D: 'static, +{ + fn add_message(&mut self, mut incoming: Incoming) -> Result<(), Self::Error> { let wants_more = match &mut self.state { MainRoundState::Ongoing { store, .. } => { - store - .add_message(incoming.clone()) - .map_err(error::Error::Principal)?; + // We pretend that msg was reliably broadcasted even though the reliability check is + // not yet enforced, however, we do not expose the output of the round unless + // reliability check has passed. + incoming.msg_type = crate::MessageType::Broadcast { reliable: true }; + + // Note: round msg doesn't implement `Clone`, but ProtoMsg does, so we + // use a trick to create a clone of incoming msg + let (incoming1, incoming2) = clone_incoming_round_msg::(incoming) + .ok_or(error::Reason::RoundMsgClone)?; + store.add_message(incoming1).map_err(error::Error::Main)?; let n = self.received_msgs.len(); let slot = self .received_msgs - .get_mut(usize::from(incoming.sender)) + .get_mut(usize::from(incoming2.sender)) .ok_or(error::Reason::UnknownSender { - i: incoming.sender, + i: incoming2.sender, n, })?; if slot.is_some() { return Err(error::Reason::StoreReceivedTwoMsgsFromSameParty.into()); } - *slot = Some(incoming.msg); + *slot = Some(ProtoMsg::to_protocol_msg(incoming2.msg)); store.wants_more() } MainRoundState::Gone => return Err(error::Reason::StateGone.into()), MainRoundState::Output { .. } | MainRoundState::Finished => { - return Err(error::Reason::ReceivedPrincipalMsgWhenRoundOver.into()) + return Err(error::Reason::ReceivedMainMsgWhenRoundOver.into()) } }; @@ -79,7 +110,7 @@ where }; let Ok(output) = store.output() else { self.state = MainRoundState::Finished; - return Err(error::Reason::PrincipalRoundFinishedButStoreDoesntOutput.into()); + return Err(error::Reason::MainRoundFinishedButStoreDoesntOutput.into()); }; self.state = MainRoundState::Output { output }; } @@ -111,55 +142,97 @@ where params: self.params, state, received_msgs: self.received_msgs, - _digest: PhantomData, + _ph: PhantomData, }) } } } } +/// Duplicates a round msg +/// +/// This function doesn't require that round msg implements `Clone`, instead it only requires +/// that protocol msg is cloneable. It works by converting round msg into protocol msg, creating +/// two clones, and converting them back to round msg. +/// +/// We need this function in places where we know that protocol msg is cloneable, but we can't +/// prove to the compiler that round msg is cloneable as well. +/// +/// Function returns `None` only if [`RoundMsg`] implementation is not correct. +fn clone_round_msg(round_msg: R) -> Option<(R, R)> +where + M: RoundMsg + Clone, +{ + let proto_msg = M::to_protocol_msg(round_msg); + + let round_msg1 = M::from_protocol_msg(proto_msg.clone()).ok()?; + let round_msg2 = M::from_protocol_msg(proto_msg).ok()?; + + Some((round_msg1, round_msg2)) +} + +/// Similar to [`clone_round_msg`] but accepts [`Incoming`](Incoming) +fn clone_incoming_round_msg( + incoming_round_msg: Incoming, +) -> Option<(Incoming, Incoming)> +where + M: RoundMsg + Clone, +{ + let (msg1, msg2) = clone_round_msg::(incoming_round_msg.msg)?; + + let incoming = |msg| Incoming { + id: incoming_round_msg.id, + sender: incoming_round_msg.sender, + msg_type: incoming_round_msg.msg_type, + msg, + }; + Some((incoming(msg1), incoming(msg2))) +} + /// An output of [`MainRound`] which needs an own message sent by local party /// in this round. Once provided in [`NeedsOwnMsg::with_own_msg`], it outputs /// a hash of messages received by all parties in this round (that needs to be /// re-sent to all participants), and [`WithReliabilityCheck`] that takes /// messages received in echo round and outputs main round result only if reliability /// check passes. -pub struct NeedsOwnMsg { +pub struct NeedsOwnMsg { params: Params, main_round_output: S::Output, - received_msgs: Vec>, + received_msgs: Vec>, _hash: PhantomData, } -pub struct WithReliabilityCheck { +pub struct ReliabilityCheck { expected_hash: digest::Output, main_round_output: S::Output, } -impl NeedsOwnMsg +impl NeedsOwnMsg where D: Digest, - S: RoundStore, - S::Msg: udigest::Digestable, + S: RoundInfo, + ProtoMsg: RoundMsg + udigest::Digestable, { pub fn with_my_msg( mut self, - msg: S::Msg, - ) -> Result<(WithReliabilityCheck, digest::Output), error::EchoError> { + msg: Option, + ) -> Result<(ReliabilityCheck, digest::Output), error::EchoError> { let n = self.received_msgs.len(); + let msg = msg.map(ProtoMsg::to_protocol_msg); *self .received_msgs .get_mut(usize::from(self.params.i)) .ok_or(error::Reason::OwnIndexOutOfBounds { i: self.params.i, n, - })? = Some(msg); + })? = msg; let hash = udigest::hash::(&udigest::inline_struct!(TAG { msgs: &self.received_msgs, + round: ProtoMsg::ROUND, n: self.params.n, })); - let with_reliability_check = WithReliabilityCheck { + let with_reliability_check = ReliabilityCheck { expected_hash: hash.clone(), main_round_output: self.main_round_output, }; @@ -167,10 +240,10 @@ where } } -impl WithReliabilityCheck +impl ReliabilityCheck where D: Digest, - S: RoundStore, + S: RoundInfo, { pub fn with_echo_output( self, @@ -188,7 +261,7 @@ where } } -pub struct EchoRound { +pub(super) struct EchoRound { echo_round: RoundInput>, _round: PhantomData, } @@ -198,15 +271,20 @@ pub struct EchoRoundOutput { _round: PhantomData, } -impl RoundStore for EchoRound +impl RoundInfo for EchoRound where D: Digest + 'static, - S: RoundStore, + S: RoundInfo, { - type Msg = EchoMsg; + type Msg = sub_msg::EchoMsg; type Output = EchoRoundOutput; type Error = error::EchoError; - +} +impl RoundStore for EchoRound +where + D: Digest + 'static, + S: RoundStore, +{ fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { self.echo_round .add_message(msg.map(|m| m.hash)) @@ -231,3 +309,89 @@ where }) } } + +/// Wraps a round store `S` and changes its msg type to `sub_msg::Main` +pub struct WithMainMsg(pub S); + +impl RoundInfo for WithMainMsg { + type Msg = sub_msg::Main; + type Output = S::Output; + type Error = S::Error; +} + +impl RoundStore for WithMainMsg { + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.0.add_message(msg.map(|m| m.0)) + } + fn wants_more(&self) -> bool { + self.0.wants_more() + } + fn output(self) -> Result { + self.0.output().map_err(Self) + } +} + +/// Wraps a round store `S` and changes its error to `Error` +pub struct WithError(pub S); + +impl RoundInfo for WithError { + type Msg = S::Msg; + type Output = S::Output; + type Error = error::Error; +} + +impl RoundStore for WithError { + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.0.add_message(msg).map_err(error::Error::Main) + } + fn wants_more(&self) -> bool { + self.0.wants_more() + } + fn output(self) -> Result { + self.0.output().map_err(Self) + } +} + +/// Wraps a round store `S` with `Error = error::EchoError` and changes it to `Error = error::Error` +pub struct WithEchoError { + pub store: S, + _ph: PhantomData, +} + +impl From for WithEchoError { + fn from(store: S) -> Self { + Self { + store, + _ph: PhantomData, + } + } +} + +impl RoundInfo for WithEchoError +where + S: RoundInfo, + E: core::error::Error + 'static, +{ + type Msg = S::Msg; + type Output = S::Output; + type Error = error::Error; +} + +impl RoundStore for WithEchoError +where + S: RoundStore, + E: core::error::Error + 'static, +{ + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.store.add_message(msg).map_err(error::Error::Echo) + } + fn wants_more(&self) -> bool { + self.store.wants_more() + } + fn output(self) -> Result { + self.store.output().map_err(|store| Self { + store, + _ph: PhantomData, + }) + } +} diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index e34a1e2..8991bd8 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -67,8 +67,8 @@ mod false_positives { } mod delivery; -// #[cfg(feature = "echo-broadcast")] -// pub mod echo_broadcast; +#[cfg(feature = "echo-broadcast")] +pub mod echo_broadcast; pub mod mpc; pub mod round; #[cfg(feature = "state-machine")] diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index 10e227f..ae3b5fd 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -35,8 +35,6 @@ //! # Ok(()) } //! ``` -use core::convert::Infallible; - use crate::{ round::{RoundInfo, RoundStore}, Outgoing, PartyIndex, @@ -99,14 +97,6 @@ pub trait MpcExecution { R: RoundInfo, Self::Msg: RoundMsg; - /// Instructs the MPC driver to receive exactly one message and route it to its appropriate round store - /// - /// This is a low-level function, normally you don't need to use it. Use [`.complete()`](Self::complete) - /// to receive messages until round is completed. - async fn receive_and_process_one_message( - &mut self, - ) -> Result<(), Self::CompleteRoundErr>; - /// Sends a message /// /// This method awaits until the message is sent, which might be not the best method to use if you diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs index 3d47b2c..c270610 100644 --- a/round-based/src/mpc/party/mod.rs +++ b/round-based/src/mpc/party/mod.rs @@ -128,9 +128,13 @@ where // Round is not completed - we need more messages loop { - self.receive_and_process_one_message() + let incoming = self + .io + .next() .await - .map_err(|e| e.map_process_err(|e| match e {}))?; + .ok_or(CompleteRoundError::UnexpectedEof)? + .map_err(CompleteRoundError::Io)?; + self.router.received_msg(incoming)?; // Check if round was just completed round = match self.router.complete_round(round) { @@ -140,19 +144,6 @@ where } } - async fn receive_and_process_one_message( - &mut self, - ) -> Result<(), Self::CompleteRoundErr> { - let incoming = self - .io - .next() - .await - .ok_or(CompleteRoundError::UnexpectedEof)? - .map_err(CompleteRoundError::Io)?; - self.router.received_msg(incoming)?; - Ok(()) - } - async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { self.io.send(msg).await } From 0d3da0bea8ae5b9ce1ebc718f5aeb5cd70bbab46 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Thu, 29 May 2025 16:52:14 +0200 Subject: [PATCH 16/29] Fix doc examples Signed-off-by: Denis Varlakov --- round-based/src/round/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index 8455120..b7821b7 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -70,10 +70,12 @@ pub trait RoundStore: RoundInfo { /// pub struct AnotherProperty(String); /// /// # type Msg = (); - /// impl round_based::round::RoundStore for MyStore { + /// # impl round_based::round::RoundInfo for MyStore { /// # type Msg = Msg; /// # type Output = Vec; /// # type Error = core::convert::Infallible; + /// # } + /// impl round_based::round::RoundStore for MyStore { /// # fn add_message(&mut self, msg: round_based::Incoming) -> Result<(), Self::Error> { unimplemented!() } /// # fn wants_more(&self) -> bool { unimplemented!() } /// # fn output(self) -> Result { unimplemented!() } From 0c8679cc89545b98fdd95e2079c163d32b65c981 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Thu, 29 May 2025 16:55:30 +0200 Subject: [PATCH 17/29] clippy fix Signed-off-by: Denis Varlakov --- round-based/src/echo_broadcast/error.rs | 2 +- round-based/src/echo_broadcast/mod.rs | 2 +- round-based/src/echo_broadcast/store.rs | 14 ++++++-------- round-based/src/round/mod.rs | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index ac5b369..3c09b11 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -74,7 +74,7 @@ pub(super) enum CompleteRoundReason { impl From for CompleteRoundError { fn from(err: Reason) -> Self { - CompleteRoundError(CompleteRoundReason::Echo(err.into())) + CompleteRoundError(CompleteRoundReason::Echo(err)) } } diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index 8164c81..2c79d03 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -289,7 +289,7 @@ where .party .complete(echo_round) .await - .map_err(|e| error::CompleteRoundReason::CompleteRound(e))?; + .map_err(error::CompleteRoundReason::CompleteRound)?; // check that everyone sent the same hash let main_output = main_output.with_echo_output(echoes)?; diff --git a/round-based/src/echo_broadcast/store.rs b/round-based/src/echo_broadcast/store.rs index e9af8b3..507cc7f 100644 --- a/round-based/src/echo_broadcast/store.rs +++ b/round-based/src/echo_broadcast/store.rs @@ -137,14 +137,12 @@ where received_msgs: self.received_msgs, _hash: PhantomData, }), - state => { - return Err(Self { - params: self.params, - state, - received_msgs: self.received_msgs, - _ph: PhantomData, - }) - } + state => Err(Self { + params: self.params, + state, + received_msgs: self.received_msgs, + _ph: PhantomData, + }), } } } diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs index b7821b7..1684edd 100644 --- a/round-based/src/round/mod.rs +++ b/round-based/src/round/mod.rs @@ -27,7 +27,7 @@ pub trait RoundInfo: Sized + 'static { /// and should implement extra measures against malicious parties (e.g. prohibit message overwrite). /// /// ## Flow -/// `RoundStore` stores received messages. Once enough messages are received, it outputs [`RoundStore::Output`]. +/// `RoundStore` stores received messages. Once enough messages are received, it outputs [`RoundInfo::Output`]. /// In order to save received messages, [`.add_message(msg)`] is called. Then, [`.wants_more()`] tells whether more /// messages are needed to be received. If it returned `false`, then output can be retrieved by calling [`.output()`]. /// From afa13e902ee6b256b714db1ad9fc7c2ec8466dab Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 30 May 2025 11:02:16 +0200 Subject: [PATCH 18/29] Add a test that echo broadcast works Signed-off-by: Denis Varlakov --- Cargo.lock | 16 +++ .../random-generation-protocol/Cargo.toml | 5 + .../random-generation-protocol/src/lib.rs | 11 +- round-based-tests/Cargo.toml | 8 +- round-based-tests/src/lib.rs | 74 ++++++++++ .../tests/{rounds.rs => random_beacon.rs} | 0 .../tests/random_beacon_with_echo.rs | 126 ++++++++++++++++++ round-based/src/delivery.rs | 2 +- round-based/src/echo_broadcast/error.rs | 17 +++ round-based/src/echo_broadcast/mod.rs | 56 +++++++- 10 files changed, 304 insertions(+), 11 deletions(-) rename round-based-tests/tests/{rounds.rs => random_beacon.rs} (100%) create mode 100644 round-based-tests/tests/random_beacon_with_echo.rs diff --git a/Cargo.lock b/Cargo.lock index 4ce8fc9..909ee6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -459,6 +459,7 @@ dependencies = [ "sha2", "thiserror", "tokio", + "udigest", ] [[package]] @@ -497,12 +498,15 @@ dependencies = [ name = "round-based-tests" version = "0.1.0" dependencies = [ + "anyhow", "futures", + "hex", "hex-literal", "matches", "rand_chacha", "random-generation-protocol", "round-based", + "sha2", "tokio", ] @@ -754,6 +758,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cd61fa9fb78569e9fe34acf0048fd8cb9ebdbacc47af740745487287043ff0" dependencies = [ "digest", + "udigest-derive", +] + +[[package]] +name = "udigest-derive" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "603329303137e0d59238ee4d6b9c085eada8e2a9d20666f3abd9dadf8f8543f4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", ] [[package]] diff --git a/examples/random-generation-protocol/Cargo.toml b/examples/random-generation-protocol/Cargo.toml index ddcec5e..edd0188 100644 --- a/examples/random-generation-protocol/Cargo.toml +++ b/examples/random-generation-protocol/Cargo.toml @@ -18,6 +18,11 @@ thiserror = { version = "2", default-features = false } # We don't use it directy, but we need to enable `serde` feature generic-array = { version = "0.14", features = ["serde"] } +udigest = { version = "0.2", default-features = false, features = ["derive"], optional = true } + +[features] +udigest = ["dep:udigest"] + [dev-dependencies] round-based = { path = "../../round-based", features = ["derive", "sim", "state-machine"] } tokio = { version = "1.15", features = ["macros", "rt"] } diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index a73077c..33a52df 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -24,7 +24,8 @@ use round_based::{ }; /// Protocol message -#[derive(round_based::ProtocolMsg, Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(round_based::ProtocolMsg, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub enum Msg { /// Round 1 CommitMsg(CommitMsg), @@ -33,16 +34,20 @@ pub enum Msg { } /// Message from round 1 -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub struct CommitMsg { /// Party commitment + #[cfg_attr(feature = "udigest", udigest(as_bytes))] pub commitment: Output, } /// Message from round 2 -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub struct DecommitMsg { /// Randomness generated by party + #[cfg_attr(feature = "udigest", udigest(as_bytes))] pub randomness: [u8; 32], } diff --git a/round-based-tests/Cargo.toml b/round-based-tests/Cargo.toml index dfb0ac6..0e7e65e 100644 --- a/round-based-tests/Cargo.toml +++ b/round-based-tests/Cargo.toml @@ -7,13 +7,17 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +round-based = { path = "../round-based", features = ["echo-broadcast", "state-machine"] } + +anyhow = "1" [dev-dependencies] tokio = { version = "1", features = ["rt", "macros"] } hex-literal = "0.3" +hex = "0.4" futures = "0.3" matches = "0.1" rand_chacha = "0.3" +sha2 = "0.10" -round-based = { path = "../round-based" } -random-generation-protocol = { path = "../examples/random-generation-protocol" } +random-generation-protocol = { path = "../examples/random-generation-protocol", features = ["udigest"] } diff --git a/round-based-tests/src/lib.rs b/round-based-tests/src/lib.rs index 8b13789..237836b 100644 --- a/round-based-tests/src/lib.rs +++ b/round-based-tests/src/lib.rs @@ -1 +1,75 @@ +use round_based::{state_machine::ProceedResult, Incoming, Outgoing}; +pub struct PartySim(S); + +pub fn new_one_party_sim<'a, M, F>( + protocol: impl FnOnce(round_based::state_machine::MpcParty) -> F, +) -> PartySim + 'a> +where + M: round_based::ProtocolMsg + 'static, + F: core::future::Future + 'a, +{ + PartySim(round_based::state_machine::wrap_protocol(protocol)) +} + +impl PartySim { + #[track_caller] + pub fn receives(&mut self, msg: Incoming) { + loop { + match self.0.proceed() { + ProceedResult::NeedsOneMoreMessage => { + // that's exactly what we want + break; + } + ProceedResult::Yielded => { + // Protocol yielded, we ignore that + continue; + } + // everything else is unexpected + r => panic!("state machine proceed: expected NeedsOneMoreMessage, got {r:?}"), + } + } + self.0 + .received_msg(msg) + .ok() + .expect("couldn't feed a message into simulation") + } + + #[track_caller] + pub fn sends(&mut self) -> Expect> { + loop { + match self.0.proceed() { + ProceedResult::SendMsg(m) => break Expect(m), + ProceedResult::Yielded => continue, + r => panic!("state machine proceed: expected SendMsg, got {r:?}"), + } + } + } + + #[track_caller] + pub fn outputs(&mut self) -> Expect { + loop { + match self.0.proceed() { + ProceedResult::Output(r) => break Expect(r), + ProceedResult::Yielded => continue, + r => panic!("state machine proceed: expected Output, got {r:?}"), + } + } + } +} + +pub struct Expect(pub T); + +impl Expect { + #[track_caller] + pub fn expect_eq(&self, expected: &T) { + assert_eq!(self.0, *expected) + } +} + +impl Expect> { + #[track_caller] + pub fn unwrap(self) -> Expect { + Expect(self.0.unwrap()) + } +} diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/random_beacon.rs similarity index 100% rename from round-based-tests/tests/rounds.rs rename to round-based-tests/tests/random_beacon.rs diff --git a/round-based-tests/tests/random_beacon_with_echo.rs b/round-based-tests/tests/random_beacon_with_echo.rs new file mode 100644 index 0000000..bcc515b --- /dev/null +++ b/round-based-tests/tests/random_beacon_with_echo.rs @@ -0,0 +1,126 @@ +use hex_literal::hex; +use rand_chacha::rand_core::SeedableRng; + +use random_generation_protocol::{protocol_of_random_generation, CommitMsg, DecommitMsg, Msg}; +use round_based::{echo_broadcast as echo, Incoming, MessageType, Outgoing}; + +const PARTY0_SEED: [u8; 32] = + hex!("6772d079d5c984b3936a291e36b0d3dc6c474e36ed4afdfc973ef79a431ca870"); +const PARTY0_COMMITMENT: [u8; 32] = + hex!("6ac69d1b2082536de5d4f5092f807873173fac86117cb7b5e179b191e97fc8d7"); +const PARTY0_RANDOMNESS: [u8; 32] = + hex!("15f88064c7daeb863f37c3a41466a827f5ca07b7f7dd8e58c510ff612e641ea2"); +const PARTY1_COMMITMENT: [u8; 32] = + hex!("2a8c585d9a80cb78bc226f4ab35a75c8e5834ff77a83f41cf6c893ea0f3b2aed"); +const ECHO_MSG: [u8; 32] = hex!("b9c51267eae8dc5ea988111e933709fa8496b5cab4aa11778f61271527d3760a"); +const PARTY1_RANDOMNESS: [u8; 32] = + hex!("12a595f4893fdb4ab9cc38caeec5f7456acb3002ca58457c5056977ce59136a6"); +const PARTY2_COMMITMENT: [u8; 32] = + hex!("01274ef40aece8aa039587cc05620a19b80a5c93fbfb24a9f8e1b77b7936e47d"); +const PARTY2_RANDOMNESS: [u8; 32] = + hex!("6fc78a926c7eebfad4e98e796cd53b771ac5947b460567c7ea441abb957c89c7"); +const PROTOCOL_OUTPUT: [u8; 32] = + hex!("689a9f02229bdb36521275179676641585c4a3ce7b80ace37f0272a65e89a1c3"); +const PARTY_OVERWRITES: [u8; 32] = + hex!("00aa11bb22cc33dd44ee55ff6677889900aa11bb22cc33dd44ee55ff66778899"); + +#[tokio::test] +async fn random_generation_completes() { + let mut sim = simulation(); + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY0_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 0, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 1, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY2_COMMITMENT.into(), + })), + }); + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 2, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 3, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY0_RANDOMNESS.into(), + })), + }); + sim.receives(Incoming { + id: 4, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY1_RANDOMNESS, + })), + }); + sim.receives(Incoming { + id: 5, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY2_RANDOMNESS, + })), + }); + + sim.outputs().unwrap().expect_eq(&PROTOCOL_OUTPUT); +} + +fn simulation() -> round_based_tests::PartySim< + impl round_based::state_machine::StateMachine< + Msg = echo::Msg, + Output = Result< + [u8; 32], + random_generation_protocol::Error< + round_based::echo_broadcast::CompleteRoundError< + round_based::mpc::party::CompleteRoundError< + round_based::echo_broadcast::Error, + round_based::state_machine::DeliveryErr, + >, + round_based::state_machine::DeliveryErr, + >, + round_based::echo_broadcast::Error, + >, + >, + >, +> { + let rng = rand_chacha::ChaCha8Rng::from_seed(PARTY0_SEED); + round_based_tests::new_one_party_sim(|party| async { + let party = round_based::echo_broadcast::wrap(party, 0, 3); + protocol_of_random_generation(party, 0, 3, rng).await + }) +} diff --git a/round-based/src/delivery.rs b/round-based/src/delivery.rs index ea9c856..aa230f9 100644 --- a/round-based/src/delivery.rs +++ b/round-based/src/delivery.rs @@ -89,7 +89,7 @@ impl Incoming { } /// Outgoing message -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Outgoing { /// Message destination: either one party (p2p message) or all parties (broadcast message) pub recipient: MessageDestination, diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index 3c09b11..107fc97 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -40,6 +40,23 @@ pub(super) enum Reason { #[error("cannot convert a sent round msg back from proto msg (it's a bug)")] SentMsgFromProto, + + #[error( + "sent a message that doesn't require reliable broadcast in reliable broadcast \ + round (round: {round}, dest: {dest:?})" + )] + SentNonReliableMsgInReliableRound { + dest: crate::MessageDestination, + round: u16, + }, + + #[error( + "sent a reliable broadcast message in a regular round that doesn't require \ + reliable broadcast: you might have forgotten to register a reliable broadcast \ + round, or round store doesn't expose a property required to identify a reliable \ + broadcast round (round: {round})" + )] + SentReliableMsgInNonReliableRound { round: u16 }, } #[derive(thiserror::Error, Debug)] diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index 2c79d03..36840f8 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -26,6 +26,8 @@ use crate::{ mod error; mod store; +pub use self::error::{CompleteRoundError, EchoError, Error}; + /// Message of the protocol with echo broadcast round(s) pub enum Msg { /// Message from echo broadcast sub-protocol @@ -77,6 +79,32 @@ impl Clone for Msg { } } +impl PartialEq for Msg { + fn eq(&self, other: &Self) -> bool { + match self { + Self::Echo { round, hash } => { + matches!(other, Self::Echo { round: r2, hash: h2 } if round == r2 && hash == h2) + } + Self::Main(msg) => matches!(other, Self::Main(m2) if msg == m2), + } + } +} + +impl Eq for Msg {} + +impl core::fmt::Debug for Msg { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Echo { round, hash } => f + .debug_struct("Msg::Echo") + .field("round", round) + .field("hash", hash) + .finish(), + Self::Main(msg) => f.debug_tuple("Msg::Main").field(msg).finish(), + } + } +} + impl ProtocolMsg for Msg { fn round(&self) -> u16 { match self { @@ -214,13 +242,31 @@ where D: Digest, MainMsg: ProtocolMsg + Clone, { - fn on_send(&mut self, outgoing: &Outgoing) -> Result<(), error::EchoError> { + fn on_send(&mut self, outgoing: &mut Outgoing) -> Result<(), error::EchoError> { if let Some(slot) = self.sent_reliable_msgs.get_mut(&outgoing.msg.round()) { + if !outgoing.recipient.is_reliable_broadcast() { + // it's reliable broadcast round, but message is not reliable broadcast + return Err(error::Reason::SentNonReliableMsgInReliableRound { + dest: outgoing.recipient, + round: outgoing.msg.round(), + } + .into()); + } + // Message delivery layer doesn't need to know that protocol wants this message to be + // reliably broadcasted - echo broadcast takes care of it + outgoing.recipient = crate::MessageDestination::AllParties { reliable: false }; if slot.is_some() { return Err(error::Reason::SendTwice.into()); } *slot = Some(outgoing.msg.clone()) + } else if outgoing.recipient.is_reliable_broadcast() { + // it's not a reliable broadcast round, but message is a reliable broadcast + return Err(error::Reason::SentReliableMsgInNonReliableRound { + round: outgoing.msg.round(), + } + .into()); } + Ok(()) } } @@ -298,8 +344,8 @@ where } } - async fn send(&mut self, outgoing: Outgoing) -> Result<(), Self::SendErr> { - self.on_send(&outgoing)?; + async fn send(&mut self, mut outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&mut outgoing)?; self.party .send(outgoing.map(Msg::Main)) @@ -351,8 +397,8 @@ where type Msg = MainMsg; type SendErr = error::Error; - async fn send(&mut self, outgoing: Outgoing) -> Result<(), Self::SendErr> { - self.on_send(&outgoing)?; + async fn send(&mut self, mut outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&mut outgoing)?; self.party .send(outgoing.map(Msg::Main)) .await From 49cad7456f9314c90cb8cb915836a8c7a00409f7 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 30 May 2025 12:06:26 +0200 Subject: [PATCH 19/29] Add a test with failed reliability check Signed-off-by: Denis Varlakov --- round-based-tests/src/lib.rs | 7 ++ .../tests/random_beacon_with_echo.rs | 67 ++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/round-based-tests/src/lib.rs b/round-based-tests/src/lib.rs index 237836b..46540a8 100644 --- a/round-based-tests/src/lib.rs +++ b/round-based-tests/src/lib.rs @@ -58,6 +58,7 @@ impl PartySim { } } +#[must_use = "you need to make sure the output meets tests expectations"] pub struct Expect(pub T); impl Expect { @@ -73,3 +74,9 @@ impl Expect> { Expect(self.0.unwrap()) } } +impl Expect> { + #[track_caller] + pub fn unwrap_err(self) -> Expect { + Expect(self.0.unwrap_err()) + } +} diff --git a/round-based-tests/tests/random_beacon_with_echo.rs b/round-based-tests/tests/random_beacon_with_echo.rs index bcc515b..205f225 100644 --- a/round-based-tests/tests/random_beacon_with_echo.rs +++ b/round-based-tests/tests/random_beacon_with_echo.rs @@ -1,4 +1,5 @@ use hex_literal::hex; +use matches::assert_matches; use rand_chacha::rand_core::SeedableRng; use random_generation_protocol::{protocol_of_random_generation, CommitMsg, DecommitMsg, Msg}; @@ -24,9 +25,10 @@ const PROTOCOL_OUTPUT: [u8; 32] = const PARTY_OVERWRITES: [u8; 32] = hex!("00aa11bb22cc33dd44ee55ff6677889900aa11bb22cc33dd44ee55ff66778899"); -#[tokio::test] -async fn random_generation_completes() { +#[test] +fn random_generation_completes() { let mut sim = simulation(); + // Round 0 - commitment sim.sends().expect_eq(&Outgoing { recipient: round_based::MessageDestination::AllParties { reliable: false }, msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { @@ -49,6 +51,7 @@ async fn random_generation_completes() { commitment: PARTY2_COMMITMENT.into(), })), }); + // Round 1 - echo round sim.sends().expect_eq(&Outgoing { recipient: round_based::MessageDestination::AllParties { reliable: false }, msg: echo::Msg::Echo { @@ -74,6 +77,7 @@ async fn random_generation_completes() { hash: ECHO_MSG.into(), }, }); + // Round 2 - decommitment sim.sends().expect_eq(&Outgoing { recipient: round_based::MessageDestination::AllParties { reliable: false }, msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { @@ -100,6 +104,65 @@ async fn random_generation_completes() { sim.outputs().unwrap().expect_eq(&PROTOCOL_OUTPUT); } +#[test] +fn detects_unreliable_broadcast() { + let mut sim = simulation(); + // Round 0 - commitment + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY0_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 0, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 1, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY2_COMMITMENT.into(), + })), + }); + // Round 1 - echo round + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 2, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 3, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: PARTY_OVERWRITES.into(), + }, + }); + + assert_matches!( + sim.outputs().unwrap_err().0, + random_generation_protocol::Error::Round1Receive(_) + ); +} + fn simulation() -> round_based_tests::PartySim< impl round_based::state_machine::StateMachine< Msg = echo::Msg, From f7bdac45f38eb22cecb8b3add7cd90801bf583c3 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 30 May 2025 13:17:58 +0200 Subject: [PATCH 20/29] Change errors a bit Signed-off-by: Denis Varlakov --- .../tests/random_beacon_with_echo.rs | 3 +- round-based/src/echo_broadcast/error.rs | 33 ++++++++++++++----- round-based/src/echo_broadcast/mod.rs | 8 ++--- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/round-based-tests/tests/random_beacon_with_echo.rs b/round-based-tests/tests/random_beacon_with_echo.rs index 205f225..d6f2ec8 100644 --- a/round-based-tests/tests/random_beacon_with_echo.rs +++ b/round-based-tests/tests/random_beacon_with_echo.rs @@ -159,7 +159,8 @@ fn detects_unreliable_broadcast() { assert_matches!( sim.outputs().unwrap_err().0, - random_generation_protocol::Error::Round1Receive(_) + random_generation_protocol::Error::Round1Receive(echo::CompleteRoundError::Echo(err)) + if err.reliability_check_failed() ); } diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index 107fc97..4b3fe59 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -1,7 +1,19 @@ +/// An error in echo-broadcast sub-protocol +/// +/// It may indicate that reliability check was not successful, or some other error, for instance, that +/// a round that requires reliable broadcast behaves unexpectedly (e.g. if we received two messages +/// from the same party and the main round didn't return an error). #[derive(thiserror::Error, Debug)] #[error(transparent)] pub struct EchoError(#[from] Reason); +impl EchoError { + /// Indicates that an error was caused by failed reliability check + pub fn reliability_check_failed(&self) -> bool { + matches!(self.0, Reason::MismatchedHash) + } +} + #[derive(thiserror::Error, Debug)] pub(super) enum Reason { #[error( @@ -59,10 +71,13 @@ pub(super) enum Reason { SentReliableMsgInNonReliableRound { round: u16 }, } +/// An error originated either from main protocol or echo-broadcast sub-protocol #[derive(thiserror::Error, Debug)] pub enum Error { + /// Error originated from main protocol #[error("error originated in principal protocol")] Main(#[source] E), + /// Error originated from echo-broadcast sub-protocol #[error("echo broadcast")] Echo(#[from] EchoError), } @@ -73,25 +88,25 @@ impl From for Error { } } +/// An error returned in [round completion](super::WithEchoBroadcast::complete) #[derive(thiserror::Error, Debug)] -#[error(transparent)] -pub struct CompleteRoundError( - #[from] CompleteRoundReason, -); - -#[derive(thiserror::Error, Debug)] -pub(super) enum CompleteRoundReason { +pub enum CompleteRoundError { + /// Error occurred while handling received message(s) #[error(transparent)] CompleteRound(CompleteErr), + /// Error occurred while sending a message to another party + /// + /// The only message we send during round completion is echo message #[error(transparent)] Send(SendErr), + /// Echo broadcast sub-protocol error #[error(transparent)] - Echo(Reason), + Echo(EchoError), } impl From for CompleteRoundError { fn from(err: Reason) -> Self { - CompleteRoundError(CompleteRoundReason::Echo(err)) + CompleteRoundError::Echo(err.into()) } } diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index 36840f8..c709414 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -299,7 +299,7 @@ where .party .complete(round) .await - .map_err(error::CompleteRoundReason::CompleteRound)?; + .map_err(error::CompleteRoundError::CompleteRound)?; Ok(output) } Inner::WithReliabilityCheck { @@ -311,7 +311,7 @@ where .party .complete(main_round) .await - .map_err(error::CompleteRoundReason::CompleteRound)?; + .map_err(error::CompleteRoundError::CompleteRound)?; // retrieve a msg that we sent in this round let sent_msg = if let Some(Some(msg)) = self.sent_reliable_msgs.remove(&Self::Msg::ROUND) { @@ -329,13 +329,13 @@ where hash, }) .await - .map_err(error::CompleteRoundReason::Send)?; + .map_err(error::CompleteRoundError::Send)?; // receive echoes from other parties let echoes = self .party .complete(echo_round) .await - .map_err(error::CompleteRoundReason::CompleteRound)?; + .map_err(error::CompleteRoundError::CompleteRound)?; // check that everyone sent the same hash let main_output = main_output.with_echo_output(echoes)?; From 6a6654e3db6e5c5d42e289f09c790b8a36eac2f4 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 30 May 2025 13:24:47 +0200 Subject: [PATCH 21/29] Do a small renaming Signed-off-by: Denis Varlakov --- round-based/src/echo_broadcast/error.rs | 2 +- round-based/src/echo_broadcast/mod.rs | 30 ++++++++++++------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index 4b3fe59..9b68dfa 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -88,7 +88,7 @@ impl From for Error { } } -/// An error returned in [round completion](super::WithEchoBroadcast::complete) +/// An error returned in [round completion](::complete) #[derive(thiserror::Error, Debug)] pub enum CompleteRoundError { /// Error occurred while handling received message(s) diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index c709414..a8b89b0 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -156,13 +156,13 @@ where } /// Wraps an [`Mpc`] engine and provides echo broadcast capabilities -pub fn wrap(party: M, i: u16, n: u16) -> WithReliableBroadcast +pub fn wrap(party: M, i: u16, n: u16) -> WithEchoBroadcast where D: Digest, M: Mpc>, MainMsg: udigest::Digestable, { - WithReliableBroadcast { + WithEchoBroadcast { party, i, n, @@ -172,7 +172,7 @@ where } /// [`Mpc`] engine with echo-broadcast capabilities -pub struct WithReliableBroadcast { +pub struct WithEchoBroadcast { party: M, i: u16, n: u16, @@ -180,10 +180,10 @@ pub struct WithReliableBroadcast { _ph: PhantomData, } -impl WithReliableBroadcast { - fn map_party

(self, f: impl FnOnce(M) -> P) -> WithReliableBroadcast { +impl WithEchoBroadcast { + fn map_party

(self, f: impl FnOnce(M) -> P) -> WithEchoBroadcast { let party = f(self.party); - WithReliableBroadcast { + WithEchoBroadcast { party, i: self.i, n: self.n, @@ -193,7 +193,7 @@ impl WithReliableBroadcast { } } -impl Mpc for WithReliableBroadcast +impl Mpc for WithEchoBroadcast where D: Digest + 'static, M: Mpc>, @@ -201,7 +201,7 @@ where { type Msg = MainMsg; - type Exec = WithReliableBroadcast; + type Exec = WithEchoBroadcast; type SendErr = error::Error; @@ -237,7 +237,7 @@ where } } -impl WithReliableBroadcast +impl WithEchoBroadcast where D: Digest, MainMsg: ProtocolMsg + Clone, @@ -271,7 +271,7 @@ where } } -impl MpcExecution for WithReliableBroadcast +impl MpcExecution for WithEchoBroadcast where D: Digest + 'static, M: MpcExecution>, @@ -282,7 +282,7 @@ where type CompleteRoundErr = error::CompleteRoundError>, M::SendErr>; type SendErr = error::Error; - type SendMany = WithReliableBroadcast; + type SendMany = WithEchoBroadcast; async fn complete( &mut self, @@ -364,7 +364,7 @@ where /// Round registration witness /// -/// Returned by [`WithReliableBroadcast::add_round()`] +/// Returned by [`WithEchoBroadcast::add_round()`] pub struct Round(Inner) where M: MpcExecution, @@ -387,13 +387,13 @@ where }, } -impl crate::mpc::SendMany for WithReliableBroadcast +impl crate::mpc::SendMany for WithEchoBroadcast where D: Digest + 'static, M: crate::mpc::SendMany>, MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, { - type Exec = WithReliableBroadcast; + type Exec = WithEchoBroadcast; type Msg = MainMsg; type SendErr = error::Error; @@ -407,7 +407,7 @@ where async fn flush(self) -> Result { let party = self.party.flush().await.map_err(error::Error::Main)?; - Ok(WithReliableBroadcast { + Ok(WithEchoBroadcast { party, i: self.i, n: self.n, From e7f07ac6bb6dcc776e19b02a23d01b1e2ab47234 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Mon, 2 Jun 2025 11:44:00 +0200 Subject: [PATCH 22/29] Update docs Signed-off-by: Denis Varlakov --- README.md | 86 +++++++++++++++--- round-based/src/echo_broadcast/mod.rs | 46 +++++++++- round-based/src/lib.rs | 106 +++++++++++++++++++--- round-based/src/state_machine/delivery.rs | 2 +- 4 files changed, 208 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 94500f0..f481d16 100644 --- a/README.md +++ b/README.md @@ -13,25 +13,86 @@ multiparty protocols (e.g. threshold signing, random beacons, etc.). * Simple, configurable \ Protocol can be carried out in a few lines of code: check out examples. * Independent of networking layer \ - We use abstractions `Stream` and `Sink` to receive and send messages. + You may define your own networking layer and you don't need to change anything + in protocol implementation: it's agnostic of networking by default! So you can + use central delivery server, distributed redis nodes, postgres database, + p2p channels, or a public blockchain, or whatever else fits your needs. + +## Example +MPC protocol execution typically looks like this: + +```rust +// protocol to be executed, takes MPC engine `M`, index of party `i`, +// and number of participants `n` +async fn keygen(mpc: M, i: u16, n: u16) -> Result +where + M: round_based::Mpc +{ + // ... +} +// establishes network connection(s) to other parties so they may communicate +async fn connect() -> + impl futures::Stream>> + + futures::Sink, Error = Error> + + Unpin +{ + // ... +} +let delivery = connect().await; + +// constructs an MPC engine, which, primarily, is used to communicate with +// other parties +let mpc = round_based::mpc::connected(delivery); + +// execute the protocol +let keyshare = keygen(mpc, i, n).await?; +``` ## Networking In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to provide a channel which implements a stream and a sink for receiving and sending messages. -Message delivery should meet certain criterias that differ from protocol to protocol (refer to -the documentation of the protocol you're using), but usually they are: +```rust +async fn connect() -> + impl futures::Stream>> + + futures::Sink, Error = Error> + + Unpin +{ + // ... +} + +let delivery = connect().await; +let party = round_based::mpc::connected(delivery); + +// run the protocol +``` + +In order to guarantee the protocol security, it may require: + +* Message Authentication \ + Guarantees message source and integrity. If protocol requires it, make sure + message was sent by claimed sender and that it hasn't been tampered with. \ + This is typically achieved either through public-key cryptography (e.g., + signing with a private key) or through symmetric mechanisms like MACs (e.g., + HMAC) or authenticated encryption (AEAD) in point-to-point scenarios. +* Message Privacy \ + When a p2p message is sent, only recipient shall be able to read the content. \ + It can be achieved by using symmetric or asymmetric encryption, encryption methods + come with their own trade-offs (e.g. simplicity vs forward secrecy). +* Reliable Broadcast \ + When party receives a reliable broadcast message it shall be ensured that + everybody else received the same message. \ + Our library provides `echo_broadcast` support out-of-box that enforces broadcast + reliability by adding an extra communication round per each round that requires + reliable broadcast. \ + More advanced techniques implement [Byzantine fault](https://en.wikipedia.org/wiki/Byzantine_fault) + tolerant broadcast -* Messages should be authenticated \ - Each message should be signed with identity key of the sender. This implies having Public Key - Infrastructure. -* P2P messages should be encrypted \ - Only recipient should be able to learn the content of p2p message -* Broadcast channel should be reliable \ - Some protocols may require broadcast channel to be reliable. Simply saying, when party receives a - broadcast message over reliable channel it should be ensured that everybody else received the same - message. +## Developing MPC protocol with `round_based` +We plan to write a book guiding through MPC protocol development process, but +while it's not done, you may refer to [random beacon example](https://github.com/LFDT-Lockness/round-based/blob/m/examples/random-generation-protocol/src/lib.rs) +and our well-documented API. ## Features @@ -40,6 +101,7 @@ the documentation of the protocol you're using), but usually they are: module * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync API, see `state_machine` module +* `echo-broadcast` adds `echo_broadcast` support * `derive` is needed to use `ProtocolMsg` proc macro * `runtime-tokio` enables tokio-specific implementation of async runtime diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index a8b89b0..3b1e33d 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -6,12 +6,52 @@ //! honest participants of the protocol has received the same message. //! //! One way to achieve the reliable broadcast is by adding an echo round: when we receive -//! messages in a reliable broadcast round, we hash all messages, and send the hash to all +//! messages in a reliable broadcast round, we hash all messages, and we send the hash to all //! other participants. If party receives a the same hash from everyone else, we can be //! assured that messages in the round were reliably broadcasted. //! //! This module provides a mechanism that automatically add an echo round per each //! round of the protocol that requires a reliable broadcast. +//! +//! ## Example +//! +//! ```rust +//! # #[derive(round_based::ProtocolMsg)] +//! # enum KeygenMsg {} +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! // protocol to be executed that **requires** reliable broadcast +//! async fn keygen(mpc: M, i: u16, n: u16) -> Result +//! where +//! M: round_based::Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! // establishes network connection(s) to other parties, but +//! // **does not** support reliable broadcast +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! let delivery = connect().await; +//! +//! # let (i, n) = (1, 3); +//! // constructs an MPC engine as usual +//! let mpc = round_based::mpc::connected(delivery); +//! // wrap an engine to add reliable broadcast support +//! let mpc = round_based::echo_broadcast::wrap(mpc, i, n); +//! +//! // execute the protocol +//! let keyshare = keygen(mpc, i, n).await?; +//! # Ok(()) } +//! ``` use core::marker::PhantomData; @@ -362,9 +402,7 @@ where } } -/// Round registration witness -/// -/// Returned by [`WithEchoBroadcast::add_round()`] +/// Round registration witness returned by [`WithEchoBroadcast::add_round()`] pub struct Round(Inner) where M: MpcExecution, diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index 8991bd8..5134598 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -13,25 +13,103 @@ //! * Simple, configurable \ //! Protocol can be carried out in a few lines of code: check out examples. //! * Independent of networking layer \ -//! We use abstractions [`Stream`] and [`Sink`] to receive and send messages. +//! You may define your own networking layer and you don't need to change anything +//! in protocol implementation: it's agnostic of networking by default! So you can +//! use central delivery server, distributed redis nodes, postgres database, +//! p2p channels, or a public blockchain, or whatever else fits your needs. +//! +//! ## Example +//! MPC protocol execution typically looks like this: +//! +//! ```rust +//! # #[derive(round_based::ProtocolMsg)] +//! # enum KeygenMsg {} +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! // protocol to be executed, takes MPC engine `M`, index of party `i`, +//! // and number of participants `n` +//! async fn keygen(mpc: M, i: u16, n: u16) -> Result +//! where +//! M: round_based::Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! // establishes network connection(s) to other parties so they may communicate +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! let delivery = connect().await; +//! +//! // constructs an MPC engine, which, primarily, is used to communicate with +//! // other parties +//! let mpc = round_based::mpc::connected(delivery); +//! +//! # let (i, n) = (1, 3); +//! // execute the protocol +//! let keyshare = keygen(mpc, i, n).await?; +//! # Ok(()) } +//! ``` //! //! ## Networking //! //! In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to //! provide a channel which implements a stream and a sink for receiving and sending messages. //! -//! Message delivery should meet certain criterias that differ from protocol to protocol (refer to -//! the documentation of the protocol you're using), but usually they are: +//! ```rust,no_run +//! # #[derive(round_based::ProtocolMsg)] +//! # enum Msg {} +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! +//! let delivery = connect().await; +//! let party = round_based::mpc::connected(delivery); //! -//! * Messages should be authenticated \ -//! Each message should be signed with identity key of the sender. This implies having Public Key -//! Infrastructure. -//! * P2P messages should be encrypted \ -//! Only recipient should be able to learn the content of p2p message -//! * Broadcast channel should be reliable \ -//! Some protocols may require broadcast channel to be reliable. Simply saying, when party receives a -//! broadcast message over reliable channel it should be ensured that everybody else received the same -//! message. +//! // run the protocol +//! # Ok(()) } +//! ``` +//! +//! In order to guarantee the protocol security, it may require: +//! +//! * Message Authentication \ +//! Guarantees message source and integrity. If protocol requires it, make sure +//! message was sent by claimed sender and that it hasn't been tampered with. \ +//! This is typically achieved either through public-key cryptography (e.g., +//! signing with a private key) or through symmetric mechanisms like MACs (e.g., +//! HMAC) or authenticated encryption (AEAD) in point-to-point scenarios. +//! * Message Privacy \ +//! When a p2p message is sent, only recipient shall be able to read the content. \ +//! It can be achieved by using symmetric or asymmetric encryption, encryption methods +//! come with their own trade-offs (e.g. simplicity vs forward secrecy). +//! * Reliable Broadcast \ +//! When party receives a reliable broadcast message it shall be ensured that +//! everybody else received the same message. \ +//! Our library provides [`echo_broadcast`] support out-of-box that enforces broadcast +//! reliability by adding an extra communication round per each round that requires +//! reliable broadcast. \ +//! More advanced techniques implement [Byzantine fault](https://en.wikipedia.org/wiki/Byzantine_fault) +//! tolerant broadcast +//! +//! ## Developing MPC protocol with `round_based` +//! We plan to write a book guiding through MPC protocol development process, but +//! while it's not done, you may refer to [random beacon example](https://github.com/LFDT-Lockness/round-based/blob/m/examples/random-generation-protocol/src/lib.rs) +//! and our well-documented API. //! //! ## Features //! @@ -40,6 +118,7 @@ //! module //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync //! API, see [`state_machine`] module +//! * `echo-broadcast` adds [`echo_broadcast`] support //! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro //! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](mpc::party::runtime) //! @@ -53,9 +132,6 @@ extern crate alloc; -#[doc(no_inline)] -pub use futures_util::{Sink, SinkExt, Stream, StreamExt}; - /// Fixes false-positive of `unused_crate_dependencies` lint that only occur in the tests #[cfg(test)] mod false_positives { diff --git a/round-based/src/state_machine/delivery.rs b/round-based/src/state_machine/delivery.rs index ef92fcb..9a710ff 100644 --- a/round-based/src/state_machine/delivery.rs +++ b/round-based/src/state_machine/delivery.rs @@ -26,7 +26,7 @@ impl futures_util::Stream for Delivery { } } -impl crate::Sink> for Delivery { +impl futures_util::Sink> for Delivery { type Error = DeliveryErr; fn poll_ready( From f2faef49123e4d46ecacad7efab27435efb69dcd Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 6 Jun 2025 12:12:44 +0200 Subject: [PATCH 23/29] Fix docs Signed-off-by: Denis Varlakov --- Cargo.lock | 1 + round-based/Cargo.toml | 2 ++ round-based/src/echo_broadcast/mod.rs | 8 +++++--- round-based/src/lib.rs | 2 ++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 909ee6b..0c2c921 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -477,6 +477,7 @@ dependencies = [ "rand", "rand_dev", "round-based-derive", + "sha2", "thiserror", "tokio", "tokio-stream", diff --git a/round-based/Cargo.toml b/round-based/Cargo.toml index 53489cc..dacd8ba 100644 --- a/round-based/Cargo.toml +++ b/round-based/Cargo.toml @@ -44,6 +44,8 @@ rand_dev = "0.1" anyhow = "1" +sha2 = "0.10" + [features] default = [] state-machine = [] diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index 3b1e33d..c221010 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -16,7 +16,7 @@ //! ## Example //! //! ```rust -//! # #[derive(round_based::ProtocolMsg)] +//! # #[derive(round_based::ProtocolMsg, Clone, udigest::Digestable)] //! # enum KeygenMsg {} //! # struct KeyShare; //! # struct Error; @@ -30,11 +30,13 @@ //! // ... //! # unimplemented!() //! } +//! // The full message type, which corresponds to keygen msg + echo broadcast msg +//! type Msg = round_based::echo_broadcast::Msg; //! // establishes network connection(s) to other parties, but //! // **does not** support reliable broadcast //! async fn connect() -> -//! impl futures::Stream>> -//! + futures::Sink, Error = Error> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> //! + Unpin //! { //! // ... diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index 5134598..b6859b0 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -140,6 +140,8 @@ mod false_positives { use trybuild as _; use {hex as _, rand as _, rand_dev as _}; + + use sha2 as _; } mod delivery; From 0e757109a0cdb16f0106d705f2f51b8eecf89649 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 6 Jun 2025 12:27:10 +0200 Subject: [PATCH 24/29] Fix broken link Signed-off-by: Denis Varlakov --- round-based/src/echo_broadcast/error.rs | 2 +- round-based/src/echo_broadcast/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs index 9b68dfa..093502d 100644 --- a/round-based/src/echo_broadcast/error.rs +++ b/round-based/src/echo_broadcast/error.rs @@ -88,7 +88,7 @@ impl From for Error { } } -/// An error returned in [round completion](::complete) +/// An error returned in round completion #[derive(thiserror::Error, Debug)] pub enum CompleteRoundError { /// Error occurred while handling received message(s) diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index c221010..d9e2861 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -47,7 +47,7 @@ //! # let (i, n) = (1, 3); //! // constructs an MPC engine as usual //! let mpc = round_based::mpc::connected(delivery); -//! // wrap an engine to add reliable broadcast support +//! // wraps an engine to add reliable broadcast support //! let mpc = round_based::echo_broadcast::wrap(mpc, i, n); //! //! // execute the protocol From dcf0c5fdec0bfc8954140b643a062ddc928b09ac Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 13 Jun 2025 10:04:01 +0200 Subject: [PATCH 25/29] Rename `finish` to `finis_setup` Signed-off-by: Denis Varlakov --- examples/random-generation-protocol/src/lib.rs | 2 +- round-based/src/echo_broadcast/mod.rs | 4 ++-- round-based/src/mpc/mod.rs | 4 ++-- round-based/src/mpc/party/mod.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index 33a52df..cbd59ff 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -65,7 +65,7 @@ where // Define rounds let round1 = mpc.add_round(round_based::round::reliable_broadcast::(i, n)); let round2 = mpc.add_round(round_based::round::broadcast::(i, n)); - let mut mpc = mpc.finish(); + let mut mpc = mpc.finish_setup(); // --- The Protocol --- diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index d9e2861..c519dcd 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -274,8 +274,8 @@ where } } - fn finish(self) -> Self::Exec { - self.map_party(|p| p.finish()) + fn finish_setup(self) -> Self::Exec { + self.map_party(|p| p.finish_setup()) } } diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index ae3b5fd..a5bcf97 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -61,11 +61,11 @@ pub trait Mpc { R: RoundStore, Self::Msg: RoundMsg; - /// Indicates that network setup is complete + /// Completes network setup /// /// Once this method is called, no more rounds can be added, /// but the protocol can receive and send messages. - fn finish(self) -> Self::Exec; + fn finish_setup(self) -> Self::Exec; } /// Abstracts functionalities needed for MPC protocol execution diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs index c270610..866b0f7 100644 --- a/round-based/src/mpc/party/mod.rs +++ b/round-based/src/mpc/party/mod.rs @@ -90,7 +90,7 @@ where self.router.add_round(round) } - fn finish(self) -> Self::Exec { + fn finish_setup(self) -> Self::Exec { MpcParty { router: self.router, io: self.io, From 7ac859b10dfa8b7159332c953b2ea05e62ed43ec Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 13 Jun 2025 10:05:56 +0200 Subject: [PATCH 26/29] Update docs Signed-off-by: Denis Varlakov --- round-based/src/mpc/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index a5bcf97..3bdbd3e 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -50,7 +50,7 @@ pub trait Mpc { /// Protocol message type Msg; - /// Returned in [`Self::finish`] + /// Returned in [`Self::finish_setup`] type Exec: MpcExecution; /// Error indicating that sending a message has failed type SendErr; From eeffe155074ce74e4d5b5ec201ca7f641cd5157e Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 13 Jun 2025 10:13:06 +0200 Subject: [PATCH 27/29] Add comment Signed-off-by: Denis Varlakov --- round-based-tests/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/round-based-tests/src/lib.rs b/round-based-tests/src/lib.rs index 46540a8..6123b29 100644 --- a/round-based-tests/src/lib.rs +++ b/round-based-tests/src/lib.rs @@ -1,5 +1,8 @@ use round_based::{state_machine::ProceedResult, Incoming, Outgoing}; +/// Wraps a state machine and provides convenient methods for feeding to and receiving messages from +/// the state machine, removing a boilerplate for handling `Yield`-ing, and providing convenient +/// methods for output assertions like `output.expect_eq()` pub struct PartySim(S); pub fn new_one_party_sim<'a, M, F>( From 2501d4e92ad5d12f663c8143c912cec9e9f0f3ba Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 13 Jun 2025 10:24:02 +0200 Subject: [PATCH 28/29] Add more comments Signed-off-by: Denis Varlakov --- round-based-tests/src/lib.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/round-based-tests/src/lib.rs b/round-based-tests/src/lib.rs index 6123b29..0f46494 100644 --- a/round-based-tests/src/lib.rs +++ b/round-based-tests/src/lib.rs @@ -5,6 +5,7 @@ use round_based::{state_machine::ProceedResult, Incoming, Outgoing}; /// methods for output assertions like `output.expect_eq()` pub struct PartySim(S); +/// Wraps a state machine and returns [`PartySim`] pub fn new_one_party_sim<'a, M, F>( protocol: impl FnOnce(round_based::state_machine::MpcParty) -> F, ) -> PartySim + 'a> @@ -16,6 +17,10 @@ where } impl PartySim { + /// Feeds an incoming message to the state machine + /// + /// State machine **must be** in the state waiting for the incoming message. If state + /// machine is in other state (e.g. wants to send a message), this function will panic. #[track_caller] pub fn receives(&mut self, msg: Incoming) { loop { @@ -38,6 +43,10 @@ impl PartySim { .expect("couldn't feed a message into simulation") } + /// Retrieves an outgoing message from the state machine + /// + /// State machine **must be** in the state of sending a message. If state machine is in + /// other state (e.g. waits for an incoming message), this function will panic. #[track_caller] pub fn sends(&mut self) -> Expect> { loop { @@ -49,6 +58,10 @@ impl PartySim { } } + /// Retrieves state machine output + /// + /// State machine **must be** in the output state. If state machine is in other state (e.g. + /// waits for an incoming message), this function will panic. #[track_caller] pub fn outputs(&mut self) -> Expect { loop { @@ -61,10 +74,14 @@ impl PartySim { } } +/// Wraps `T` and allows to make assertions on it #[must_use = "you need to make sure the output meets tests expectations"] pub struct Expect(pub T); impl Expect { + /// Wrapped value must be equal to `expected` + /// + /// Panics if it's not #[track_caller] pub fn expect_eq(&self, expected: &T) { assert_eq!(self.0, *expected) @@ -72,12 +89,14 @@ impl Expect { } impl Expect> { + /// Unwraps a result #[track_caller] pub fn unwrap(self) -> Expect { Expect(self.0.unwrap()) } } impl Expect> { + /// Unwraps an error from result #[track_caller] pub fn unwrap_err(self) -> Expect { Expect(self.0.unwrap_err()) From 167b810ee64fac23a8b8e39ade107f31529d07e7 Mon Sep 17 00:00:00 2001 From: Denis Varlakov Date: Fri, 11 Jul 2025 16:38:10 +0200 Subject: [PATCH 29/29] Update docs Signed-off-by: Denis Varlakov --- round-based/src/echo_broadcast/mod.rs | 2 + round-based/src/mpc/mod.rs | 53 ++++++++++++++++++--------- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs index c519dcd..3598e6d 100644 --- a/round-based/src/echo_broadcast/mod.rs +++ b/round-based/src/echo_broadcast/mod.rs @@ -109,6 +109,8 @@ mod sub_msg { } } +// `D` doesn't implement traits like `Clone`, `Eq`, etc. so we have to implement those traits by hand + impl Clone for Msg { fn clone(&self) -> Self { match self { diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs index 3bdbd3e..afeb401 100644 --- a/round-based/src/mpc/mod.rs +++ b/round-based/src/mpc/mod.rs @@ -45,12 +45,19 @@ pub mod party; #[doc(no_inline)] pub use self::party::{Halves, MpcParty}; -/// Abstracts functionalities needed for MPC protocol execution +/// Abstracts functionalities needed for creating an MPC protocol execution. +/// +/// An object implementing this trait is accepted as a parameter of a protocol. It is used to +/// configure the protocol with [`Mpc::add_round`], and then finalized into a protocol executor +/// with [`Mpc::finish_setup`] pub trait Mpc { /// Protocol message type Msg; - /// Returned in [`Self::finish_setup`] + /// A type of a finalized instatiation of a party for this protocol + /// + /// After being created with [`Mpc::finish_setup`], you can use an object with this type to + /// drive the protocol execution using the methods of [`MpcExecution`]. type Exec: MpcExecution; /// Error indicating that sending a message has failed type SendErr; @@ -72,7 +79,8 @@ pub trait Mpc { pub trait MpcExecution { /// Witness that round was registered /// - /// It is used to retrieve messages in [`MpcExecution::complete`]. + /// It's obtained by registering round in [`Mpc::add_round`], which then can be used to retrieve + /// messages from associated round by calling [`MpcExecution::complete`]. type Round; /// Protocol message @@ -88,7 +96,8 @@ pub trait MpcExecution { /// Completes the round /// - /// Waits until all messages in the round `R` are received, returns the received messages. + /// Waits until we receive all the messages in the round `R` from other parties. Returns + /// received messages. async fn complete( &mut self, round: Self::Round, @@ -139,15 +148,25 @@ pub trait MpcExecution { /// Creates a buffer of outgoing messages so they can be sent all at once /// - /// When you have many messages that you want to send at once, using [`.send()`](Self::send) is not efficient, - /// as it will accumulate message delivery cost. Use this method to optimize sending many messages at once. + /// When you have many messages that you want to send at once, using [`.send()`](Self::send) + /// may be inefficient, as delivery implementation may pause the execution until the message is + /// received by the recipient. Use this method to send many messages in a batch. + /// + /// This method takes ownership of `self` to create the [impl SendMany](SendMany) object. After + /// enqueueing all the messages, you need to reclaim `self` back by calling [`SendMany::flush`]. fn send_many(self) -> Self::SendMany; - /// Yields execution + /// Yields execution back to the async runtime + /// + /// Used in MPC protocols with many heavy synchronous computations. The protocol implementors + /// can manually insert yield points to ease the CPU contention async fn yield_now(&self); } /// Buffer, optimized for sending many messages at once +/// +/// It's obtained by calling [`MpcExecution::send_many`], which takes ownership of `MpcExecution`. To reclaim +/// ownership, call [`SendMany::flush`]. pub trait SendMany { /// MPC executor, returned after successful [`.flush()`](Self::flush) type Exec: MpcExecution; @@ -161,8 +180,8 @@ pub trait SendMany { /// Similar to [`MpcExecution::send`], but possibly buffers a message until [`.flush()`](Self::flush) is /// called. /// - /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed - /// until [`.flush()`](Self::flush) is called. + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; /// Adds a p2p message to the sending queue @@ -170,8 +189,8 @@ pub trait SendMany { /// Similar to [`MpcExecution::send_p2p`], but possibly buffers a message until [`.flush()`](Self::flush) is /// called. /// - /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed - /// until [`.flush()`](Self::flush) is called. + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). async fn send_p2p( &mut self, recipient: PartyIndex, @@ -185,8 +204,8 @@ pub trait SendMany { /// Similar to [`MpcExecution::send_to_all`], but possibly buffers a message until [`.flush()`](Self::flush) is /// called. /// - /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed - /// until [`.flush()`](Self::flush) is called. + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { self.send(Outgoing::all_parties(msg)).await } @@ -196,8 +215,8 @@ pub trait SendMany { /// Similar to [`MpcExecution::reliably_broadcast`], but possibly buffers a message until [`.flush()`](Self::flush) is /// called. /// - /// Message may be sent within the call (e.g. if internal buffer is full), but no sending is guaranteed - /// until [`.flush()`](Self::flush) is called. + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`Self::flush`] async fn reliably_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { self.send(Outgoing::reliable_broadcast(msg)).await } @@ -213,7 +232,7 @@ pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRo /// /// MPC protocols typically consist of several rounds, each round has differently typed message. /// `ProtocolMsg` and [`RoundMsg`] traits are used to examine received message: `ProtocolMsg::round` -/// determines which round message belongs to, and then `RoundMessage` trait can be used to retrieve +/// determines which round message belongs to, and then `RoundMsg` trait can be used to retrieve /// actual round-specific message. /// /// You should derive these traits using proc macro (requires `derive` feature): @@ -280,7 +299,7 @@ pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRo /// } /// ``` pub trait ProtocolMsg: Sized { - /// Number of round this message originates from + /// Number of the round that this message originates from fn round(&self) -> u16; }