diff --git a/Cargo.lock b/Cargo.lock index a86357a..0c2c921 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -459,26 +459,31 @@ dependencies = [ "sha2", "thiserror", "tokio", + "udigest", ] [[package]] name = "round-based" -version = "0.4.0" +version = "0.4.1" dependencies = [ "anyhow", + "digest", "futures", "futures-util", "hex", "matches", "phantom-type", + "pin-project-lite", "rand", "rand_dev", "round-based-derive", + "sha2", "thiserror", "tokio", "tokio-stream", "tracing", "trybuild", + "udigest", ] [[package]] @@ -494,12 +499,15 @@ dependencies = [ name = "round-based-tests" version = "0.1.0" dependencies = [ + "anyhow", "futures", + "hex", "hex-literal", "matches", "rand_chacha", "random-generation-protocol", "round-based", + "sha2", "tokio", ] @@ -744,6 +752,27 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "udigest" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cd61fa9fb78569e9fe34acf0048fd8cb9ebdbacc47af740745487287043ff0" +dependencies = [ + "digest", + "udigest-derive", +] + +[[package]] +name = "udigest-derive" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "603329303137e0d59238ee4d6b9c085eada8e2a9d20666f3abd9dadf8f8543f4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/README.md b/README.md index 5308be4..f481d16 100644 --- a/README.md +++ b/README.md @@ -13,25 +13,86 @@ multiparty protocols (e.g. threshold signing, random beacons, etc.). * Simple, configurable \ Protocol can be carried out in a few lines of code: check out examples. * Independent of networking layer \ - We use abstractions `Stream` and `Sink` to receive and send messages. + You may define your own networking layer and you don't need to change anything + in protocol implementation: it's agnostic of networking by default! So you can + use central delivery server, distributed redis nodes, postgres database, + p2p channels, or a public blockchain, or whatever else fits your needs. + +## Example +MPC protocol execution typically looks like this: + +```rust +// protocol to be executed, takes MPC engine `M`, index of party `i`, +// and number of participants `n` +async fn keygen(mpc: M, i: u16, n: u16) -> Result +where + M: round_based::Mpc +{ + // ... +} +// establishes network connection(s) to other parties so they may communicate +async fn connect() -> + impl futures::Stream>> + + futures::Sink, Error = Error> + + Unpin +{ + // ... +} +let delivery = connect().await; + +// constructs an MPC engine, which, primarily, is used to communicate with +// other parties +let mpc = round_based::mpc::connected(delivery); + +// execute the protocol +let keyshare = keygen(mpc, i, n).await?; +``` ## Networking In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to -implement `Delivery` trait which is basically a stream and a sink for receiving and sending messages. +provide a channel which implements a stream and a sink for receiving and sending messages. + +```rust +async fn connect() -> + impl futures::Stream>> + + futures::Sink, Error = Error> + + Unpin +{ + // ... +} + +let delivery = connect().await; +let party = round_based::mpc::connected(delivery); + +// run the protocol +``` + +In order to guarantee the protocol security, it may require: -Message delivery should meet certain criterias that differ from protocol to protocol (refer to -the documentation of the protocol you're using), but usually they are: +* Message Authentication \ + Guarantees message source and integrity. If protocol requires it, make sure + message was sent by claimed sender and that it hasn't been tampered with. \ + This is typically achieved either through public-key cryptography (e.g., + signing with a private key) or through symmetric mechanisms like MACs (e.g., + HMAC) or authenticated encryption (AEAD) in point-to-point scenarios. +* Message Privacy \ + When a p2p message is sent, only recipient shall be able to read the content. \ + It can be achieved by using symmetric or asymmetric encryption, encryption methods + come with their own trade-offs (e.g. simplicity vs forward secrecy). +* Reliable Broadcast \ + When party receives a reliable broadcast message it shall be ensured that + everybody else received the same message. \ + Our library provides `echo_broadcast` support out-of-box that enforces broadcast + reliability by adding an extra communication round per each round that requires + reliable broadcast. \ + More advanced techniques implement [Byzantine fault](https://en.wikipedia.org/wiki/Byzantine_fault) + tolerant broadcast -* Messages should be authenticated \ - Each message should be signed with identity key of the sender. This implies having Public Key - Infrastructure. -* P2P messages should be encrypted \ - Only recipient should be able to learn the content of p2p message -* Broadcast channel should be reliable \ - Some protocols may require broadcast channel to be reliable. Simply saying, when party receives a - broadcast message over reliable channel it should be ensured that everybody else received the same - message. +## Developing MPC protocol with `round_based` +We plan to write a book guiding through MPC protocol development process, but +while it's not done, you may refer to [random beacon example](https://github.com/LFDT-Lockness/round-based/blob/m/examples/random-generation-protocol/src/lib.rs) +and our well-documented API. ## Features @@ -39,8 +100,9 @@ the documentation of the protocol you're using), but usually they are: * `sim-async` enables protocol execution simulation with tokio runtime, see `sim::async_env` module * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync - API, see `state_machine` module -* `derive` is needed to use `ProtocolMessage` proc macro + API, see `state_machine` module +* `echo-broadcast` adds `echo_broadcast` support +* `derive` is needed to use `ProtocolMsg` proc macro * `runtime-tokio` enables tokio-specific implementation of async runtime ## Join us in Discord! diff --git a/examples/random-generation-protocol/Cargo.toml b/examples/random-generation-protocol/Cargo.toml index ddcec5e..edd0188 100644 --- a/examples/random-generation-protocol/Cargo.toml +++ b/examples/random-generation-protocol/Cargo.toml @@ -18,6 +18,11 @@ thiserror = { version = "2", default-features = false } # We don't use it directy, but we need to enable `serde` feature generic-array = { version = "0.14", features = ["serde"] } +udigest = { version = "0.2", default-features = false, features = ["derive"], optional = true } + +[features] +udigest = ["dep:udigest"] + [dev-dependencies] round-based = { path = "../../round-based", features = ["derive", "sim", "state-machine"] } tokio = { version = "1.15", features = ["macros", "rt"] } diff --git a/examples/random-generation-protocol/src/lib.rs b/examples/random-generation-protocol/src/lib.rs index 4f30ded..cbd59ff 100644 --- a/examples/random-generation-protocol/src/lib.rs +++ b/examples/random-generation-protocol/src/lib.rs @@ -18,14 +18,14 @@ use alloc::{vec, vec::Vec}; use serde::{Deserialize, Serialize}; use sha2::{digest::Output, Digest, Sha256}; -use round_based::rounds_router::{ - simple_store::{RoundInput, RoundInputError}, - CompleteRoundError, RoundsRouter, +use round_based::{ + mpc::{Mpc, MpcExecution}, + MsgId, }; -use round_based::{Delivery, Mpc, MpcParty, MsgId, Outgoing, PartyIndex, ProtocolMessage, SinkExt}; /// Protocol message -#[derive(Clone, Debug, PartialEq, ProtocolMessage, Serialize, Deserialize)] +#[derive(round_based::ProtocolMsg, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub enum Msg { /// Round 1 CommitMsg(CommitMsg), @@ -34,38 +34,38 @@ pub enum Msg { } /// Message from round 1 -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub struct CommitMsg { /// Party commitment + #[cfg_attr(feature = "udigest", udigest(as_bytes))] pub commitment: Output, } /// Message from round 2 -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(feature = "udigest", derive(udigest::Digestable))] pub struct DecommitMsg { /// Randomness generated by party + #[cfg_attr(feature = "udigest", udigest(as_bytes))] pub randomness: [u8; 32], } /// Carries out the randomness generation protocol pub async fn protocol_of_random_generation( - party: M, - i: PartyIndex, + mut mpc: M, + i: u16, n: u16, mut rng: R, -) -> Result<[u8; 32], Error> +) -> Result<[u8; 32], ErrorM> where - M: Mpc, + M: Mpc, R: rand_core::RngCore, { - let MpcParty { delivery, .. } = party.into_party(); - let (incoming, mut outgoing) = delivery.split(); - // Define rounds - let mut rounds = RoundsRouter::::builder(); - let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); - let round2 = rounds.add_round(RoundInput::::broadcast(i, n)); - let mut rounds = rounds.listen(incoming); + let round1 = mpc.add_round(round_based::round::reliable_broadcast::(i, n)); + let round2 = mpc.add_round(round_based::round::broadcast::(i, n)); + let mut mpc = mpc.finish_setup(); // --- The Protocol --- @@ -74,33 +74,26 @@ where rng.fill_bytes(&mut local_randomness); // 2. Commit local randomness (broadcast m=sha256(randomness)) + // This message must be reliably broadcasted to guarantee protocol security let commitment = Sha256::digest(local_randomness); - outgoing - .send(Outgoing::broadcast(Msg::CommitMsg(CommitMsg { - commitment, - }))) + mpc.reliably_broadcast(Msg::CommitMsg(CommitMsg { commitment })) .await .map_err(Error::Round1Send)?; // 3. Receive committed randomness from other parties - let commitments = rounds - .complete(round1) - .await - .map_err(Error::Round1Receive)?; + let commitments = mpc.complete(round1).await.map_err(Error::Round1Receive)?; // 4. Open local randomness - outgoing - .send(Outgoing::broadcast(Msg::DecommitMsg(DecommitMsg { - randomness: local_randomness, - }))) - .await - .map_err(Error::Round2Send)?; + // This message will be sent to all other participants, but it doesn't require + // a reliable broadcast + mpc.send_to_all(Msg::DecommitMsg(DecommitMsg { + randomness: local_randomness, + })) + .await + .map_err(Error::Round2Send)?; // 5. Receive opened local randomness from other parties, verify them, and output protocol randomness - let randomness = rounds - .complete(round2) - .await - .map_err(Error::Round2Receive)?; + let randomness = mpc.complete(round2).await.map_err(Error::Round2Receive)?; let mut guilty_parties = vec![]; let mut output = local_randomness; @@ -139,13 +132,13 @@ pub enum Error { Round1Send(#[source] SendErr), /// Couldn't receive a message in the first round #[error("receive messages at round 1")] - Round1Receive(#[source] CompleteRoundError), + Round1Receive(#[source] RecvErr), /// Couldn't send a message in the second round #[error("send a message at round 2")] Round2Send(#[source] SendErr), /// Couldn't receive a message in the second round #[error("receive messages at round 2")] - Round2Receive(#[source] CompleteRoundError), + Round2Receive(#[source] RecvErr), /// Some of the parties cheated #[error("malicious parties: {guilty_parties:?}")] @@ -155,11 +148,17 @@ pub enum Error { }, } +/// Error type deduced from `M: Mpc` +pub type ErrorM = Error< + round_based::mpc::CompleteRoundErr, + ::SendErr, +>; + /// Blames a party in cheating during the protocol #[derive(Debug)] pub struct Blame { /// Index of the cheated party - pub guilty_party: PartyIndex, + pub guilty_party: u16, /// ID of the message that party sent in the first round pub commitment_msg: MsgId, /// ID of the message that party sent in the second round @@ -251,7 +250,7 @@ mod tests { .received_msg(Incoming { id: 0, sender: 1, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: party1_com, }), @@ -265,7 +264,7 @@ mod tests { .received_msg(Incoming { id: 1, sender: 2, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: party2_com, }), @@ -298,7 +297,7 @@ mod tests { .received_msg(Incoming { id: 3, sender: 1, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: party1_rng, }), @@ -312,7 +311,7 @@ mod tests { .received_msg(Incoming { id: 3, sender: 2, - msg_type: round_based::MessageType::Broadcast, + msg_type: round_based::MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: party2_rng, }), diff --git a/round-based-derive/src/lib.rs b/round-based-derive/src/lib.rs index 4b62029..c701e65 100644 --- a/round-based-derive/src/lib.rs +++ b/round-based-derive/src/lib.rs @@ -6,18 +6,18 @@ use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Token, Variant}; -#[proc_macro_derive(ProtocolMessage, attributes(protocol_message))] -pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +#[proc_macro_derive(ProtocolMsg, attributes(protocol_msg))] +pub fn protocol_msg(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let mut root = None; for attr in input.attrs { - if !attr.path.is_ident("protocol_message") { + if !attr.path.is_ident("protocol_msg") { continue; } if root.is_some() { - return quote_spanned! { attr.path.span() => compile_error!("#[protocol_message] attribute appears more than once"); }.into(); + return quote_spanned! { attr.path.span() => compile_error!("#[protocol_msg] attribute appears more than once"); }.into(); } let tokens = attr.tokens.into(); root = Some(parse_macro_input!(tokens as RootAttribute)); @@ -30,10 +30,10 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre let enum_data = match input.data { Data::Enum(e) => e, Data::Struct(s) => { - return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMessage");}.into() + return quote_spanned! {s.struct_token.span => compile_error!("only enum may implement ProtocolMsg");}.into() } Data::Union(s) => { - return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMessage");}.into() + return quote_spanned! {s.union_token.span => compile_error!("only enum may implement ProtocolMsg");}.into() } }; @@ -46,15 +46,15 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre quote! { match *self {} } }; - let impl_protocol_message = quote! { - impl #impl_generics #root_path::ProtocolMessage for #name #ty_generics #where_clause { + let impl_protocol_msg = quote! { + impl #impl_generics #root_path::ProtocolMsg for #name #ty_generics #where_clause { fn round(&self) -> u16 { #round_method_impl } } }; - let impl_round_message = round_messages( + let impl_round_msg = round_msgs( &root_path, &name, &input.generics, @@ -62,8 +62,8 @@ pub fn protocol_message(input: proc_macro::TokenStream) -> proc_macro::TokenStre ); proc_macro::TokenStream::from(quote! { - #impl_protocol_message - #impl_round_message + #impl_protocol_msg + #impl_round_msg }) } @@ -73,11 +73,11 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator quote_spanned! { variant.ident.span() => - #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMessage"), + #enum_name::#variant_name => compile_error!("unit variants are not allowed in ProtocolMsg"), }, Fields::Named(_) => quote_spanned! { variant.ident.span() => - #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMessage"), + #enum_name::#variant_name{..} => compile_error!("named variants are not allowed in ProtocolMsg"), }, Fields::Unnamed(unnamed) => if unnamed.unnamed.len() == 1 { quote_spanned! { @@ -87,7 +87,7 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator - #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMessage"), + #enum_name::#variant_name(..) => compile_error!("this variant must contain exactly one field to be valid ProtocolMsg"), } }, } @@ -99,7 +99,7 @@ fn round_method<'v>(enum_name: &Ident, variants: impl Iterator( +fn round_msgs<'v>( root_path: &RootPath, enum_name: &Ident, generics: &Generics, @@ -113,16 +113,16 @@ fn round_messages<'v>( let msg_type = &unnamed.unnamed[0].ty; quote_spanned! { variant.ident.span() => - impl #impl_generics #root_path::RoundMessage<#msg_type> for #enum_name #ty_generics #where_clause { + impl #impl_generics #root_path::RoundMsg<#msg_type> for #enum_name #ty_generics #where_clause { const ROUND: u16 = #i; - fn to_protocol_message(round_message: #msg_type) -> Self { - #enum_name::#variant_name(round_message) + fn to_protocol_msg(round_msg: #msg_type) -> Self { + #enum_name::#variant_name(round_msg) } - fn from_protocol_message(protocol_message: Self) -> Result<#msg_type, Self> { + fn from_protocol_msg(protocol_msg: Self) -> Result<#msg_type, Self> { #[allow(unreachable_patterns)] - match protocol_message { + match protocol_msg { #enum_name::#variant_name(msg) => Ok(msg), - _ => Err(protocol_message), + _ => Err(protocol_msg), } } } diff --git a/round-based-tests/Cargo.toml b/round-based-tests/Cargo.toml index dfb0ac6..0e7e65e 100644 --- a/round-based-tests/Cargo.toml +++ b/round-based-tests/Cargo.toml @@ -7,13 +7,17 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +round-based = { path = "../round-based", features = ["echo-broadcast", "state-machine"] } + +anyhow = "1" [dev-dependencies] tokio = { version = "1", features = ["rt", "macros"] } hex-literal = "0.3" +hex = "0.4" futures = "0.3" matches = "0.1" rand_chacha = "0.3" +sha2 = "0.10" -round-based = { path = "../round-based" } -random-generation-protocol = { path = "../examples/random-generation-protocol" } +random-generation-protocol = { path = "../examples/random-generation-protocol", features = ["udigest"] } diff --git a/round-based-tests/src/lib.rs b/round-based-tests/src/lib.rs index 8b13789..0f46494 100644 --- a/round-based-tests/src/lib.rs +++ b/round-based-tests/src/lib.rs @@ -1 +1,104 @@ +use round_based::{state_machine::ProceedResult, Incoming, Outgoing}; +/// Wraps a state machine and provides convenient methods for feeding to and receiving messages from +/// the state machine, removing a boilerplate for handling `Yield`-ing, and providing convenient +/// methods for output assertions like `output.expect_eq()` +pub struct PartySim(S); + +/// Wraps a state machine and returns [`PartySim`] +pub fn new_one_party_sim<'a, M, F>( + protocol: impl FnOnce(round_based::state_machine::MpcParty) -> F, +) -> PartySim + 'a> +where + M: round_based::ProtocolMsg + 'static, + F: core::future::Future + 'a, +{ + PartySim(round_based::state_machine::wrap_protocol(protocol)) +} + +impl PartySim { + /// Feeds an incoming message to the state machine + /// + /// State machine **must be** in the state waiting for the incoming message. If state + /// machine is in other state (e.g. wants to send a message), this function will panic. + #[track_caller] + pub fn receives(&mut self, msg: Incoming) { + loop { + match self.0.proceed() { + ProceedResult::NeedsOneMoreMessage => { + // that's exactly what we want + break; + } + ProceedResult::Yielded => { + // Protocol yielded, we ignore that + continue; + } + // everything else is unexpected + r => panic!("state machine proceed: expected NeedsOneMoreMessage, got {r:?}"), + } + } + self.0 + .received_msg(msg) + .ok() + .expect("couldn't feed a message into simulation") + } + + /// Retrieves an outgoing message from the state machine + /// + /// State machine **must be** in the state of sending a message. If state machine is in + /// other state (e.g. waits for an incoming message), this function will panic. + #[track_caller] + pub fn sends(&mut self) -> Expect> { + loop { + match self.0.proceed() { + ProceedResult::SendMsg(m) => break Expect(m), + ProceedResult::Yielded => continue, + r => panic!("state machine proceed: expected SendMsg, got {r:?}"), + } + } + } + + /// Retrieves state machine output + /// + /// State machine **must be** in the output state. If state machine is in other state (e.g. + /// waits for an incoming message), this function will panic. + #[track_caller] + pub fn outputs(&mut self) -> Expect { + loop { + match self.0.proceed() { + ProceedResult::Output(r) => break Expect(r), + ProceedResult::Yielded => continue, + r => panic!("state machine proceed: expected Output, got {r:?}"), + } + } + } +} + +/// Wraps `T` and allows to make assertions on it +#[must_use = "you need to make sure the output meets tests expectations"] +pub struct Expect(pub T); + +impl Expect { + /// Wrapped value must be equal to `expected` + /// + /// Panics if it's not + #[track_caller] + pub fn expect_eq(&self, expected: &T) { + assert_eq!(self.0, *expected) + } +} + +impl Expect> { + /// Unwraps a result + #[track_caller] + pub fn unwrap(self) -> Expect { + Expect(self.0.unwrap()) + } +} +impl Expect> { + /// Unwraps an error from result + #[track_caller] + pub fn unwrap_err(self) -> Expect { + Expect(self.0.unwrap_err()) + } +} diff --git a/round-based-tests/tests/rounds.rs b/round-based-tests/tests/random_beacon.rs similarity index 68% rename from round-based-tests/tests/rounds.rs rename to round-based-tests/tests/random_beacon.rs index 00403d6..c2790f0 100644 --- a/round-based-tests/tests/rounds.rs +++ b/round-based-tests/tests/random_beacon.rs @@ -1,6 +1,6 @@ use std::convert::Infallible; -use futures::{sink, stream, Sink, Stream}; +use futures::{sink, stream, SinkExt}; use hex_literal::hex; use matches::assert_matches; use rand_chacha::rand_core::SeedableRng; @@ -8,9 +8,7 @@ use rand_chacha::rand_core::SeedableRng; use random_generation_protocol::{ protocol_of_random_generation, CommitMsg, DecommitMsg, Error, Msg, }; -use round_based::rounds_router::errors::IoError; -use round_based::rounds_router::{simple_store::RoundInput, CompleteRoundError, RoundsRouter}; -use round_based::{Delivery, Incoming, MessageType, MpcParty, Outgoing}; +use round_based::{mpc::party::CompleteRoundError, Incoming, MessageType}; const PARTY0_SEED: [u8; 32] = hex!("6772d079d5c984b3936a291e36b0d3dc6c474e36ed4afdfc973ef79a431ca870"); @@ -33,7 +31,7 @@ async fn random_generation_completes() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -41,7 +39,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -49,7 +47,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -57,7 +55,7 @@ async fn random_generation_completes() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -69,13 +67,37 @@ async fn random_generation_completes() { assert_eq!(output, PROTOCOL_OUTPUT); } +#[tokio::test] +async fn protocol_terminates_with_error_if_party_broadcasts_msg_unreliably_at_round1() { + let output = run_protocol([Ok::<_, Infallible>(Incoming { + id: 0, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + }), + })]) + .await; + + assert_matches!( + output, + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::MismatchedMessageType { + msg_id: 0, + expected: round_based::MessageType::Broadcast { reliable: true }, + actual: round_based::MessageType::Broadcast { reliable: false } + } + ))) + ) +} + #[tokio::test] async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_round1() { let output = run_protocol([ Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -83,7 +105,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 1, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY_OVERWRITES.into(), }), @@ -93,7 +115,12 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::AttemptToOverwriteReceivedMsg { + msgs_ids: [0, 1], + sender: 1 + } + ))) ) } @@ -103,7 +130,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -111,7 +138,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -119,7 +146,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -127,7 +154,7 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r Ok(Incoming { id: 3, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY_OVERWRITES, }), @@ -137,7 +164,12 @@ async fn protocol_terminates_with_error_if_party_tries_to_overwrite_message_at_r assert_matches!( output, - Err(Error::Round2Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round2Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::AttemptToOverwriteReceivedMsg { + msgs_ids: [2, 3], + sender: 1 + } + ))) ) } @@ -146,7 +178,7 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() let output = run_protocol([Ok::<_, Infallible>(Incoming { id: 0, sender: 3, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -155,7 +187,13 @@ async fn protocol_terminates_if_received_message_from_unknown_sender_at_round1() assert_matches!( output, - Err(Error::Round1Receive(CompleteRoundError::ProcessMessage(_))) + Err(Error::Round1Receive(CompleteRoundError::ProcessMsg( + round_based::round::RoundInputError::SenderIndexOutOfRange { + msg_id: 0, + sender: 3, + n: 3 + } + ))) ) } @@ -165,7 +203,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -173,7 +211,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -181,7 +219,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY_OVERWRITES.into(), }), @@ -189,7 +227,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 3, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -197,7 +235,7 @@ async fn protocol_ignores_message_that_goes_to_completed_round() { Ok(Incoming { id: 4, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -215,7 +253,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -223,7 +261,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -231,7 +269,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -239,7 +277,7 @@ async fn protocol_ignores_io_error_if_it_is_completed() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -258,7 +296,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -266,7 +304,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -274,7 +312,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -283,7 +321,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -291,7 +329,10 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round2() { ]) .await; - assert_matches!(output, Err(Error::Round2Receive(CompleteRoundError::Io(_)))); + assert_matches!( + output, + Err(Error::Round2Receive(CompleteRoundError::Io(DummyError))) + ); } #[tokio::test] @@ -301,7 +342,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -309,7 +350,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -317,7 +358,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -325,7 +366,7 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { Ok(Incoming { id: 3, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY2_RANDOMNESS, }), @@ -333,7 +374,10 @@ async fn protocol_terminates_with_error_if_io_error_happens_at_round1() { ]) .await; - assert_matches!(output, Err(Error::Round1Receive(CompleteRoundError::Io(_)))); + assert_matches!( + output, + Err(Error::Round1Receive(CompleteRoundError::Io(DummyError))) + ); } #[tokio::test] @@ -342,7 +386,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok::<_, Infallible>(Incoming { id: 0, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY1_COMMITMENT.into(), }), @@ -350,7 +394,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok(Incoming { id: 1, sender: 2, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg::CommitMsg(CommitMsg { commitment: PARTY2_COMMITMENT.into(), }), @@ -358,7 +402,7 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { Ok(Incoming { id: 2, sender: 1, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg::DecommitMsg(DecommitMsg { randomness: PARTY1_RANDOMNESS, }), @@ -368,31 +412,19 @@ async fn protocol_terminates_with_error_if_unexpected_eof_happens_at_round2() { assert_matches!( output, - Err(Error::Round2Receive(CompleteRoundError::Io( - IoError::UnexpectedEof - ))) + Err(Error::Round2Receive(CompleteRoundError::UnexpectedEof)) ); } -#[tokio::test] -async fn all_non_completed_rounds_are_terminated_with_unexpected_eof_error_if_incoming_channel_suddenly_closed( -) { - let mut rounds = RoundsRouter::builder(); - let round1 = rounds.add_round(RoundInput::::new(0, 3, MessageType::P2P)); - let round2 = rounds.add_round(RoundInput::::new(0, 3, MessageType::P2P)); - let mut rounds = rounds.listen(stream::empty::, Infallible>>()); - - assert_matches!( - rounds.complete(round1).await, - Err(CompleteRoundError::Io(IoError::UnexpectedEof)) - ); - assert_matches!( - rounds.complete(round2).await, - Err(CompleteRoundError::Io(IoError::UnexpectedEof)) - ); -} - -async fn run_protocol(incomings: I) -> Result<[u8; 32], Error> +async fn run_protocol( + incomings: I, +) -> Result< + [u8; 32], + random_generation_protocol::Error< + round_based::mpc::party::CompleteRoundError, + E, + >, +> where I: IntoIterator, E>>, I::IntoIter: Send + 'static, @@ -400,38 +432,13 @@ where { let rng = rand_chacha::ChaCha8Rng::from_seed(PARTY0_SEED); - let party = MpcParty::connected(MockedDelivery::new(stream::iter(incomings), sink::drain())); + let party = round_based::mpc::connected_halves( + stream::iter(incomings), + sink::drain().sink_map_err(|e| match e {}), + ); protocol_of_random_generation(party, 0, 3, rng).await } -struct MockedDelivery { - incoming: I, - outgoing: O, -} - -impl MockedDelivery { - pub fn new(incoming: I, outgoing: O) -> Self { - Self { incoming, outgoing } - } -} - -impl Delivery for MockedDelivery -where - I: Stream, IErr>> + Send + Unpin + 'static, - O: Sink, Error = OErr> + Send + Unpin, - IErr: std::error::Error + Send + Sync + 'static, - OErr: std::error::Error + Send + Sync + 'static, -{ - type Send = O; - type Receive = I; - type SendError = OErr; - type ReceiveError = IErr; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.incoming, self.outgoing) - } -} - #[derive(Debug)] struct DummyError; diff --git a/round-based-tests/tests/random_beacon_with_echo.rs b/round-based-tests/tests/random_beacon_with_echo.rs new file mode 100644 index 0000000..d6f2ec8 --- /dev/null +++ b/round-based-tests/tests/random_beacon_with_echo.rs @@ -0,0 +1,190 @@ +use hex_literal::hex; +use matches::assert_matches; +use rand_chacha::rand_core::SeedableRng; + +use random_generation_protocol::{protocol_of_random_generation, CommitMsg, DecommitMsg, Msg}; +use round_based::{echo_broadcast as echo, Incoming, MessageType, Outgoing}; + +const PARTY0_SEED: [u8; 32] = + hex!("6772d079d5c984b3936a291e36b0d3dc6c474e36ed4afdfc973ef79a431ca870"); +const PARTY0_COMMITMENT: [u8; 32] = + hex!("6ac69d1b2082536de5d4f5092f807873173fac86117cb7b5e179b191e97fc8d7"); +const PARTY0_RANDOMNESS: [u8; 32] = + hex!("15f88064c7daeb863f37c3a41466a827f5ca07b7f7dd8e58c510ff612e641ea2"); +const PARTY1_COMMITMENT: [u8; 32] = + hex!("2a8c585d9a80cb78bc226f4ab35a75c8e5834ff77a83f41cf6c893ea0f3b2aed"); +const ECHO_MSG: [u8; 32] = hex!("b9c51267eae8dc5ea988111e933709fa8496b5cab4aa11778f61271527d3760a"); +const PARTY1_RANDOMNESS: [u8; 32] = + hex!("12a595f4893fdb4ab9cc38caeec5f7456acb3002ca58457c5056977ce59136a6"); +const PARTY2_COMMITMENT: [u8; 32] = + hex!("01274ef40aece8aa039587cc05620a19b80a5c93fbfb24a9f8e1b77b7936e47d"); +const PARTY2_RANDOMNESS: [u8; 32] = + hex!("6fc78a926c7eebfad4e98e796cd53b771ac5947b460567c7ea441abb957c89c7"); +const PROTOCOL_OUTPUT: [u8; 32] = + hex!("689a9f02229bdb36521275179676641585c4a3ce7b80ace37f0272a65e89a1c3"); +const PARTY_OVERWRITES: [u8; 32] = + hex!("00aa11bb22cc33dd44ee55ff6677889900aa11bb22cc33dd44ee55ff66778899"); + +#[test] +fn random_generation_completes() { + let mut sim = simulation(); + // Round 0 - commitment + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY0_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 0, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 1, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY2_COMMITMENT.into(), + })), + }); + // Round 1 - echo round + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 2, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 3, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + // Round 2 - decommitment + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY0_RANDOMNESS.into(), + })), + }); + sim.receives(Incoming { + id: 4, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY1_RANDOMNESS, + })), + }); + sim.receives(Incoming { + id: 5, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::DecommitMsg(DecommitMsg { + randomness: PARTY2_RANDOMNESS, + })), + }); + + sim.outputs().unwrap().expect_eq(&PROTOCOL_OUTPUT); +} + +#[test] +fn detects_unreliable_broadcast() { + let mut sim = simulation(); + // Round 0 - commitment + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY0_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 0, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY1_COMMITMENT.into(), + })), + }); + sim.receives(Incoming { + id: 1, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Main(Msg::CommitMsg(CommitMsg { + commitment: PARTY2_COMMITMENT.into(), + })), + }); + // Round 1 - echo round + sim.sends().expect_eq(&Outgoing { + recipient: round_based::MessageDestination::AllParties { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 2, + sender: 1, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: ECHO_MSG.into(), + }, + }); + sim.receives(Incoming { + id: 3, + sender: 2, + msg_type: MessageType::Broadcast { reliable: false }, + msg: echo::Msg::Echo { + round: 0, + hash: PARTY_OVERWRITES.into(), + }, + }); + + assert_matches!( + sim.outputs().unwrap_err().0, + random_generation_protocol::Error::Round1Receive(echo::CompleteRoundError::Echo(err)) + if err.reliability_check_failed() + ); +} + +fn simulation() -> round_based_tests::PartySim< + impl round_based::state_machine::StateMachine< + Msg = echo::Msg, + Output = Result< + [u8; 32], + random_generation_protocol::Error< + round_based::echo_broadcast::CompleteRoundError< + round_based::mpc::party::CompleteRoundError< + round_based::echo_broadcast::Error, + round_based::state_machine::DeliveryErr, + >, + round_based::state_machine::DeliveryErr, + >, + round_based::echo_broadcast::Error, + >, + >, + >, +> { + let rng = rand_chacha::ChaCha8Rng::from_seed(PARTY0_SEED); + round_based_tests::new_one_party_sim(|party| async { + let party = round_based::echo_broadcast::wrap(party, 0, 3); + protocol_of_random_generation(party, 0, 3, rng).await + }) +} diff --git a/round-based/Cargo.toml b/round-based/Cargo.toml index 3cea8fe..dacd8ba 100644 --- a/round-based/Cargo.toml +++ b/round-based/Cargo.toml @@ -26,6 +26,12 @@ round-based-derive = { version = "0.2", optional = true, path = "../round-based- tokio = { version = "1", features = ["rt"], optional = true } tokio-stream = { version = "0.1", features = ["sync"], optional = true } +pin-project-lite = "0.2" + +# echo broadcast +digest = { version = "0.10", default-features = false, optional = true } +udigest = { version = "0.2", default-features = false, features = ["alloc", "digest", "inline-struct"], optional = true } + [dev-dependencies] trybuild = "1" matches = "0.1" @@ -38,6 +44,8 @@ rand_dev = "0.1" anyhow = "1" +sha2 = "0.10" + [features] default = [] state-machine = [] @@ -46,6 +54,8 @@ sim-async = ["sim", "tokio/sync", "tokio-stream", "futures-util/alloc"] derive = ["round-based-derive"] runtime-tokio = ["tokio"] +echo-broadcast = ["dep:digest", "dep:udigest"] + [[test]] name = "derive" required-features = ["derive"] diff --git a/round-based/src/_docs.rs b/round-based/src/_docs.rs index 82c5a16..dc57265 100644 --- a/round-based/src/_docs.rs +++ b/round-based/src/_docs.rs @@ -1,21 +1,11 @@ -use core::convert::Infallible; +use futures_util::{Sink, SinkExt, Stream}; -use phantom_type::PhantomType; +use crate::{Incoming, Outgoing}; -use crate::{Delivery, Incoming, Outgoing}; - -pub fn fake_delivery() -> impl Delivery { - struct FakeDelivery(PhantomType); - impl Delivery for FakeDelivery { - type Send = futures_util::sink::Drain>; - type Receive = futures_util::stream::Pending, Infallible>>; - - type SendError = Infallible; - type ReceiveError = Infallible; - - fn split(self) -> (Self::Receive, Self::Send) { - (futures_util::stream::pending(), futures_util::sink::drain()) - } - } - FakeDelivery(PhantomType::new()) +pub fn fake_delivery( +) -> impl Stream, E>> + Sink, Error = E> + Unpin { + crate::mpc::Halves::new( + futures_util::stream::pending::, E>>(), + futures_util::sink::drain().sink_map_err(|e| match e {}), + ) } diff --git a/round-based/src/delivery.rs b/round-based/src/delivery.rs index 25550d6..aa230f9 100644 --- a/round-based/src/delivery.rs +++ b/round-based/src/delivery.rs @@ -1,39 +1,3 @@ -use futures_util::{Sink, Stream}; - -/// Networking abstraction -/// -/// Basically, it's pair of channels: [`Stream`] for receiving messages, and [`Sink`] for sending -/// messages to other parties. -pub trait Delivery { - /// Outgoing delivery channel - type Send: Sink, Error = Self::SendError> + Unpin; - /// Incoming delivery channel - type Receive: Stream, Self::ReceiveError>> + Unpin; - /// Error of outgoing delivery channel - type SendError: core::error::Error + Send + Sync + 'static; - /// Error of incoming delivery channel - type ReceiveError: core::error::Error + Send + Sync + 'static; - /// Returns a pair of incoming and outgoing delivery channels - fn split(self) -> (Self::Receive, Self::Send); -} - -impl Delivery for (I, O) -where - I: Stream, IErr>> + Unpin, - O: Sink, Error = OErr> + Unpin, - IErr: core::error::Error + Send + Sync + 'static, - OErr: core::error::Error + Send + Sync + 'static, -{ - type Send = O; - type Receive = I; - type SendError = OErr; - type ReceiveError = IErr; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.0, self.1) - } -} - /// Incoming message #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct Incoming { @@ -52,7 +16,11 @@ pub struct Incoming { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MessageType { /// Message was broadcasted - Broadcast, + Broadcast { + /// Indicates that message was reliably broadcasted, meaning that it's guaranteed (cryptographically or through + /// other trust assumptions) that all honest participants of the protocol received the same message + reliable: bool, + }, /// P2P message P2P, } @@ -104,11 +72,16 @@ impl Incoming { } } - /// Checks whether it's broadcast message + /// Checks whether it's broadcast message (regardless if it's reliable or not) pub fn is_broadcast(&self) -> bool { matches!(self.msg_type, MessageType::Broadcast { .. }) } + /// Checks if message was reliably broadcasted + pub fn is_reliably_broadcasted(&self) -> bool { + matches!(self.msg_type, MessageType::Broadcast { reliable: true }) + } + /// Checks whether it's p2p message pub fn is_p2p(&self) -> bool { matches!(self.msg_type, MessageType::P2P) @@ -116,7 +89,7 @@ impl Incoming { } /// Outgoing message -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Outgoing { /// Message destination: either one party (p2p message) or all parties (broadcast message) pub recipient: MessageDestination, @@ -126,9 +99,17 @@ pub struct Outgoing { impl Outgoing { /// Constructs an outgoing message addressed to all parties - pub fn broadcast(msg: M) -> Self { + pub fn all_parties(msg: M) -> Self { + Self { + recipient: MessageDestination::AllParties { reliable: false }, + msg, + } + } + + /// Constructs an outgoing message addressed to all parties via reliable broadcast channel + pub fn reliable_broadcast(msg: M) -> Self { Self { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: true }, msg, } } @@ -175,7 +156,12 @@ impl Outgoing { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum MessageDestination { /// Broadcast message - AllParties, + AllParties { + /// Indicates that message needs to be reliably broadcasted, meaning that when recipient receives this message, + /// it must be assured (cryptographically or through other trust assumptions) that all honest participants of the + /// protocol received the same message + reliable: bool, + }, /// P2P message OneParty(PartyIndex), } @@ -185,8 +171,12 @@ impl MessageDestination { pub fn is_p2p(&self) -> bool { matches!(self, MessageDestination::OneParty(_)) } - /// Returns `true` if it's broadcast message + /// Returns `true` if it's broadcast message (regardless if it's reliable or not) pub fn is_broadcast(&self) -> bool { matches!(self, MessageDestination::AllParties { .. }) } + /// Returns `true` if it's reliable broadcast message + pub fn is_reliable_broadcast(&self) -> bool { + matches!(self, MessageDestination::AllParties { reliable: true }) + } } diff --git a/round-based/src/echo_broadcast/error.rs b/round-based/src/echo_broadcast/error.rs new file mode 100644 index 0000000..093502d --- /dev/null +++ b/round-based/src/echo_broadcast/error.rs @@ -0,0 +1,117 @@ +/// An error in echo-broadcast sub-protocol +/// +/// It may indicate that reliability check was not successful, or some other error, for instance, that +/// a round that requires reliable broadcast behaves unexpectedly (e.g. if we received two messages +/// from the same party and the main round didn't return an error). +#[derive(thiserror::Error, Debug)] +#[error(transparent)] +pub struct EchoError(#[from] Reason); + +impl EchoError { + /// Indicates that an error was caused by failed reliability check + pub fn reliability_check_failed(&self) -> bool { + matches!(self.0, Reason::MismatchedHash) + } +} + +#[derive(thiserror::Error, Debug)] +pub(super) enum Reason { + #[error( + "round store received two msgs from same party and \ + didn't return an error" + )] + StoreReceivedTwoMsgsFromSameParty, + #[error("received a msg from principal protocol when round is over")] + ReceivedMainMsgWhenRoundOver, + #[error("unknown sender i={i} (n={n})")] + UnknownSender { i: u16, n: usize }, + #[error("local party index is out of bounds, probably indicates a bug")] + OwnIndexOutOfBounds { i: u16, n: usize }, + #[error("principal round is finished, but store doesn't output")] + MainRoundFinishedButStoreDoesntOutput, + + #[error("handle incoming echo msg")] + HandleEcho(#[source] crate::round::RoundInputError), + + #[error("reliability check error: messages were not reliably broadcasted")] + MismatchedHash, + + #[error("impossible state (it's a bug)")] + StateGone, + + #[error("main round is in unexpected state (it's a bug)")] + UnexpectedMainRoundState, + + #[error("clone error msg: RoundMsg implementation is incorrect")] + RoundMsgClone, + + #[error( + "protocol attempts to send a broadcast msg twice within the same round, it's unsupported" + )] + SendTwice, + + #[error("cannot convert a sent round msg back from proto msg (it's a bug)")] + SentMsgFromProto, + + #[error( + "sent a message that doesn't require reliable broadcast in reliable broadcast \ + round (round: {round}, dest: {dest:?})" + )] + SentNonReliableMsgInReliableRound { + dest: crate::MessageDestination, + round: u16, + }, + + #[error( + "sent a reliable broadcast message in a regular round that doesn't require \ + reliable broadcast: you might have forgotten to register a reliable broadcast \ + round, or round store doesn't expose a property required to identify a reliable \ + broadcast round (round: {round})" + )] + SentReliableMsgInNonReliableRound { round: u16 }, +} + +/// An error originated either from main protocol or echo-broadcast sub-protocol +#[derive(thiserror::Error, Debug)] +pub enum Error { + /// Error originated from main protocol + #[error("error originated in principal protocol")] + Main(#[source] E), + /// Error originated from echo-broadcast sub-protocol + #[error("echo broadcast")] + Echo(#[from] EchoError), +} + +impl From for Error { + fn from(value: Reason) -> Self { + Error::Echo(value.into()) + } +} + +/// An error returned in round completion +#[derive(thiserror::Error, Debug)] +pub enum CompleteRoundError { + /// Error occurred while handling received message(s) + #[error(transparent)] + CompleteRound(CompleteErr), + /// Error occurred while sending a message to another party + /// + /// The only message we send during round completion is echo message + #[error(transparent)] + Send(SendErr), + /// Echo broadcast sub-protocol error + #[error(transparent)] + Echo(EchoError), +} + +impl From for CompleteRoundError { + fn from(err: Reason) -> Self { + CompleteRoundError::Echo(err.into()) + } +} + +impl From for CompleteRoundError { + fn from(err: EchoError) -> Self { + err.0.into() + } +} diff --git a/round-based/src/echo_broadcast/mod.rs b/round-based/src/echo_broadcast/mod.rs new file mode 100644 index 0000000..3598e6d --- /dev/null +++ b/round-based/src/echo_broadcast/mod.rs @@ -0,0 +1,460 @@ +//! Reliable broadcast for any protocol via echo messages +//! +//! Broadcast message is a message meant to be received by all participants of the protocol. +//! +//! We say that message is reliably broadcasted if, upon reception, it is guaranteed that all +//! honest participants of the protocol has received the same message. +//! +//! One way to achieve the reliable broadcast is by adding an echo round: when we receive +//! messages in a reliable broadcast round, we hash all messages, and we send the hash to all +//! other participants. If party receives a the same hash from everyone else, we can be +//! assured that messages in the round were reliably broadcasted. +//! +//! This module provides a mechanism that automatically add an echo round per each +//! round of the protocol that requires a reliable broadcast. +//! +//! ## Example +//! +//! ```rust +//! # #[derive(round_based::ProtocolMsg, Clone, udigest::Digestable)] +//! # enum KeygenMsg {} +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! // protocol to be executed that **requires** reliable broadcast +//! async fn keygen(mpc: M, i: u16, n: u16) -> Result +//! where +//! M: round_based::Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! // The full message type, which corresponds to keygen msg + echo broadcast msg +//! type Msg = round_based::echo_broadcast::Msg; +//! // establishes network connection(s) to other parties, but +//! // **does not** support reliable broadcast +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! let delivery = connect().await; +//! +//! # let (i, n) = (1, 3); +//! // constructs an MPC engine as usual +//! let mpc = round_based::mpc::connected(delivery); +//! // wraps an engine to add reliable broadcast support +//! let mpc = round_based::echo_broadcast::wrap(mpc, i, n); +//! +//! // execute the protocol +//! let keyshare = keygen(mpc, i, n).await?; +//! # Ok(()) } +//! ``` + +use core::marker::PhantomData; + +use alloc::collections::btree_map::BTreeMap; +use digest::Digest; + +use crate::{ + round::{RoundInfo, RoundStore, RoundStoreExt}, + Mpc, MpcExecution, Outgoing, ProtocolMsg, RoundMsg, +}; + +mod error; +mod store; + +pub use self::error::{CompleteRoundError, EchoError, Error}; + +/// Message of the protocol with echo broadcast round(s) +pub enum Msg { + /// Message from echo broadcast sub-protocol + Echo { + /// Indicates for which round of main protocol this echo message is transmitted + /// + /// Note that this field is controlled by potential malicious party. If it sets it + /// to the round that doesn't exist, the protocol will likely be aborted with an error + /// that we received a message from unregistered round, which may appear as implementation + /// error (i.e. API misuse), but in fact it's a malicious abort. + round: u16, + /// Hash of all messages received in `round` + hash: digest::Output, + }, + /// Message from the main protocol + Main(M), +} + +/// Sub-messages of [`Msg`] +/// +/// Sub-messages implement [`RoundMsg`] trait for [`Msg`] +mod sub_msg { + pub struct EchoMsg { + pub hash: digest::Output, + pub _round: core::marker::PhantomData, + } + #[derive(Debug, Clone)] + pub struct Main(pub M); + + impl Clone for EchoMsg { + fn clone(&self) -> Self { + Self { + hash: self.hash.clone(), + _round: core::marker::PhantomData, + } + } + } +} + +// `D` doesn't implement traits like `Clone`, `Eq`, etc. so we have to implement those traits by hand + +impl Clone for Msg { + fn clone(&self) -> Self { + match self { + Self::Echo { round, hash } => Self::Echo { + round: *round, + hash: hash.clone(), + }, + Self::Main(msg) => Self::Main(msg.clone()), + } + } +} + +impl PartialEq for Msg { + fn eq(&self, other: &Self) -> bool { + match self { + Self::Echo { round, hash } => { + matches!(other, Self::Echo { round: r2, hash: h2 } if round == r2 && hash == h2) + } + Self::Main(msg) => matches!(other, Self::Main(m2) if msg == m2), + } + } +} + +impl Eq for Msg {} + +impl core::fmt::Debug for Msg { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Echo { round, hash } => f + .debug_struct("Msg::Echo") + .field("round", round) + .field("hash", hash) + .finish(), + Self::Main(msg) => f.debug_tuple("Msg::Main").field(msg).finish(), + } + } +} + +impl ProtocolMsg for Msg { + fn round(&self) -> u16 { + match self { + Self::Echo { round, .. } => 2 * round + 1, + Self::Main(m) => 2 * m.round(), + } + } +} + +impl RoundMsg> for Msg +where + M: RoundMsg, +{ + const ROUND: u16 = 2 * M::ROUND + 1; + fn to_protocol_msg(round_msg: sub_msg::EchoMsg) -> Self { + Self::Echo { + round: M::ROUND, + hash: round_msg.hash, + } + } + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + match protocol_msg { + Self::Echo { round, hash } if round == M::ROUND => Ok(sub_msg::EchoMsg { + hash, + _round: PhantomData, + }), + _ => Err(protocol_msg), + } + } +} + +impl RoundMsg> for Msg +where + ProtoM: ProtocolMsg + RoundMsg, +{ + const ROUND: u16 = 2 * >::ROUND; + fn to_protocol_msg(round_msg: sub_msg::Main) -> Self { + Self::Main(ProtoM::to_protocol_msg(round_msg.0)) + } + fn from_protocol_msg(protocol_msg: Self) -> Result, Self> { + if let Self::Main(msg) = protocol_msg { + ProtoM::from_protocol_msg(msg) + .map(sub_msg::Main) + .map_err(|m| Self::Main(m)) + } else { + Err(protocol_msg) + } + } +} + +/// Wraps an [`Mpc`] engine and provides echo broadcast capabilities +pub fn wrap(party: M, i: u16, n: u16) -> WithEchoBroadcast +where + D: Digest, + M: Mpc>, + MainMsg: udigest::Digestable, +{ + WithEchoBroadcast { + party, + i, + n, + sent_reliable_msgs: Default::default(), + _ph: PhantomData, + } +} + +/// [`Mpc`] engine with echo-broadcast capabilities +pub struct WithEchoBroadcast { + party: M, + i: u16, + n: u16, + sent_reliable_msgs: BTreeMap>, + _ph: PhantomData, +} + +impl WithEchoBroadcast { + fn map_party

(self, f: impl FnOnce(M) -> P) -> WithEchoBroadcast { + let party = f(self.party); + WithEchoBroadcast { + party, + i: self.i, + n: self.n, + sent_reliable_msgs: self.sent_reliable_msgs, + _ph: PhantomData, + } + } +} + +impl Mpc for WithEchoBroadcast +where + D: Digest + 'static, + M: Mpc>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, +{ + type Msg = MainMsg; + + type Exec = WithEchoBroadcast; + + type SendErr = error::Error; + + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg, + { + let reliable_broadcast_required = round + .read_prop::() + .map(|x| x.0); + if reliable_broadcast_required == Some(true) { + let (main_round, echo_round) = store::new::(self.i, self.n, round); + let main_round = self.party.add_round(store::WithMainMsg(main_round)); + let echo_round = self.party.add_round(store::WithEchoError::from(echo_round)); + + self.sent_reliable_msgs.insert(Self::Msg::ROUND, None); + + Round(Inner::WithReliabilityCheck { + main_round, + echo_round, + }) + } else { + let round = self + .party + .add_round(store::WithError(store::WithMainMsg(round))); + Round(Inner::Unmodified(round)) + } + } + + fn finish_setup(self) -> Self::Exec { + self.map_party(|p| p.finish_setup()) + } +} + +impl WithEchoBroadcast +where + D: Digest, + MainMsg: ProtocolMsg + Clone, +{ + fn on_send(&mut self, outgoing: &mut Outgoing) -> Result<(), error::EchoError> { + if let Some(slot) = self.sent_reliable_msgs.get_mut(&outgoing.msg.round()) { + if !outgoing.recipient.is_reliable_broadcast() { + // it's reliable broadcast round, but message is not reliable broadcast + return Err(error::Reason::SentNonReliableMsgInReliableRound { + dest: outgoing.recipient, + round: outgoing.msg.round(), + } + .into()); + } + // Message delivery layer doesn't need to know that protocol wants this message to be + // reliably broadcasted - echo broadcast takes care of it + outgoing.recipient = crate::MessageDestination::AllParties { reliable: false }; + if slot.is_some() { + return Err(error::Reason::SendTwice.into()); + } + *slot = Some(outgoing.msg.clone()) + } else if outgoing.recipient.is_reliable_broadcast() { + // it's not a reliable broadcast round, but message is a reliable broadcast + return Err(error::Reason::SentReliableMsgInNonReliableRound { + round: outgoing.msg.round(), + } + .into()); + } + + Ok(()) + } +} + +impl MpcExecution for WithEchoBroadcast +where + D: Digest + 'static, + M: MpcExecution>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, +{ + type Round = Round; + type Msg = MainMsg; + type CompleteRoundErr = + error::CompleteRoundError>, M::SendErr>; + type SendErr = error::Error; + type SendMany = WithEchoBroadcast; + + async fn complete( + &mut self, + round: Self::Round, + ) -> Result> + where + R: RoundInfo, + Self::Msg: RoundMsg, + { + match round.0 { + Inner::Unmodified(round) => { + // regular round that doesn't need reliable broadcast + let output = self + .party + .complete(round) + .await + .map_err(error::CompleteRoundError::CompleteRound)?; + Ok(output) + } + Inner::WithReliabilityCheck { + main_round, + echo_round, + } => { + // receive all messages in the main round + let main_output = self + .party + .complete(main_round) + .await + .map_err(error::CompleteRoundError::CompleteRound)?; + // retrieve a msg that we sent in this round + let sent_msg = + if let Some(Some(msg)) = self.sent_reliable_msgs.remove(&Self::Msg::ROUND) { + let msg: R::Msg = Self::Msg::from_protocol_msg(msg) + .map_err(|_| error::Reason::SentMsgFromProto)?; + Some(msg) + } else { + None + }; + // calculate a hash and send it to all other parties + let (main_output, hash) = main_output.with_my_msg(sent_msg)?; + self.party + .send_to_all(Msg::Echo { + round: Self::Msg::ROUND, + hash, + }) + .await + .map_err(error::CompleteRoundError::Send)?; + // receive echoes from other parties + let echoes = self + .party + .complete(echo_round) + .await + .map_err(error::CompleteRoundError::CompleteRound)?; + // check that everyone sent the same hash + let main_output = main_output.with_echo_output(echoes)?; + + Ok(main_output) + } + } + } + + async fn send(&mut self, mut outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&mut outgoing)?; + + self.party + .send(outgoing.map(Msg::Main)) + .await + .map_err(error::Error::Main) + } + + fn send_many(self) -> Self::SendMany { + self.map_party(|p| p.send_many()) + } + + async fn yield_now(&self) { + self.party.yield_now().await + } +} + +/// Round registration witness returned by [`WithEchoBroadcast::add_round()`] +pub struct Round(Inner) +where + M: MpcExecution, + D: Digest + 'static, + ProtoMsg: 'static, + R: RoundInfo; + +enum Inner +where + M: MpcExecution, + D: Digest + 'static, + ProtoMsg: 'static, + R: RoundInfo, +{ + /// Round that we do not modify (round that doesn't require reliable broadcast) + Unmodified(M::Round>>), + WithReliabilityCheck { + main_round: M::Round>>, + echo_round: M::Round, R::Error>>, + }, +} + +impl crate::mpc::SendMany for WithEchoBroadcast +where + D: Digest + 'static, + M: crate::mpc::SendMany>, + MainMsg: ProtocolMsg + udigest::Digestable + Clone + 'static, +{ + type Exec = WithEchoBroadcast; + type Msg = MainMsg; + type SendErr = error::Error; + + async fn send(&mut self, mut outgoing: Outgoing) -> Result<(), Self::SendErr> { + self.on_send(&mut outgoing)?; + self.party + .send(outgoing.map(Msg::Main)) + .await + .map_err(error::Error::Main) + } + + async fn flush(self) -> Result { + let party = self.party.flush().await.map_err(error::Error::Main)?; + Ok(WithEchoBroadcast { + party, + i: self.i, + n: self.n, + sent_reliable_msgs: self.sent_reliable_msgs, + _ph: PhantomData, + }) + } +} diff --git a/round-based/src/echo_broadcast/store.rs b/round-based/src/echo_broadcast/store.rs new file mode 100644 index 0000000..507cc7f --- /dev/null +++ b/round-based/src/echo_broadcast/store.rs @@ -0,0 +1,395 @@ +use alloc::vec::Vec; +use core::marker::PhantomData; +use digest::Digest; + +use crate::{ + round::{RoundInfo, RoundInput, RoundMsgs, RoundStore}, + Incoming, RoundMsg, +}; + +use super::{error, sub_msg}; + +const TAG: &[u8] = b"dfns.round_based.echo_broadcast"; + +pub fn new( + i: u16, + n: u16, + main_round: S, +) -> (MainRound, EchoRound) { + let params = Params { i, n }; + let state = match main_round.output() { + Ok(output) => MainRoundState::Output { output }, + Err(store) => MainRoundState::Ongoing { store }, + }; + let main_round = MainRound { + params, + state, + received_msgs: core::iter::repeat_with(|| None).take(n.into()).collect(), + _ph: PhantomData, + }; + let echo_round = EchoRound { + echo_round: RoundInput::broadcast(i, n), + _round: PhantomData, + }; + (main_round, echo_round) +} + +enum MainRoundState { + Ongoing { store: S }, + Output { output: S::Output }, + Finished, + Gone, +} + +#[derive(Clone, Copy, Debug)] +struct Params { + i: u16, + n: u16, +} + +pub struct MainRound { + params: Params, + state: MainRoundState, + received_msgs: Vec>, + _ph: PhantomData<(D, ProtoMsg)>, +} + +impl RoundInfo for MainRound { + type Msg = S::Msg; + /// When a main round is finished, we output a builder that can be used to + /// calculate a hash of messages received by all parties in the reliable + /// broadcast round. Then this hash needs to be re-sent to all participants. + /// + /// Only if we receive the same hash from all parties in [`EchoRound`], only + /// then we can obtain a main round output. + type Output = NeedsOwnMsg; + type Error = error::Error; +} +impl RoundStore for MainRound +where + ProtoMsg: RoundMsg + Clone + 'static, + D: 'static, +{ + fn add_message(&mut self, mut incoming: Incoming) -> Result<(), Self::Error> { + let wants_more = match &mut self.state { + MainRoundState::Ongoing { store, .. } => { + // We pretend that msg was reliably broadcasted even though the reliability check is + // not yet enforced, however, we do not expose the output of the round unless + // reliability check has passed. + incoming.msg_type = crate::MessageType::Broadcast { reliable: true }; + + // Note: round msg doesn't implement `Clone`, but ProtoMsg does, so we + // use a trick to create a clone of incoming msg + let (incoming1, incoming2) = clone_incoming_round_msg::(incoming) + .ok_or(error::Reason::RoundMsgClone)?; + store.add_message(incoming1).map_err(error::Error::Main)?; + let n = self.received_msgs.len(); + let slot = self + .received_msgs + .get_mut(usize::from(incoming2.sender)) + .ok_or(error::Reason::UnknownSender { + i: incoming2.sender, + n, + })?; + if slot.is_some() { + return Err(error::Reason::StoreReceivedTwoMsgsFromSameParty.into()); + } + *slot = Some(ProtoMsg::to_protocol_msg(incoming2.msg)); + store.wants_more() + } + MainRoundState::Gone => return Err(error::Reason::StateGone.into()), + MainRoundState::Output { .. } | MainRoundState::Finished => { + return Err(error::Reason::ReceivedMainMsgWhenRoundOver.into()) + } + }; + + if !wants_more { + let store = core::mem::replace(&mut self.state, MainRoundState::Gone); + let MainRoundState::Ongoing { store } = store else { + return Err(error::Reason::UnexpectedMainRoundState.into()); + }; + let Ok(output) = store.output() else { + self.state = MainRoundState::Finished; + return Err(error::Reason::MainRoundFinishedButStoreDoesntOutput.into()); + }; + self.state = MainRoundState::Output { output }; + } + + Ok(()) + } + + fn wants_more(&self) -> bool { + match &self.state { + MainRoundState::Ongoing { .. } => { + // Note that on each `add_message` we check if `store.wants_more()`, + // and if it doesn't we change the state to MainRoundState::Output + true + } + _ => false, + } + } + + fn output(self) -> Result { + match self.state { + MainRoundState::Output { output } => Ok(NeedsOwnMsg { + params: self.params, + main_round_output: output, + received_msgs: self.received_msgs, + _hash: PhantomData, + }), + state => Err(Self { + params: self.params, + state, + received_msgs: self.received_msgs, + _ph: PhantomData, + }), + } + } +} + +/// Duplicates a round msg +/// +/// This function doesn't require that round msg implements `Clone`, instead it only requires +/// that protocol msg is cloneable. It works by converting round msg into protocol msg, creating +/// two clones, and converting them back to round msg. +/// +/// We need this function in places where we know that protocol msg is cloneable, but we can't +/// prove to the compiler that round msg is cloneable as well. +/// +/// Function returns `None` only if [`RoundMsg`] implementation is not correct. +fn clone_round_msg(round_msg: R) -> Option<(R, R)> +where + M: RoundMsg + Clone, +{ + let proto_msg = M::to_protocol_msg(round_msg); + + let round_msg1 = M::from_protocol_msg(proto_msg.clone()).ok()?; + let round_msg2 = M::from_protocol_msg(proto_msg).ok()?; + + Some((round_msg1, round_msg2)) +} + +/// Similar to [`clone_round_msg`] but accepts [`Incoming`](Incoming) +fn clone_incoming_round_msg( + incoming_round_msg: Incoming, +) -> Option<(Incoming, Incoming)> +where + M: RoundMsg + Clone, +{ + let (msg1, msg2) = clone_round_msg::(incoming_round_msg.msg)?; + + let incoming = |msg| Incoming { + id: incoming_round_msg.id, + sender: incoming_round_msg.sender, + msg_type: incoming_round_msg.msg_type, + msg, + }; + Some((incoming(msg1), incoming(msg2))) +} + +/// An output of [`MainRound`] which needs an own message sent by local party +/// in this round. Once provided in [`NeedsOwnMsg::with_own_msg`], it outputs +/// a hash of messages received by all parties in this round (that needs to be +/// re-sent to all participants), and [`WithReliabilityCheck`] that takes +/// messages received in echo round and outputs main round result only if reliability +/// check passes. +pub struct NeedsOwnMsg { + params: Params, + main_round_output: S::Output, + received_msgs: Vec>, + _hash: PhantomData, +} + +pub struct ReliabilityCheck { + expected_hash: digest::Output, + main_round_output: S::Output, +} + +impl NeedsOwnMsg +where + D: Digest, + S: RoundInfo, + ProtoMsg: RoundMsg + udigest::Digestable, +{ + pub fn with_my_msg( + mut self, + msg: Option, + ) -> Result<(ReliabilityCheck, digest::Output), error::EchoError> { + let n = self.received_msgs.len(); + let msg = msg.map(ProtoMsg::to_protocol_msg); + *self + .received_msgs + .get_mut(usize::from(self.params.i)) + .ok_or(error::Reason::OwnIndexOutOfBounds { + i: self.params.i, + n, + })? = msg; + + let hash = udigest::hash::(&udigest::inline_struct!(TAG { + msgs: &self.received_msgs, + round: ProtoMsg::ROUND, + n: self.params.n, + })); + let with_reliability_check = ReliabilityCheck { + expected_hash: hash.clone(), + main_round_output: self.main_round_output, + }; + Ok((with_reliability_check, hash)) + } +} + +impl ReliabilityCheck +where + D: Digest, + S: RoundInfo, +{ + pub fn with_echo_output( + self, + echo_output: EchoRoundOutput, + ) -> Result { + if echo_output + .received_echoes + .iter() + .any(|h| *h != self.expected_hash) + { + return Err(error::Reason::MismatchedHash.into()); + } + + Ok(self.main_round_output) + } +} + +pub(super) struct EchoRound { + echo_round: RoundInput>, + _round: PhantomData, +} + +pub struct EchoRoundOutput { + received_echoes: RoundMsgs>, + _round: PhantomData, +} + +impl RoundInfo for EchoRound +where + D: Digest + 'static, + S: RoundInfo, +{ + type Msg = sub_msg::EchoMsg; + type Output = EchoRoundOutput; + type Error = error::EchoError; +} +impl RoundStore for EchoRound +where + D: Digest + 'static, + S: RoundStore, +{ + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.echo_round + .add_message(msg.map(|m| m.hash)) + .map_err(error::Reason::HandleEcho)?; + Ok(()) + } + + fn wants_more(&self) -> bool { + self.echo_round.wants_more() + } + + fn output(self) -> Result { + self.echo_round + .output() + .map(|received_echoes| EchoRoundOutput { + received_echoes, + _round: PhantomData, + }) + .map_err(|echo_round| Self { + echo_round, + _round: PhantomData, + }) + } +} + +/// Wraps a round store `S` and changes its msg type to `sub_msg::Main` +pub struct WithMainMsg(pub S); + +impl RoundInfo for WithMainMsg { + type Msg = sub_msg::Main; + type Output = S::Output; + type Error = S::Error; +} + +impl RoundStore for WithMainMsg { + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.0.add_message(msg.map(|m| m.0)) + } + fn wants_more(&self) -> bool { + self.0.wants_more() + } + fn output(self) -> Result { + self.0.output().map_err(Self) + } +} + +/// Wraps a round store `S` and changes its error to `Error` +pub struct WithError(pub S); + +impl RoundInfo for WithError { + type Msg = S::Msg; + type Output = S::Output; + type Error = error::Error; +} + +impl RoundStore for WithError { + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.0.add_message(msg).map_err(error::Error::Main) + } + fn wants_more(&self) -> bool { + self.0.wants_more() + } + fn output(self) -> Result { + self.0.output().map_err(Self) + } +} + +/// Wraps a round store `S` with `Error = error::EchoError` and changes it to `Error = error::Error` +pub struct WithEchoError { + pub store: S, + _ph: PhantomData, +} + +impl From for WithEchoError { + fn from(store: S) -> Self { + Self { + store, + _ph: PhantomData, + } + } +} + +impl RoundInfo for WithEchoError +where + S: RoundInfo, + E: core::error::Error + 'static, +{ + type Msg = S::Msg; + type Output = S::Output; + type Error = error::Error; +} + +impl RoundStore for WithEchoError +where + S: RoundStore, + E: core::error::Error + 'static, +{ + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.store.add_message(msg).map_err(error::Error::Echo) + } + fn wants_more(&self) -> bool { + self.store.wants_more() + } + fn output(self) -> Result { + self.store.output().map_err(|store| Self { + store, + _ph: PhantomData, + }) + } +} diff --git a/round-based/src/lib.rs b/round-based/src/lib.rs index a653be1..b6859b0 100644 --- a/round-based/src/lib.rs +++ b/round-based/src/lib.rs @@ -13,25 +13,103 @@ //! * Simple, configurable \ //! Protocol can be carried out in a few lines of code: check out examples. //! * Independent of networking layer \ -//! We use abstractions [`Stream`] and [`Sink`] to receive and send messages. +//! You may define your own networking layer and you don't need to change anything +//! in protocol implementation: it's agnostic of networking by default! So you can +//! use central delivery server, distributed redis nodes, postgres database, +//! p2p channels, or a public blockchain, or whatever else fits your needs. +//! +//! ## Example +//! MPC protocol execution typically looks like this: +//! +//! ```rust +//! # #[derive(round_based::ProtocolMsg)] +//! # enum KeygenMsg {} +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! // protocol to be executed, takes MPC engine `M`, index of party `i`, +//! // and number of participants `n` +//! async fn keygen(mpc: M, i: u16, n: u16) -> Result +//! where +//! M: round_based::Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! // establishes network connection(s) to other parties so they may communicate +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! let delivery = connect().await; +//! +//! // constructs an MPC engine, which, primarily, is used to communicate with +//! // other parties +//! let mpc = round_based::mpc::connected(delivery); +//! +//! # let (i, n) = (1, 3); +//! // execute the protocol +//! let keyshare = keygen(mpc, i, n).await?; +//! # Ok(()) } +//! ``` //! //! ## Networking //! //! In order to run an MPC protocol, transport layer needs to be defined. All you have to do is to -//! implement [`Delivery`] trait which is basically a stream and a sink for receiving and sending messages. +//! provide a channel which implements a stream and a sink for receiving and sending messages. +//! +//! ```rust,no_run +//! # #[derive(round_based::ProtocolMsg)] +//! # enum Msg {} +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! +//! let delivery = connect().await; +//! let party = round_based::mpc::connected(delivery); //! -//! Message delivery should meet certain criterias that differ from protocol to protocol (refer to -//! the documentation of the protocol you're using), but usually they are: +//! // run the protocol +//! # Ok(()) } +//! ``` //! -//! * Messages should be authenticated \ -//! Each message should be signed with identity key of the sender. This implies having Public Key -//! Infrastructure. -//! * P2P messages should be encrypted \ -//! Only recipient should be able to learn the content of p2p message -//! * Broadcast channel should be reliable \ -//! Some protocols may require broadcast channel to be reliable. Simply saying, when party receives a -//! broadcast message over reliable channel it should be ensured that everybody else received the same -//! message. +//! In order to guarantee the protocol security, it may require: +//! +//! * Message Authentication \ +//! Guarantees message source and integrity. If protocol requires it, make sure +//! message was sent by claimed sender and that it hasn't been tampered with. \ +//! This is typically achieved either through public-key cryptography (e.g., +//! signing with a private key) or through symmetric mechanisms like MACs (e.g., +//! HMAC) or authenticated encryption (AEAD) in point-to-point scenarios. +//! * Message Privacy \ +//! When a p2p message is sent, only recipient shall be able to read the content. \ +//! It can be achieved by using symmetric or asymmetric encryption, encryption methods +//! come with their own trade-offs (e.g. simplicity vs forward secrecy). +//! * Reliable Broadcast \ +//! When party receives a reliable broadcast message it shall be ensured that +//! everybody else received the same message. \ +//! Our library provides [`echo_broadcast`] support out-of-box that enforces broadcast +//! reliability by adding an extra communication round per each round that requires +//! reliable broadcast. \ +//! More advanced techniques implement [Byzantine fault](https://en.wikipedia.org/wiki/Byzantine_fault) +//! tolerant broadcast +//! +//! ## Developing MPC protocol with `round_based` +//! We plan to write a book guiding through MPC protocol development process, but +//! while it's not done, you may refer to [random beacon example](https://github.com/LFDT-Lockness/round-based/blob/m/examples/random-generation-protocol/src/lib.rs) +//! and our well-documented API. //! //! ## Features //! @@ -39,22 +117,21 @@ //! * `sim-async` enables protocol execution simulation with tokio runtime, see [`sim::async_env`] //! module //! * `state-machine` provides ability to carry out the protocol, defined as async function, via Sync -//! API, see [`state_machine`] module -//! * `derive` is needed to use [`ProtocolMessage`](macro@ProtocolMessage) proc macro -//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](runtime) +//! API, see [`state_machine`] module +//! * `echo-broadcast` adds [`echo_broadcast`] support +//! * `derive` is needed to use [`ProtocolMsg`](macro@ProtocolMsg) proc macro +//! * `runtime-tokio` enables [tokio]-specific implementation of [async runtime](mpc::party::runtime) //! //! ## Join us in Discord! //! Feel free to reach out to us [in Discord](https://discordapp.com/channels/905194001349627914/1285268686147424388)! #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg, doc_cfg_hide))] -#![forbid(unused_crate_dependencies, missing_docs)] +#![warn(unused_crate_dependencies, missing_docs)] +#![allow(async_fn_in_trait)] #![no_std] extern crate alloc; -#[doc(no_inline)] -pub use futures_util::{Sink, SinkExt, Stream, StreamExt}; - /// Fixes false-positive of `unused_crate_dependencies` lint that only occur in the tests #[cfg(test)] mod false_positives { @@ -63,12 +140,15 @@ mod false_positives { use trybuild as _; use {hex as _, rand as _, rand_dev as _}; + + use sha2 as _; } mod delivery; -pub mod party; -pub mod rounds_router; -pub mod runtime; +#[cfg(feature = "echo-broadcast")] +pub mod echo_broadcast; +pub mod mpc; +pub mod round; #[cfg(feature = "state-machine")] pub mod state_machine; @@ -76,17 +156,15 @@ pub mod state_machine; pub mod sim; pub use self::delivery::*; +pub use self::mpc::MpcParty; #[doc(no_inline)] -pub use self::{ - party::{Mpc, MpcParty}, - rounds_router::{ProtocolMessage, RoundMessage}, -}; +pub use self::mpc::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; #[doc(hidden)] pub mod _docs; -/// Derives [`ProtocolMessage`] and [`RoundMessage`] traits +/// Derives [`ProtocolMsg`] and [`RoundMsg`] traits /// -/// See [`ProtocolMessage`] docs for more details +/// See [`ProtocolMsg`] docs for more details #[cfg(feature = "derive")] -pub use round_based_derive::ProtocolMessage; +pub use round_based_derive::ProtocolMsg; diff --git a/round-based/src/mpc/mod.rs b/round-based/src/mpc/mod.rs new file mode 100644 index 0000000..afeb401 --- /dev/null +++ b/round-based/src/mpc/mod.rs @@ -0,0 +1,348 @@ +//! Party of MPC protocol +//! +//! [`MpcParty`] is party of MPC protocol, connected to network, ready to start carrying out the protocol. +//! +//! ```rust +//! use round_based::{Incoming, Outgoing}; +//! +//! # #[derive(round_based::ProtocolMsg)] +//! # enum KeygenMsg {} +//! # struct KeyShare; +//! # struct Error; +//! # type Result = std::result::Result; +//! # async fn doc() -> Result<()> { +//! async fn keygen(party: M, i: u16, n: u16) -> Result +//! where +//! M: round_based::Mpc +//! { +//! // ... +//! # unimplemented!() +//! } +//! async fn connect() -> +//! impl futures::Stream>> +//! + futures::Sink, Error = Error> +//! + Unpin +//! { +//! // ... +//! # round_based::_docs::fake_delivery() +//! } +//! +//! let delivery = connect().await; +//! let party = round_based::mpc::connected(delivery); +//! +//! # let (i, n) = (1, 3); +//! let keyshare = keygen(party, i, n).await?; +//! # Ok(()) } +//! ``` + +use crate::{ + round::{RoundInfo, RoundStore}, + Outgoing, PartyIndex, +}; + +pub mod party; + +#[doc(no_inline)] +pub use self::party::{Halves, MpcParty}; + +/// Abstracts functionalities needed for creating an MPC protocol execution. +/// +/// An object implementing this trait is accepted as a parameter of a protocol. It is used to +/// configure the protocol with [`Mpc::add_round`], and then finalized into a protocol executor +/// with [`Mpc::finish_setup`] +pub trait Mpc { + /// Protocol message + type Msg; + + /// A type of a finalized instatiation of a party for this protocol + /// + /// After being created with [`Mpc::finish_setup`], you can use an object with this type to + /// drive the protocol execution using the methods of [`MpcExecution`]. + type Exec: MpcExecution; + /// Error indicating that sending a message has failed + type SendErr; + + /// Registers a round + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg; + + /// Completes network setup + /// + /// Once this method is called, no more rounds can be added, + /// but the protocol can receive and send messages. + fn finish_setup(self) -> Self::Exec; +} + +/// Abstracts functionalities needed for MPC protocol execution +pub trait MpcExecution { + /// Witness that round was registered + /// + /// It's obtained by registering round in [`Mpc::add_round`], which then can be used to retrieve + /// messages from associated round by calling [`MpcExecution::complete`]. + type Round; + + /// Protocol message + type Msg; + + /// Error indicating that completing a round has failed + type CompleteRoundErr; + /// Error indicating that sending a message has failed + type SendErr; + + /// Returned by [`.send_many()`](Self::send_many) + type SendMany: SendMany; + + /// Completes the round + /// + /// Waits until we receive all the messages in the round `R` from other parties. Returns + /// received messages. + async fn complete( + &mut self, + round: Self::Round, + ) -> Result> + where + R: RoundInfo, + Self::Msg: RoundMsg; + + /// Sends a message + /// + /// This method awaits until the message is sent, which might be not the best method to use if you + /// need to send many messages at once. If it's the case, prefer using [`.send_many()`](Self::send_many). + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; + + /// Sends a p2p message to another party + /// + /// Note: when you send many messages at once (it's most likely the case when you send a p2p message), this method + /// is not efficient, prefer using [`.send_many()`](Self::send_many). + async fn send_p2p( + &mut self, + recipient: PartyIndex, + msg: Self::Msg, + ) -> Result<(), Self::SendErr> { + self.send(Outgoing::p2p(recipient, msg)).await + } + + /// Sends a message that will be received by all parties + /// + /// Message will be broadcasted, but not reliably. If you need a reliable broadcast, use + /// [`MpcExecution::reliably_broadcast`] method. + async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::all_parties(msg)).await + } + + /// Reliably broadcasts a message + /// + /// Message will be received by all participants of the protocol. Moreover, when recipient receives a + /// message, it will be assured (cryptographically or through other trust assumptions) that all honest + /// participants of the protocol received the same message. + /// + /// It's a responsibility of a message delivery layer to provide the reliable broadcast mechanism. If + /// it's not supported, this method returns an error. Note that not every MPC protocol requires the + /// reliable broadcast, so it's totally normal to have a message delivery implementation that does + /// not support it. + async fn reliably_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::reliable_broadcast(msg)).await + } + + /// Creates a buffer of outgoing messages so they can be sent all at once + /// + /// When you have many messages that you want to send at once, using [`.send()`](Self::send) + /// may be inefficient, as delivery implementation may pause the execution until the message is + /// received by the recipient. Use this method to send many messages in a batch. + /// + /// This method takes ownership of `self` to create the [impl SendMany](SendMany) object. After + /// enqueueing all the messages, you need to reclaim `self` back by calling [`SendMany::flush`]. + fn send_many(self) -> Self::SendMany; + + /// Yields execution back to the async runtime + /// + /// Used in MPC protocols with many heavy synchronous computations. The protocol implementors + /// can manually insert yield points to ease the CPU contention + async fn yield_now(&self); +} + +/// Buffer, optimized for sending many messages at once +/// +/// It's obtained by calling [`MpcExecution::send_many`], which takes ownership of `MpcExecution`. To reclaim +/// ownership, call [`SendMany::flush`]. +pub trait SendMany { + /// MPC executor, returned after successful [`.flush()`](Self::flush) + type Exec: MpcExecution; + /// Protocol message + type Msg; + /// Error indicating that sending a message has failed + type SendErr; + + /// Adds a message to the sending queue + /// + /// Similar to [`MpcExecution::send`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr>; + + /// Adds a p2p message to the sending queue + /// + /// Similar to [`MpcExecution::send_p2p`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). + async fn send_p2p( + &mut self, + recipient: PartyIndex, + msg: Self::Msg, + ) -> Result<(), Self::SendErr> { + self.send(Outgoing::p2p(recipient, msg)).await + } + + /// Adds a broadcast message to the sending queue + /// + /// Similar to [`MpcExecution::send_to_all`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`.flush()`](Self::flush). + async fn send_to_all(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::all_parties(msg)).await + } + + /// Adds a reliable broadcast message to the sending queue + /// + /// Similar to [`MpcExecution::reliably_broadcast`], but possibly buffers a message until [`.flush()`](Self::flush) is + /// called. + /// + /// A call to this function may send the message, but this is not guaranteed by the API. To + /// flush the sending queue and send all messages, use [`Self::flush`] + async fn reliably_broadcast(&mut self, msg: Self::Msg) -> Result<(), Self::SendErr> { + self.send(Outgoing::reliable_broadcast(msg)).await + } + + /// Flushes internal buffer by sending all messages in the queue + async fn flush(self) -> Result; +} + +/// Alias to `<::Exec as MpcExecution>::CompleteRoundErr` +pub type CompleteRoundErr = <::Exec as MpcExecution>::CompleteRoundErr; + +/// Message of MPC protocol +/// +/// MPC protocols typically consist of several rounds, each round has differently typed message. +/// `ProtocolMsg` and [`RoundMsg`] traits are used to examine received message: `ProtocolMsg::round` +/// determines which round message belongs to, and then `RoundMsg` trait can be used to retrieve +/// actual round-specific message. +/// +/// You should derive these traits using proc macro (requires `derive` feature): +/// ```rust +/// use round_based::ProtocolMsg; +/// +/// #[derive(ProtocolMsg)] +/// pub enum Message { +/// Round1(Msg1), +/// Round2(Msg2), +/// // ... +/// } +/// +/// pub struct Msg1 { /* ... */ } +/// pub struct Msg2 { /* ... */ } +/// ``` +/// +/// This desugars into: +/// +/// ```rust +/// use round_based::{ProtocolMsg, RoundMsg}; +/// +/// pub enum Message { +/// Round1(Msg1), +/// Round2(Msg2), +/// // ... +/// } +/// +/// pub struct Msg1 { /* ... */ } +/// pub struct Msg2 { /* ... */ } +/// +/// impl ProtocolMsg for Message { +/// fn round(&self) -> u16 { +/// match self { +/// Message::Round1(_) => 1, +/// Message::Round2(_) => 2, +/// // ... +/// } +/// } +/// } +/// impl RoundMsg for Message { +/// const ROUND: u16 = 1; +/// fn to_protocol_msg(round_msg: Msg1) -> Self { +/// Message::Round1(round_msg) +/// } +/// fn from_protocol_msg(protocol_msg: Self) -> Result { +/// match protocol_msg { +/// Message::Round1(msg) => Ok(msg), +/// msg => Err(msg), +/// } +/// } +/// } +/// impl RoundMsg for Message { +/// const ROUND: u16 = 2; +/// fn to_protocol_msg(round_msg: Msg2) -> Self { +/// Message::Round2(round_msg) +/// } +/// fn from_protocol_msg(protocol_msg: Self) -> Result { +/// match protocol_msg { +/// Message::Round2(msg) => Ok(msg), +/// msg => Err(msg), +/// } +/// } +/// } +/// ``` +pub trait ProtocolMsg: Sized { + /// Number of the round that this message originates from + fn round(&self) -> u16; +} + +/// Round message +/// +/// See [`ProtocolMsg`] trait documentation. +pub trait RoundMsg: ProtocolMsg { + /// Number of the round this message belongs to + const ROUND: u16; + + /// Converts round message into protocol message (never fails) + fn to_protocol_msg(round_msg: M) -> Self; + /// Extracts round message from protocol message + /// + /// Returns `Err(protocol_message)` if `protocol_message.round() != Self::ROUND`, otherwise + /// returns `Ok(round_message)` + fn from_protocol_msg(protocol_msg: Self) -> Result; +} + +/// Construct an [`MpcParty`] that can be used to carry out MPC protocol +/// +/// Accepts a channels with incoming and outgoing messages. +/// +/// Alias to [`MpcParty::connected`] +pub fn connected(delivery: D) -> MpcParty +where + M: ProtocolMsg + 'static, + D: futures_util::Stream, D::Error>> + Unpin, + D: futures_util::Sink> + Unpin, +{ + MpcParty::connected(delivery) +} + +/// Construct an [`MpcParty`] that can be used to carry out MPC protocol +/// +/// Accepts separately a channel for incoming and a channel for outgoing messages. +/// +/// Alias to [`MpcParty::connected_halves`] +pub fn connected_halves(incomings: In, outgoings: Out) -> MpcParty> +where + M: ProtocolMsg + 'static, + In: futures_util::Stream, Out::Error>> + Unpin, + Out: futures_util::Sink> + Unpin, +{ + MpcParty::connected_halves(incomings, outgoings) +} diff --git a/round-based/src/mpc/party/mod.rs b/round-based/src/mpc/party/mod.rs new file mode 100644 index 0000000..866b0f7 --- /dev/null +++ b/round-based/src/mpc/party/mod.rs @@ -0,0 +1,319 @@ +//! Provides [`MpcParty`], default engine for MPC protocol execution that implements [`Mpc`] and [`MpcExecution`] traits + +use futures_util::{Sink, SinkExt, Stream, StreamExt}; + +use crate::{ + round::{RoundInfo, RoundStore}, + Incoming, Outgoing, +}; + +use super::{Mpc, MpcExecution, ProtocolMsg, RoundMsg}; + +mod router; +pub mod runtime; + +pub use self::router::{errors::RouterError, Round}; +#[doc(no_inline)] +pub use self::runtime::AsyncRuntime; + +/// MPC engine, carries out the protocol +/// +/// Can be constructed via [`MpcParty::connected`] or [`MpcParty::connected_halves`], which wraps +/// a channel of incoming and outgoing messages, and implements additional logic on top of this +/// to facilitate the MPC protocol execution, such as routing incoming messages between round +/// stores. +/// +/// Implements [`Mpc`] and [`MpcExecution`]. +pub struct MpcParty { + router: router::RoundsRouter, + io: D, + runtime: R, +} + +impl MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, +{ + /// Constructs [`MpcParty`] + pub fn connected(delivery: D) -> Self { + Self { + router: router::RoundsRouter::new(), + io: delivery, + runtime: runtime::DefaultRuntime::default(), + } + } +} + +impl MpcParty> +where + M: ProtocolMsg + 'static, + In: Stream, E>> + Unpin, + Out: Sink, Error = E> + Unpin, +{ + /// Constructs [`MpcParty`] + pub fn connected_halves(incomings: In, outgoings: Out) -> Self { + Self::connected(Halves::new(incomings, outgoings)) + } +} + +impl MpcParty { + /// Changes which async runtime to use + pub fn with_runtime(self, runtime: R) -> MpcParty { + MpcParty { + router: self.router, + io: self.io, + runtime, + } + } +} + +impl Mpc for MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Msg = M; + + type Exec = MpcParty; + + type SendErr = E; + + fn add_round(&mut self, round: R) -> ::Round + where + R: RoundStore, + Self::Msg: RoundMsg, + { + self.router.add_round(round) + } + + fn finish_setup(self) -> Self::Exec { + MpcParty { + router: self.router, + io: self.io, + runtime: self.runtime, + } + } +} + +impl MpcExecution for MpcParty +where + M: ProtocolMsg + 'static, + D: Stream, IoErr>> + Unpin, + D: Sink, Error = IoErr> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Round = router::Round; + type Msg = M; + type CompleteRoundErr = CompleteRoundError; + type SendErr = IoErr; + type SendMany = SendMany; + + async fn complete( + &mut self, + mut round: Self::Round, + ) -> Result> + where + R: RoundInfo, + Self::Msg: RoundMsg, + { + // Check if round is already completed + round = match self.router.complete_round(round) { + Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})), + Err(w) => w, + }; + + // Round is not completed - we need more messages + loop { + let incoming = self + .io + .next() + .await + .ok_or(CompleteRoundError::UnexpectedEof)? + .map_err(CompleteRoundError::Io)?; + self.router.received_msg(incoming)?; + + // Check if round was just completed + round = match self.router.complete_round(round) { + Ok(output) => return output.map_err(|e| e.map_io_err(|e| match e {})), + Err(w) => w, + }; + } + } + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.io.send(msg).await + } + + fn send_many(self) -> Self::SendMany { + SendMany { party: self } + } + + async fn yield_now(&self) { + self.runtime.yield_now().await + } +} + +/// Returned by [`MpcParty::send_many()`] +pub struct SendMany { + party: MpcParty, +} + +impl super::SendMany for SendMany +where + M: ProtocolMsg + 'static, + D: Stream, E>> + Unpin, + D: Sink, Error = E> + Unpin, + AsyncR: runtime::AsyncRuntime, +{ + type Exec = MpcParty; + type Msg = as Mpc>::Msg; + type SendErr = as Mpc>::SendErr; + + async fn send(&mut self, msg: Outgoing) -> Result<(), Self::SendErr> { + self.party.io.feed(msg).await + } + + async fn flush(mut self) -> Result { + self.party.io.flush().await?; + Ok(self.party) + } +} + +pin_project_lite::pin_project! { + /// Merges a stream and a sink into one structure that implements both [`Stream`] and [`Sink`] + pub struct Halves { + #[pin] + incomings: In, + #[pin] + outgoings: Out, + } +} + +impl Halves { + /// Constructs `Halves` + pub fn new(incomings: In, outgoings: Out) -> Self { + Self { + incomings, + outgoings, + } + } + + /// Deconstructs back into halves + pub fn into_inner(self) -> (In, Out) { + (self.incomings, self.outgoings) + } +} + +impl Stream for Halves +where + In: Stream>, +{ + type Item = Result; + + fn poll_next( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.incomings.poll_next(cx) + } +} + +impl Sink for Halves +where + Out: Sink, +{ + type Error = E; + + fn poll_ready( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_ready(cx) + } + fn start_send(self: core::pin::Pin<&mut Self>, item: M) -> Result<(), Self::Error> { + let this = self.project(); + this.outgoings.start_send(item) + } + fn poll_flush( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_flush(cx) + } + fn poll_close( + self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + let this = self.project(); + this.outgoings.poll_close(cx) + } +} + +/// Error returned by [`MpcParty::complete`] +/// +/// May indicate malicious behavior (e.g. adversary sent a message that aborts protocol execution) +/// or some misconfiguration of the protocol network (e.g. received a message from the round that +/// was not registered via [`Mpc::add_round`]). +#[derive(Debug, thiserror::Error)] +pub enum CompleteRoundError { + /// [`RoundStore`] returned an error + /// + /// Refer to this rounds store documentation to understand why it could fail + #[error(transparent)] + ProcessMsg(ProcessErr), + + /// Router error + /// + /// Indicates that for some reason router was not able to process a message. This can be the case of: + /// - Router API misuse \ + /// E.g. when received a message from the round that was not registered in the router + /// - Improper [`RoundStore`] implementation \ + /// Indicates that round store is not properly implemented and contains a flaw. \ + /// For instance, this error is returned when round store indicates that it doesn't need + /// any more messages ([`RoundStore::wants_more`] + /// returns `false`), but then it didn't output anything ([`RoundStore::output`] + /// returns `Err(_)`) + /// - Bug in the router + /// + /// This error is always related to some implementation flaw or bug: either in the code that uses + /// the router, or in the round store implementation, or in the router itself. When implementation + /// is correct, this error never appears. Thus, it should not be possible for the adversary to "make + /// this error happen." + Router(router::errors::RouterError), + + /// Receiving the next message resulted into I/O error + Io(IoErr), + /// Channel of incoming messages was closed before protocol completion + UnexpectedEof, +} + +impl CompleteRoundError { + /// Maps I/O error + pub fn map_io_err(self, f: impl FnOnce(IoErr) -> E) -> CompleteRoundError { + match self { + CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(e), + CompleteRoundError::Router(e) => CompleteRoundError::Router(e), + CompleteRoundError::Io(e) => CompleteRoundError::Io(f(e)), + CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof, + } + } + /// Maps [`CompleteRoundError::ProcessMsg`] + pub fn map_process_err( + self, + f: impl FnOnce(ProcessErr) -> E, + ) -> CompleteRoundError { + match self { + CompleteRoundError::ProcessMsg(e) => CompleteRoundError::ProcessMsg(f(e)), + CompleteRoundError::Router(e) => CompleteRoundError::Router(e), + CompleteRoundError::Io(e) => CompleteRoundError::Io(e), + CompleteRoundError::UnexpectedEof => CompleteRoundError::UnexpectedEof, + } + } +} diff --git a/round-based/src/mpc/party/router.rs b/round-based/src/mpc/party/router.rs new file mode 100644 index 0000000..f1e5d73 --- /dev/null +++ b/round-based/src/mpc/party/router.rs @@ -0,0 +1,446 @@ +//! Routes incoming MPC messages between rounds +//! +//! Router is a building block, used in MpcParty to register rounds and route +//! incoming messages between them + +use alloc::{boxed::Box, collections::BTreeMap}; +use core::{any::Any, convert::Infallible, mem}; + +use phantom_type::PhantomType; +use tracing::{error, trace_span, warn}; + +use crate::{ + round::{RoundInfo, RoundStore}, + Incoming, ProtocolMsg, RoundMsg, +}; + +/// Routes received messages between protocol rounds +pub struct RoundsRouter { + rounds: BTreeMap>>>, +} + +impl RoundsRouter +where + M: ProtocolMsg + 'static, +{ + pub fn new() -> Self { + Self { + rounds: Default::default(), + } + } + + /// Registers new round + /// + /// ## Panics + /// Panics if round `R` was already registered + pub fn add_round(&mut self, message_store: R) -> Round + where + R: RoundStore, + M: RoundMsg, + { + let overridden_round = self.rounds.insert( + M::ROUND, + Some(Box::new(ProcessRoundMessageImpl::new(message_store))), + ); + if overridden_round.is_some() { + panic!("round {} is overridden", M::ROUND); + } + Round { + _ph: PhantomType::new(), + } + } + + pub fn received_msg(&mut self, incoming: Incoming) -> Result<(), errors::UnregisteredRound> { + let msg_round_n = incoming.msg.round(); + let span = trace_span!( + "Round::received_msg", + round = %msg_round_n, + sender = %incoming.sender, + ty = ?incoming.msg_type + ); + let _guard = span.enter(); + + let message_round = match self.rounds.get_mut(&msg_round_n) { + Some(Some(round)) => round, + Some(None) => { + warn!("got message for the round that was already completed, ignoring it"); + return Ok(()); + } + None => { + return Err(errors::UnregisteredRound { + n: msg_round_n, + witness_provided: false, + }) + } + }; + if message_round.needs_more_messages().no() { + warn!("received message for the round that was already completed, ignoring it"); + return Ok(()); + } + message_round.process_message(incoming); + Ok(()) + } + + #[allow(clippy::type_complexity)] + pub fn complete_round( + &mut self, + round: Round, + ) -> Result>, Round> + where + R: RoundInfo, + M: RoundMsg, + { + let message_round = match self.rounds.get_mut(&M::ROUND) { + Some(Some(round)) => round, + Some(None) => { + return Ok(Err( + errors::Bug::RoundGoneButWitnessExists { n: M::ROUND }.into() + )); + } + None => { + return Ok(Err(errors::UnregisteredRound { + n: M::ROUND, + witness_provided: true, + } + .into())) + } + }; + if message_round.needs_more_messages().yes() { + return Err(round); + } + Ok(Self::retrieve_round_output::(message_round)) + } + + fn retrieve_round_output( + round: &mut Box>, + ) -> Result> + where + R: RoundInfo, + { + match round.take_output() { + Ok(Ok(any)) => Ok(*any + .downcast::() + .or(Err(errors::Bug::MismatchedOutputType))?), + Ok(Err(any)) => Err(*any + .downcast::>() + .or(Err(errors::Bug::MismatchedErrorType))?), + Err(err) => Err(errors::Bug::TakeRoundResult(err).into()), + } + } +} + +/// A witness that round has been registered in the router +/// +/// Can be used later to claim messages received in this round +pub struct Round { + _ph: PhantomType, +} + +impl core::fmt::Debug for Round { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Round").finish_non_exhaustive() + } +} + +trait ProcessRoundMessage { + type Msg; + + /// Processes round message + /// + /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`, + /// otherwise calling this method is unexpected. + fn process_message(&mut self, msg: Incoming); + + /// Indicated whether the store needs more messages + /// + /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No` + /// then you need to take the round output by calling `.take_output()`. + fn needs_more_messages(&self) -> NeedsMoreMessages; + + /// Tries to obtain round output + /// + /// Can be called once `process_message()` returned `NeedMoreMessages::No`. + /// + /// Returns: + /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output` + /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError` + /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`] + #[allow(clippy::type_complexity)] + fn take_output(&mut self) -> Result, Box>, TakeOutputError>; +} + +#[derive(Debug, thiserror::Error)] +enum TakeOutputError { + #[error("output is already taken")] + AlreadyTaken, + #[error("output is not ready yet, more messages are needed")] + NotReady, +} + +enum ProcessRoundMessageImpl> { + InProgress { store: S, _ph: PhantomType }, + Completed(Result>), + Gone, +} + +impl> ProcessRoundMessageImpl { + pub fn new(store: S) -> Self { + if store.wants_more() { + Self::InProgress { + store, + _ph: Default::default(), + } + } else { + Self::Completed( + store + .output() + .map_err(|_| errors::ImproperRoundStore::StoreDidntOutput.into()), + ) + } + } +} + +impl ProcessRoundMessageImpl +where + S: RoundStore, + M: ProtocolMsg + RoundMsg, +{ + fn _process_message( + store: &mut S, + msg: Incoming, + ) -> Result<(), errors::CompleteRoundError> { + let msg = msg.try_map(M::from_protocol_msg).map_err(|msg| { + errors::Bug::MessageFromAnotherRound { + actual_number: msg.round(), + expected_round: M::ROUND, + } + })?; + + store + .add_message(msg) + .map_err(errors::CompleteRoundError::ProcessMsg)?; + Ok(()) + } +} + +impl ProcessRoundMessage for ProcessRoundMessageImpl +where + S: RoundStore, + M: ProtocolMsg + RoundMsg, +{ + type Msg = M; + + fn process_message(&mut self, msg: Incoming) { + let store = match self { + Self::InProgress { store, .. } => store, + _ => { + return; + } + }; + + match Self::_process_message(store, msg) { + Ok(()) => { + if store.wants_more() { + return; + } + + let store = match mem::replace(self, Self::Gone) { + Self::InProgress { store, .. } => store, + _ => { + *self = Self::Completed(Err(errors::Bug::IncoherentState { + expected: "InProgress", + justification: + "we checked at beginning of the function that `state` is InProgress", + }.into())); + return; + } + }; + + match store.output() { + Ok(output) => *self = Self::Completed(Ok(output)), + Err(_err) => { + *self = Self::Completed(Err( + errors::ImproperRoundStore::StoreDidntOutput.into() + )) + } + } + } + Err(err) => { + *self = Self::Completed(Err(err)); + } + } + } + + fn needs_more_messages(&self) -> NeedsMoreMessages { + match self { + Self::InProgress { .. } => NeedsMoreMessages::Yes, + _ => NeedsMoreMessages::No, + } + } + + fn take_output(&mut self) -> Result, Box>, TakeOutputError> { + match self { + Self::InProgress { .. } => return Err(TakeOutputError::NotReady), + Self::Gone => return Err(TakeOutputError::AlreadyTaken), + _ => (), + } + match mem::replace(self, Self::Gone) { + Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))), + Self::Completed(Err(err)) => Ok(Err(Box::new(err))), + _ => unreachable!("it's checked to be completed"), + } + } +} + +enum NeedsMoreMessages { + Yes, + No, +} + +#[allow(dead_code)] +impl NeedsMoreMessages { + pub fn yes(&self) -> bool { + matches!(self, Self::Yes) + } + pub fn no(&self) -> bool { + matches!(self, Self::No) + } +} + +/// When something goes wrong +pub mod errors { + pub use crate::mpc::party::CompleteRoundError; + + use super::TakeOutputError; + + #[derive(Debug, thiserror::Error)] + #[error("received a message for unregistered round")] + pub(in crate::mpc) struct UnregisteredRound { + pub n: u16, + pub(super) witness_provided: bool, + } + + /// Router error + /// + /// Refer to [`CompleteRoundError::Router`] docs + #[derive(Debug, thiserror::Error)] + #[error(transparent)] + pub struct RouterError(Reason); + + #[derive(Debug, thiserror::Error)] + pub(super) enum Reason { + /// Router API has been misused + /// + /// For instance, this error is returned when protocol implementation does not register + /// certain round of the protocol, but then a message from this round is received. In + /// this case, router doesn't have anywhere to route the message to, so an [`ApiMisuse`] + /// error is returned. + #[error("api misuse")] + ApiMisuse(#[source] ApiMisuse), + /// Improper [`RoundStore`](crate::round::RoundStore) implementation + /// + /// For instance, this error is returned when round store indicates that it doesn't need + /// any more messages ([`RoundStore::wants_more`](crate::round::RoundStore::wants_more) + /// returns `false`), but then it didn't output anything ([`RoundStore::output`](crate::round::RoundStore::output) + /// returns `Err(_)`) + #[error("improper round store")] + ImproperRoundStore(#[source] ImproperRoundStore), + /// Indicates that there's a bug in the router implementation + #[error("bug (please, open an issue)")] + Bug(#[source] Bug), + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum ApiMisuse { + #[error(transparent)] + UnregisteredRound(#[from] UnregisteredRound), + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum ImproperRoundStore { + /// Store indicated that it received enough messages but didn't output + /// + /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`. + #[error("store didn't output")] + StoreDidntOutput, + } + + #[derive(Debug, thiserror::Error)] + pub(super) enum Bug { + #[error("round is gone, but witness exists")] + RoundGoneButWitnessExists { n: u16 }, + #[error( + "message originates from another round: we process messages from round \ + {expected_round}, got message from round {actual_number}" + )] + MessageFromAnotherRound { + expected_round: u16, + actual_number: u16, + }, + #[error("state is incoherent, it's expected to be {expected}: {justification}")] + IncoherentState { + expected: &'static str, + justification: &'static str, + }, + #[error("take round result")] + TakeRoundResult(#[source] TakeOutputError), + #[error("mismatched output type")] + MismatchedOutputType, + #[error("mismatched error type")] + MismatchedErrorType, + } + + macro_rules! impl_round_complete_from { + ($(|$err:ident: $err_ty:ty| $err_fn:expr),+$(,)?) => {$( + impl From<$err_ty> for CompleteRoundError { + fn from($err: $err_ty) -> Self { + $err_fn + } + } + )+}; + } + + impl_round_complete_from! { + |err: ApiMisuse| CompleteRoundError::Router(RouterError(Reason::ApiMisuse(err))), + |err: ImproperRoundStore| CompleteRoundError::Router(RouterError(Reason::ImproperRoundStore(err))), + |err: Bug| CompleteRoundError::Router(RouterError(Reason::Bug(err))), + |err: UnregisteredRound| ApiMisuse::UnregisteredRound(err).into(), + } +} + +#[cfg(test)] +mod tests { + struct Store; + + #[derive(crate::ProtocolMsg)] + #[protocol_msg(root = crate)] + enum FakeProtocolMsg { + R1(Msg1), + } + struct Msg1; + + impl crate::round::RoundInfo for Store { + type Msg = Msg1; + type Output = (); + type Error = core::convert::Infallible; + } + impl crate::round::RoundStore for Store { + fn add_message(&mut self, _msg: crate::Incoming) -> Result<(), Self::Error> { + Ok(()) + } + fn wants_more(&self) -> bool { + false + } + fn output(self) -> Result { + Ok(()) + } + } + + #[test] + fn complete_round_that_expects_no_messages() { + let mut rounds = super::RoundsRouter::::new(); + let round1 = rounds.add_round(Store); + + rounds.complete_round(round1).unwrap().unwrap(); + } +} diff --git a/round-based/src/runtime.rs b/round-based/src/mpc/party/runtime.rs similarity index 74% rename from round-based/src/runtime.rs rename to round-based/src/mpc/party/runtime.rs index 7daa68e..3331419 100644 --- a/round-based/src/runtime.rs +++ b/round-based/src/mpc/party/runtime.rs @@ -8,15 +8,12 @@ /// Abstracts async runtime like [tokio]. Currently only exposes a [yield_now](Self::yield_now) /// function. pub trait AsyncRuntime { - /// Future type returned by [yield_now](Self::yield_now) - type YieldNowFuture: core::future::Future; - /// Yields the execution back to the runtime /// /// If the protocol performs a long computation, it might be better for performance - /// to split it with yield points, so the signle computation does not starve other + /// to split it with yield points, so the single computation does not starve other /// tasks. - fn yield_now(&self) -> Self::YieldNowFuture; + async fn yield_now(&self); } /// [Tokio](tokio)-specific async runtime @@ -26,11 +23,8 @@ pub struct TokioRuntime; #[cfg(feature = "runtime-tokio")] impl AsyncRuntime for TokioRuntime { - type YieldNowFuture = - core::pin::Pin + Send>>; - - fn yield_now(&self) -> Self::YieldNowFuture { - alloc::boxed::Box::pin(tokio::task::yield_now()) + async fn yield_now(&self) { + tokio::task::yield_now().await } } @@ -47,7 +41,7 @@ pub type DefaultRuntime = TokioRuntime; pub type DefaultRuntime = UnknownRuntime; /// Unknown async runtime -pub mod unknown_runtime { +mod unknown_runtime { /// Unknown async runtime /// /// Tries to implement runtime features using generic futures code. It's better to use @@ -56,15 +50,13 @@ pub mod unknown_runtime { pub struct UnknownRuntime; impl super::AsyncRuntime for UnknownRuntime { - type YieldNowFuture = YieldNow; - - fn yield_now(&self) -> Self::YieldNowFuture { - YieldNow(false) + async fn yield_now(&self) { + YieldNow(false).await } } /// Future for the `yield_now` function. - pub struct YieldNow(bool); + struct YieldNow(bool); impl core::future::Future for YieldNow { type Output = (); diff --git a/round-based/src/party.rs b/round-based/src/party.rs deleted file mode 100644 index 6ab2199..0000000 --- a/round-based/src/party.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Party of MPC protocol -//! -//! [`MpcParty`] is party of MPC protocol, connected to network, ready to start carrying out the protocol. -//! -//! ```rust -//! use round_based::{Mpc, MpcParty, Delivery, PartyIndex}; -//! -//! # struct KeygenMsg; -//! # struct KeyShare; -//! # struct Error; -//! # type Result = std::result::Result; -//! # async fn doc() -> Result<()> { -//! async fn keygen(party: M, i: PartyIndex, n: u16) -> Result -//! where -//! M: Mpc -//! { -//! // ... -//! # unimplemented!() -//! } -//! async fn connect() -> impl Delivery { -//! // ... -//! # round_based::_docs::fake_delivery() -//! } -//! -//! let delivery = connect().await; -//! let party = MpcParty::connected(delivery); -//! -//! # let (i, n) = (1, 3); -//! let keyshare = keygen(party, i, n).await?; -//! # Ok(()) } -//! ``` - -use phantom_type::PhantomType; - -use crate::delivery::Delivery; -use crate::runtime::{self, AsyncRuntime}; - -/// Party of MPC protocol (trait) -/// -/// [`MpcParty`] is the only struct that implement this trait. Motivation to have this trait is to fewer amount of -/// generic bounds that are needed to be specified. -/// -/// Typical usage of this trait when implementing MPC protocol: -/// -/// ```rust -/// use round_based::{Mpc, MpcParty, PartyIndex}; -/// -/// # struct Msg; -/// async fn keygen(party: M, i: PartyIndex, n: u16) -/// where -/// M: Mpc -/// { -/// let MpcParty{ delivery, .. } = party.into_party(); -/// // ... -/// } -/// ``` -/// -/// If we didn't have this trait, generics would be less readable: -/// ```rust -/// use round_based::{MpcParty, Delivery, runtime::AsyncRuntime, PartyIndex}; -/// -/// # struct Msg; -/// async fn keygen(party: MpcParty, i: PartyIndex, n: u16) -/// where -/// D: Delivery, -/// R: AsyncRuntime -/// { -/// // ... -/// } -/// ``` -pub trait Mpc: internal::Sealed { - /// MPC message - type ProtocolMessage; - /// Transport layer implementation - type Delivery: Delivery< - Self::ProtocolMessage, - SendError = Self::SendError, - ReceiveError = Self::ReceiveError, - >; - /// Async runtime - type Runtime: AsyncRuntime; - - /// Sending message error - type SendError: core::error::Error + Send + Sync + 'static; - /// Receiving message error - type ReceiveError: core::error::Error + Send + Sync + 'static; - - /// Converts into [`MpcParty`] - fn into_party(self) -> MpcParty; -} - -mod internal { - pub trait Sealed {} -} - -/// Party of MPC protocol -#[non_exhaustive] -pub struct MpcParty { - /// Defines transport layer - pub delivery: D, - /// Defines how computationally heavy tasks should be handled - pub runtime: R, - _msg: PhantomType, -} - -impl MpcParty -where - D: Delivery, -{ - /// Party connected to the network - /// - /// Takes the delivery object determining how to deliver/receive other parties' messages - pub fn connected(delivery: D) -> Self { - Self { - delivery, - runtime: Default::default(), - _msg: PhantomType::new(), - } - } -} - -impl MpcParty -where - D: Delivery, -{ - /// Modify the delivery of this party while keeping everything else the same - pub fn map_delivery(self, f: impl FnOnce(D) -> D2) -> MpcParty { - let delivery = f(self.delivery); - MpcParty { - delivery, - runtime: self.runtime, - _msg: self._msg, - } - } - - /// Modify the runtime of this party while keeping everything else the same - pub fn map_runtime(self, f: impl FnOnce(X) -> R) -> MpcParty { - let runtime = f(self.runtime); - MpcParty { - delivery: self.delivery, - runtime, - _msg: self._msg, - } - } - - /// Specifies a [async runtime](runtime) - pub fn set_runtime(self, runtime: R) -> MpcParty - where - R: AsyncRuntime, - { - MpcParty { - delivery: self.delivery, - runtime, - _msg: self._msg, - } - } -} - -impl internal::Sealed for MpcParty {} - -impl Mpc for MpcParty -where - D: Delivery, - D::SendError: core::error::Error + Send + Sync + 'static, - D::ReceiveError: core::error::Error + Send + Sync + 'static, - R: AsyncRuntime, -{ - type ProtocolMessage = M; - type Delivery = D; - type Runtime = R; - - type SendError = D::SendError; - type ReceiveError = D::ReceiveError; - - fn into_party(self) -> MpcParty { - self - } -} diff --git a/round-based/src/round/mod.rs b/round-based/src/round/mod.rs new file mode 100644 index 0000000..1684edd --- /dev/null +++ b/round-based/src/round/mod.rs @@ -0,0 +1,188 @@ +//! Primitives that process and collect messages received at certain round + +use core::any::Any; + +use crate::Incoming; + +pub use self::simple_store::{ + broadcast, p2p, reliable_broadcast, RoundInput, RoundInputError, RoundMsgs, +}; + +mod simple_store; + +/// Common information about a round +pub trait RoundInfo: Sized + 'static { + /// Message type + type Msg; + /// Store output (e.g. `Vec<_>` of received messages) + type Output; + /// Store error + type Error: core::error::Error; +} + +/// Stores messages received at particular round +/// +/// In MPC protocol, party at every round usually needs to receive up to `n` messages. `RoundsStore` +/// is a container that stores messages, it knows how many messages are expected to be received, +/// and should implement extra measures against malicious parties (e.g. prohibit message overwrite). +/// +/// ## Flow +/// `RoundStore` stores received messages. Once enough messages are received, it outputs [`RoundInfo::Output`]. +/// In order to save received messages, [`.add_message(msg)`] is called. Then, [`.wants_more()`] tells whether more +/// messages are needed to be received. If it returned `false`, then output can be retrieved by calling [`.output()`]. +/// +/// [`.add_message(msg)`]: Self::add_message +/// [`.wants_more()`]: Self::wants_more +/// [`.output()`]: Self::output +/// +/// ## Example +/// [`RoundInput`] is an simple messages store. Refer to its docs to see usage examples. +pub trait RoundStore: RoundInfo { + /// Adds received message to the store + /// + /// Returns error if message cannot be processed. Usually it means that sender behaves maliciously. + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error>; + /// Indicates if store expects more messages to receive + fn wants_more(&self) -> bool; + /// Retrieves store output if enough messages are received + /// + /// Returns `Err(self)` if more message are needed to be received. + /// + /// If store indicated that it needs no more messages (ie `store.wants_more() == false`), then + /// this function must return `Ok(_)`. + fn output(self) -> Result; + + /// Interface that exposes ability to retrieve generic information about the round store + /// + /// For reading store properties, it's recommended to use [`RoundStoreExt::read_prop`] method which + /// uses this function internally. + /// + /// When implementing `RoundStore` trait, if you wish to expose no extra information, leave the default + /// implementation of this method. If you do want to expose certain properties that will be accessible + /// through [`RoundStoreExt::read_prop`], follow this example: + /// + /// ```rust + /// pub struct MyStore { /* ... */ } + /// + /// #[derive(Debug, PartialEq, Eq)] + /// pub struct SomePropertyWeWantToExpose { value: u64 } + /// #[derive(Debug, PartialEq, Eq)] + /// pub struct AnotherProperty(String); + /// + /// # type Msg = (); + /// # impl round_based::round::RoundInfo for MyStore { + /// # type Msg = Msg; + /// # type Output = Vec; + /// # type Error = core::convert::Infallible; + /// # } + /// impl round_based::round::RoundStore for MyStore { + /// # fn add_message(&mut self, msg: round_based::Incoming) -> Result<(), Self::Error> { unimplemented!() } + /// # fn wants_more(&self) -> bool { unimplemented!() } + /// # fn output(self) -> Result { unimplemented!() } + /// // ... + /// + /// fn read_any_prop(&self, property: &mut dyn core::any::Any) { + /// if let Some(p) = property.downcast_mut::>() { + /// *p = Some(SomePropertyWeWantToExpose { value: 42 }) + /// } else if let Some(p) = property.downcast_mut::>() { + /// *p = Some(AnotherProperty("here we return a string".to_owned())) + /// } + /// } + /// } + /// + /// // Which then can be accessed via `.read_prop()` method: + /// use round_based::round::RoundStoreExt; + /// let store = MyStore { /* ... */ }; + /// assert_eq!( + /// store.read_prop::(), + /// Some(SomePropertyWeWantToExpose { value: 42 }), + /// ); + /// assert_eq!( + /// store.read_prop::(), + /// Some(AnotherProperty("here we return a string".to_owned())), + /// ); + /// ``` + fn read_any_prop(&self, property: &mut dyn Any) { + let _ = property; + } +} + +/// Extra functionalities defined for any [`RoundStore`] +pub trait RoundStoreExt: RoundStore { + /// Reads a property `P` of the store + /// + /// Returns `Some(property_value)` if this store exposes property `P`, otherwise returns `None` + fn read_prop(&self) -> Option

; + + /// Constructs a new store that exposes property `P` with provided value + /// + /// If store already provides a property `P`, it will be overwritten + fn set_prop(self, value: P) -> WithProp; +} + +impl RoundStoreExt for S { + fn read_prop(&self) -> Option

{ + let mut p: Option

= None; + self.read_any_prop(&mut p); + p + } + + fn set_prop(self, value: P) -> WithProp { + WithProp { + prop: value, + store: self, + } + } +} + +/// Returned by [`RoundStoreExt::set_prop`] +pub struct WithProp { + prop: P, + store: S, +} + +impl RoundInfo for WithProp +where + S: RoundInfo, + P: 'static, +{ + type Msg = S::Msg; + type Output = S::Output; + type Error = S::Error; +} + +impl RoundStore for WithProp +where + S: RoundStore, + P: Clone + 'static, +{ + #[inline(always)] + fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { + self.store.add_message(msg) + } + #[inline(always)] + fn wants_more(&self) -> bool { + self.store.wants_more() + } + #[inline(always)] + fn output(self) -> Result { + self.store.output().map_err(|store| Self { + prop: self.prop, + store, + }) + } + + fn read_any_prop(&self, property: &mut dyn Any) { + if let Some(p) = property.downcast_mut::>() { + *p = Some(self.prop.clone()) + } else { + self.store.read_any_prop(property); + } + } +} + +/// Properties that may be exposed by [`RoundStore`] +pub mod props { + /// Indicates whether the round requires messages to be reliably broadcasted + pub struct RequiresReliableBroadcast(pub bool); +} diff --git a/round-based/src/rounds_router/simple_store.rs b/round-based/src/round/simple_store.rs similarity index 80% rename from round-based/src/rounds_router/simple_store.rs rename to round-based/src/round/simple_store.rs index 265dd80..cd550b1 100644 --- a/round-based/src/rounds_router/simple_store.rs +++ b/round-based/src/round/simple_store.rs @@ -1,13 +1,13 @@ -//! Simple implementation of `MessagesStore` +//! Simple implementation of [`RoundStore`] use alloc::{vec, vec::Vec}; use core::iter; use crate::{Incoming, MessageType, MsgId, PartyIndex}; -use super::MessagesStore; +use super::{RoundInfo, RoundStore}; -/// Simple implementation of [MessagesStore] that waits for all parties to send a message +/// Simple implementation of [`RoundStore`] that waits for all parties to send a message /// /// Round is considered complete when the store received a message from every party. Note that the /// store will ignore all the messages such as `msg.sender == local_party_index`. @@ -16,28 +16,32 @@ use super::MessagesStore; /// /// ## Example /// ```rust -/// # use round_based::rounds_router::{MessagesStore, simple_store::RoundInput}; -/// # use round_based::{Incoming, MessageType}; +/// use round_based::{Incoming, MessageType}; +/// use round_based::round::{RoundStore, RoundInput}; +/// /// # fn main() -> Result<(), Box> { /// let mut input = RoundInput::<&'static str>::broadcast(1, 3); /// input.add_message(Incoming{ /// id: 0, /// sender: 0, -/// msg_type: MessageType::Broadcast, +/// msg_type: MessageType::Broadcast { reliable: false }, /// msg: "first party message", /// })?; /// input.add_message(Incoming{ /// id: 1, /// sender: 2, -/// msg_type: MessageType::Broadcast, +/// msg_type: MessageType::Broadcast { reliable: false }, /// msg: "third party message", /// })?; /// assert!(!input.wants_more()); /// /// let output = input.output().unwrap(); -/// assert_eq!(output.clone().into_vec_without_me(), ["first party message", "third party message"]); /// assert_eq!( -/// output.clone().into_vec_including_me("my msg"), +/// output.clone().into_vec_without_me(), +/// ["first party message", "third party message"] +/// ); +/// assert_eq!( +/// output.into_vec_including_me("my msg"), /// ["first party message", "my msg", "third party message"] /// ); /// # Ok(()) } @@ -85,9 +89,16 @@ impl RoundInput { /// Construct a new store for broadcast messages /// - /// The same as `RoundInput::new(i, n, MessageType::Broadcast)` + /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: false })` pub fn broadcast(i: PartyIndex, n: u16) -> Self { - Self::new(i, n, MessageType::Broadcast) + Self::new(i, n, MessageType::Broadcast { reliable: false }) + } + + /// Construct a new store for reliable broadcast messages + /// + /// The same as `RoundInput::new(i, n, MessageType::Broadcast { reliable: true })` + pub fn reliable_broadcast(i: PartyIndex, n: u16) -> Self { + Self::new(i, n, MessageType::Broadcast { reliable: true }) } /// Construct a new store for p2p messages @@ -97,19 +108,34 @@ impl RoundInput { Self::new(i, n, MessageType::P2P) } - fn is_expected_type_of_msg(&self, msg_type: MessageType) -> bool { - self.expected_msg_type == msg_type + fn is_expected_type_of_msg(&self, actual_msg_type: MessageType) -> bool { + matches!( + (self.expected_msg_type, actual_msg_type), + (MessageType::P2P, MessageType::P2P) + | ( + MessageType::Broadcast { reliable: false }, + MessageType::Broadcast { .. } + ) + | ( + MessageType::Broadcast { reliable: true }, + MessageType::Broadcast { reliable: true }, + ) + ) } } -impl MessagesStore for RoundInput +impl RoundInfo for RoundInput where M: 'static, { type Msg = M; type Output = RoundMsgs; type Error = RoundInputError; - +} +impl RoundStore for RoundInput +where + M: 'static, +{ fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error> { if !self.is_expected_type_of_msg(msg.msg_type) { return Err(RoundInputError::MismatchedMessageType { @@ -163,6 +189,17 @@ where }) } } + + fn read_any_prop(&self, property: &mut dyn core::any::Any) { + if let Some(p) = + property.downcast_mut::>() + { + *p = Some(crate::round::props::RequiresReliableBroadcast(matches!( + self.expected_msg_type, + MessageType::Broadcast { reliable: true } + ))); + } + } } impl RoundMsgs { @@ -251,7 +288,7 @@ impl RoundMsgs { } } -/// Error explaining why `RoundInput` wasn't able to process a message +/// Error explaining why [`RoundInput`] wasn't able to process a message #[derive(Debug, thiserror::Error)] pub enum RoundInputError { /// Party sent two messages in one round @@ -292,12 +329,31 @@ pub enum RoundInputError { }, } +/// p2p round +/// +/// Alias to [`RoundInput::p2p`] +pub fn p2p(i: u16, n: u16) -> RoundInput { + RoundInput::p2p(i, n) +} +/// Broadcast round +/// +/// Alias to [`RoundInput::broadcast`] +pub fn broadcast(i: u16, n: u16) -> RoundInput { + RoundInput::broadcast(i, n) +} +/// Reliable broadcast round +/// +/// Alias to [`RoundInput::broadcast`] +pub fn reliable_broadcast(i: u16, n: u16) -> RoundInput { + RoundInput::reliable_broadcast(i, n) +} + #[cfg(test)] mod tests { use alloc::vec::Vec; use matches::assert_matches; - use crate::rounds_router::store::MessagesStore; + use crate::round::RoundStore; use crate::{Incoming, MessageType}; use super::{RoundInput, RoundInputError}; @@ -414,11 +470,31 @@ mod tests { #[test] fn store_returns_error_if_message_type_mismatched() { let mut store = RoundInput::::p2p(3, 5); + for reliable in [true, false] { + let err = store + .add_message(Incoming { + id: 0, + sender: 0, + msg_type: MessageType::Broadcast { reliable }, + msg: Msg(1), + }) + .unwrap_err(); + assert_matches!( + err, + RoundInputError::MismatchedMessageType { + msg_id: 0, + expected: MessageType::P2P, + actual: MessageType::Broadcast { reliable: r } + } if r == reliable + ); + } + + let mut store = RoundInput::::broadcast(3, 5); let err = store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::Broadcast, + msg_type: MessageType::P2P, msg: Msg(1), }) .unwrap_err(); @@ -426,12 +502,12 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::P2P, - actual: MessageType::Broadcast + expected: MessageType::Broadcast { reliable: false }, + actual: MessageType::P2P, } ); - let mut store = RoundInput::::broadcast(3, 5); + let mut store = RoundInput::::reliable_broadcast(3, 5); let err = store .add_message(Incoming { id: 0, @@ -444,27 +520,15 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::Broadcast, + expected: MessageType::Broadcast { reliable: true }, actual: MessageType::P2P, } ); - for sender in 0u16..5 { - store - .add_message(Incoming { - id: 0, - sender, - msg_type: MessageType::Broadcast, - msg: Msg(1), - }) - .unwrap(); - } - - let mut store = RoundInput::::broadcast(3, 5); let err = store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::P2P, + msg_type: MessageType::Broadcast { reliable: false }, msg: Msg(1), }) .unwrap_err(); @@ -472,15 +536,20 @@ mod tests { err, RoundInputError::MismatchedMessageType { msg_id: 0, - expected: MessageType::Broadcast, - actual, - } if actual == MessageType::P2P + expected: MessageType::Broadcast { reliable: true }, + actual: MessageType::Broadcast { reliable: false }, + } ); + } + + #[test] + fn non_reliable_broadcast_round_accepts_reliable_broadcast_messages() { + let mut store = RoundInput::::broadcast(3, 5); store .add_message(Incoming { id: 0, sender: 0, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable: true }, msg: Msg(1), }) .unwrap(); diff --git a/round-based/src/rounds_router/mod.rs b/round-based/src/rounds_router/mod.rs deleted file mode 100644 index ad52299..0000000 --- a/round-based/src/rounds_router/mod.rs +++ /dev/null @@ -1,638 +0,0 @@ -//! Routes incoming MPC messages between rounds -//! -//! [`RoundsRouter`] is an essential building block of MPC protocol, it processes incoming messages, groups -//! them by rounds, and provides convenient API for retrieving received messages at certain round. -//! -//! ## Example -//! -//! ```rust -//! use round_based::{Mpc, MpcParty, ProtocolMessage, Delivery, PartyIndex}; -//! use round_based::rounds_router::{RoundsRouter, simple_store::{RoundInput, RoundMsgs}}; -//! -//! #[derive(ProtocolMessage)] -//! pub enum Msg { -//! Round1(Msg1), -//! Round2(Msg2), -//! } -//! -//! pub struct Msg1 { /* ... */ } -//! pub struct Msg2 { /* ... */ } -//! -//! pub async fn some_mpc_protocol(party: M, i: PartyIndex, n: u16) -> Result -//! where -//! M: Mpc, -//! { -//! let MpcParty{ delivery, .. } = party.into_party(); -//! -//! let (incomings, _outgoings) = delivery.split(); -//! -//! // Build `Rounds` -//! let mut rounds = RoundsRouter::builder(); -//! let round1 = rounds.add_round(RoundInput::::broadcast(i, n)); -//! let round2 = rounds.add_round(RoundInput::::p2p(i, n)); -//! let mut rounds = rounds.listen(incomings); -//! -//! // Receive messages from round 1 -//! let msgs: RoundMsgs = rounds.complete(round1).await?; -//! -//! // ... process received messages -//! -//! // Receive messages from round 2 -//! let msgs = rounds.complete(round2).await?; -//! -//! // ... -//! # todo!() -//! } -//! # type Output = (); -//! # type Error = Box; -//! ``` - -use alloc::{boxed::Box, collections::BTreeMap}; -use core::{any::Any, convert::Infallible, mem}; - -use futures_util::{Stream, StreamExt}; -use phantom_type::PhantomType; -use tracing::{debug, error, trace, trace_span, warn, Span}; - -use crate::Incoming; - -#[doc(inline)] -pub use self::errors::CompleteRoundError; -pub use self::store::*; - -pub mod simple_store; -mod store; - -/// Routes received messages between protocol rounds -/// -/// See [module level](self) documentation to learn more about it. -pub struct RoundsRouter { - incomings: S, - rounds: BTreeMap + Send>>>, -} - -impl RoundsRouter { - /// Instantiates [`RoundsRouterBuilder`] - pub fn builder() -> RoundsRouterBuilder { - RoundsRouterBuilder::new() - } -} - -impl RoundsRouter -where - M: ProtocolMessage, - S: Stream, E>> + Unpin, - E: core::error::Error, -{ - /// Completes specified round - /// - /// Waits until all messages at specified round are received. Returns received - /// messages if round is successfully completed, or error otherwise. - #[inline(always)] - pub async fn complete( - &mut self, - round: Round, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let round_number = >::ROUND; - let span = trace_span!("Round", n = round_number); - debug!(parent: &span, "pending round to complete"); - - match self.complete_with_span(&span, round).await { - Ok(output) => { - trace!(parent: &span, "round successfully completed"); - Ok(output) - } - Err(err) => { - error!(parent: &span, %err, "round terminated with error"); - Err(err) - } - } - } - - async fn complete_with_span( - &mut self, - span: &Span, - _round: Round, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let pending_round = >::ROUND; - if let Some(output) = self.retrieve_round_output_if_its_completed::() { - return output; - } - - loop { - let incoming = match self.incomings.next().await { - Some(Ok(msg)) => msg, - Some(Err(err)) => return Err(errors::IoError::Io(err).into()), - None => return Err(errors::IoError::UnexpectedEof.into()), - }; - let message_round_n = incoming.msg.round(); - - let message_round = match self.rounds.get_mut(&message_round_n) { - Some(Some(round)) => round, - Some(None) => { - warn!( - parent: span, - n = message_round_n, - "got message for the round that was already completed, ignoring it" - ); - continue; - } - None => { - return Err( - errors::RoundsMisuse::UnregisteredRound { n: message_round_n }.into(), - ) - } - }; - if message_round.needs_more_messages().no() { - warn!( - parent: span, - n = message_round_n, - "received message for the round that was already completed, ignoring it" - ); - continue; - } - message_round.process_message(incoming); - - if pending_round == message_round_n { - if let Some(output) = self.retrieve_round_output_if_its_completed::() { - return output; - } - } - } - } - - #[allow(clippy::type_complexity)] - fn retrieve_round_output_if_its_completed( - &mut self, - ) -> Option>> - where - R: MessagesStore, - M: RoundMessage, - { - let round_number = >::ROUND; - let round_slot = match self - .rounds - .get_mut(&round_number) - .ok_or(errors::RoundsMisuse::UnregisteredRound { n: round_number }) - { - Ok(slot) => slot, - Err(err) => return Some(Err(err.into())), - }; - let round = match round_slot - .as_mut() - .ok_or(errors::RoundsMisuse::RoundAlreadyCompleted) - { - Ok(round) => round, - Err(err) => return Some(Err(err.into())), - }; - if round.needs_more_messages().no() { - Some(Self::retrieve_round_output::(round_slot)) - } else { - None - } - } - - fn retrieve_round_output( - slot: &mut Option + Send>>, - ) -> Result> - where - R: MessagesStore, - M: RoundMessage, - { - let mut round = slot.take().ok_or(errors::RoundsMisuse::UnregisteredRound { - n: >::ROUND, - })?; - match round.take_output() { - Ok(Ok(any)) => Ok(*any - .downcast::() - .or(Err(CompleteRoundError::from( - errors::Bug::MismatchedOutputType, - )))?), - Ok(Err(any)) => Err(any - .downcast::>() - .or(Err(CompleteRoundError::from( - errors::Bug::MismatchedErrorType, - )))? - .map_io_err(|e| match e {})), - Err(err) => Err(errors::Bug::TakeRoundResult(err).into()), - } - } -} - -/// Builds [`RoundsRouter`] -pub struct RoundsRouterBuilder { - rounds: BTreeMap + Send>>>, -} - -impl Default for RoundsRouterBuilder -where - M: ProtocolMessage + 'static, -{ - fn default() -> Self { - Self::new() - } -} - -impl RoundsRouterBuilder -where - M: ProtocolMessage + 'static, -{ - /// Constructs [`RoundsRouterBuilder`] - /// - /// Alias to [`RoundsRouter::builder`] - pub fn new() -> Self { - Self { - rounds: BTreeMap::new(), - } - } - - /// Registers new round - /// - /// ## Panics - /// Panics if round `R` was already registered - pub fn add_round(&mut self, message_store: R) -> Round - where - R: MessagesStore + Send + 'static, - R::Output: Send, - R::Error: Send, - M: RoundMessage, - { - let overridden_round = self.rounds.insert( - M::ROUND, - Some(Box::new(ProcessRoundMessageImpl::new(message_store))), - ); - if overridden_round.is_some() { - panic!("round {} is overridden", M::ROUND); - } - Round { - _ph: PhantomType::new(), - } - } - - /// Builds [`RoundsRouter`] - /// - /// Takes a stream of incoming messages which will be routed between registered rounds - pub fn listen(self, incomings: S) -> RoundsRouter - where - S: Stream, E>>, - { - RoundsRouter { - incomings, - rounds: self.rounds, - } - } -} - -/// A round of MPC protocol -/// -/// `Round` can be used to retrieve messages received at this round by calling [`RoundsRouter::complete`]. See -/// [module level](self) documentation to see usage. -pub struct Round { - _ph: PhantomType, -} - -trait ProcessRoundMessage { - type Msg; - - /// Processes round message - /// - /// Before calling this method you must ensure that `.needs_more_messages()` returns `Yes`, - /// otherwise calling this method is unexpected. - fn process_message(&mut self, msg: Incoming); - - /// Indicated whether the store needs more messages - /// - /// If it returns `Yes`, then you need to collect more messages to complete round. If it's `No` - /// then you need to take the round output by calling `.take_output()`. - fn needs_more_messages(&self) -> NeedsMoreMessages; - - /// Tries to obtain round output - /// - /// Can be called once `process_message()` returned `NeedMoreMessages::No`. - /// - /// Returns: - /// * `Ok(Ok(any))` — round is successfully completed, `any` needs to be downcasted to `MessageStore::Output` - /// * `Ok(Err(any))` — round has terminated with an error, `any` needs to be downcasted to `CompleteRoundError` - /// * `Err(err)` — couldn't retrieve the output, see [`TakeOutputError`] - #[allow(clippy::type_complexity)] - fn take_output(&mut self) -> Result, Box>, TakeOutputError>; -} - -#[derive(Debug, thiserror::Error)] -enum TakeOutputError { - #[error("output is already taken")] - AlreadyTaken, - #[error("output is not ready yet, more messages are needed")] - NotReady, -} - -enum ProcessRoundMessageImpl> { - InProgress { store: S, _ph: PhantomType }, - Completed(Result>), - Gone, -} - -impl> ProcessRoundMessageImpl { - pub fn new(store: S) -> Self { - if store.wants_more() { - Self::InProgress { - store, - _ph: Default::default(), - } - } else { - Self::Completed( - store - .output() - .map_err(|_| errors::ImproperStoreImpl::StoreDidntOutput.into()), - ) - } - } -} - -impl ProcessRoundMessageImpl -where - S: MessagesStore, - M: ProtocolMessage + RoundMessage, -{ - fn _process_message( - store: &mut S, - msg: Incoming, - ) -> Result<(), CompleteRoundError> { - let msg = msg.try_map(M::from_protocol_message).map_err(|msg| { - errors::Bug::MessageFromAnotherRound { - actual_number: msg.round(), - expected_round: M::ROUND, - } - })?; - - store - .add_message(msg) - .map_err(CompleteRoundError::ProcessMessage)?; - Ok(()) - } -} - -impl ProcessRoundMessage for ProcessRoundMessageImpl -where - S: MessagesStore, - M: ProtocolMessage + RoundMessage, -{ - type Msg = M; - - fn process_message(&mut self, msg: Incoming) { - let store = match self { - Self::InProgress { store, .. } => store, - _ => { - return; - } - }; - - match Self::_process_message(store, msg) { - Ok(()) => { - if store.wants_more() { - return; - } - - let store = match mem::replace(self, Self::Gone) { - Self::InProgress { store, .. } => store, - _ => { - *self = Self::Completed(Err(errors::Bug::IncoherentState { - expected: "InProgress", - justification: - "we checked at beginning of the function that `state` is InProgress", - } - .into())); - return; - } - }; - - match store.output() { - Ok(output) => *self = Self::Completed(Ok(output)), - Err(_err) => { - *self = - Self::Completed(Err(errors::ImproperStoreImpl::StoreDidntOutput.into())) - } - } - } - Err(err) => { - *self = Self::Completed(Err(err)); - } - } - } - - fn needs_more_messages(&self) -> NeedsMoreMessages { - match self { - Self::InProgress { .. } => NeedsMoreMessages::Yes, - _ => NeedsMoreMessages::No, - } - } - - fn take_output(&mut self) -> Result, Box>, TakeOutputError> { - match self { - Self::InProgress { .. } => return Err(TakeOutputError::NotReady), - Self::Gone => return Err(TakeOutputError::AlreadyTaken), - _ => (), - } - match mem::replace(self, Self::Gone) { - Self::Completed(Ok(output)) => Ok(Ok(Box::new(output))), - Self::Completed(Err(err)) => Ok(Err(Box::new(err))), - _ => unreachable!("it's checked to be completed"), - } - } -} - -enum NeedsMoreMessages { - Yes, - No, -} - -#[allow(dead_code)] -impl NeedsMoreMessages { - pub fn yes(&self) -> bool { - matches!(self, Self::Yes) - } - pub fn no(&self) -> bool { - matches!(self, Self::No) - } -} - -/// When something goes wrong -pub mod errors { - use super::TakeOutputError; - - /// Error indicating that `Rounds` failed to complete certain round - #[derive(Debug, thiserror::Error)] - pub enum CompleteRoundError { - /// [`MessagesStore`](super::MessagesStore) failed to process this message - #[error("failed to process the message")] - ProcessMessage(#[source] ProcessErr), - /// Receiving next message resulted into i/o error - #[error("receive next message")] - Io(#[from] IoError), - /// Some implementation specific error - /// - /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug - /// in `Rounds` implementation - #[error("implementation error")] - Other(#[source] OtherError), - } - - /// Error indicating that receiving next message resulted into i/o error - #[derive(Debug, thiserror::Error)] - pub enum IoError { - /// I/O error - #[error("i/o error")] - Io(#[source] E), - /// Encountered unexpected EOF - #[error("unexpected eof")] - UnexpectedEof, - } - - /// Some implementation specific error - /// - /// Error may be result of improper `MessagesStore` implementation, API misuse, or bug - /// in `Rounds` implementation - #[derive(Debug, thiserror::Error)] - #[error(transparent)] - pub struct OtherError(OtherReason); - - #[derive(Debug, thiserror::Error)] - pub(super) enum OtherReason { - #[error("improper `MessagesStore` implementation")] - ImproperStoreImpl(#[source] ImproperStoreImpl), - #[error("`Rounds` API misuse")] - RoundsMisuse(#[source] RoundsMisuse), - #[error("bug in `Rounds` (please, open a issue)")] - Bug(#[source] Bug), - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum ImproperStoreImpl { - /// Store indicated that it received enough messages but didn't output - /// - /// I.e. [`store.wants_more()`] returned `false`, but `store.output()` returned `Err(_)`. - #[error("store didn't output")] - StoreDidntOutput, - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum RoundsMisuse { - #[error("round is already completed")] - RoundAlreadyCompleted, - #[error("round {n} is not registered")] - UnregisteredRound { n: u16 }, - } - - #[derive(Debug, thiserror::Error)] - pub(super) enum Bug { - #[error( - "message originates from another round: we process messages from round \ - {expected_round}, got message from round {actual_number}" - )] - MessageFromAnotherRound { - expected_round: u16, - actual_number: u16, - }, - #[error("state is incoherent, it's expected to be {expected}: {justification}")] - IncoherentState { - expected: &'static str, - justification: &'static str, - }, - #[error("mismatched output type")] - MismatchedOutputType, - #[error("mismatched error type")] - MismatchedErrorType, - #[error("take round result")] - TakeRoundResult(#[source] TakeOutputError), - } - - impl CompleteRoundError { - pub(super) fn map_io_err(self, f: F) -> CompleteRoundError - where - F: FnOnce(IoErr) -> E, - { - match self { - CompleteRoundError::Io(err) => CompleteRoundError::Io(err.map_err(f)), - CompleteRoundError::ProcessMessage(err) => CompleteRoundError::ProcessMessage(err), - CompleteRoundError::Other(err) => CompleteRoundError::Other(err), - } - } - } - - impl IoError { - pub(super) fn map_err(self, f: F) -> IoError - where - F: FnOnce(E) -> B, - { - match self { - IoError::Io(e) => IoError::Io(f(e)), - IoError::UnexpectedEof => IoError::UnexpectedEof, - } - } - } - - macro_rules! impl_from_other_error { - ($($err:ident),+,) => {$( - impl From<$err> for CompleteRoundError { - fn from(err: $err) -> Self { - Self::Other(OtherError(OtherReason::$err(err))) - } - } - )+}; - } - - impl_from_other_error! { - ImproperStoreImpl, - RoundsMisuse, - Bug, - } -} - -#[cfg(test)] -mod tests { - struct Store; - - #[derive(crate::ProtocolMessage)] - #[protocol_message(root = crate)] - enum FakeProtocolMsg { - R1(Msg1), - } - struct Msg1; - - impl super::MessagesStore for Store { - type Msg = Msg1; - type Output = (); - type Error = core::convert::Infallible; - - fn add_message(&mut self, _msg: crate::Incoming) -> Result<(), Self::Error> { - Ok(()) - } - fn wants_more(&self) -> bool { - false - } - fn output(self) -> Result { - Ok(()) - } - } - - #[tokio::test] - async fn complete_round_that_expects_no_messages() { - let incomings = futures::stream::pending::< - Result, core::convert::Infallible>, - >(); - - let mut rounds = super::RoundsRouter::builder(); - let round1 = rounds.add_round(Store); - let mut rounds = rounds.listen(incomings); - - rounds.complete(round1).await.unwrap(); - } -} diff --git a/round-based/src/rounds_router/store.rs b/round-based/src/rounds_router/store.rs deleted file mode 100644 index a7758df..0000000 --- a/round-based/src/rounds_router/store.rs +++ /dev/null @@ -1,132 +0,0 @@ -use crate::Incoming; - -/// Stores messages received at particular round -/// -/// In MPC protocol, party at every round usually needs to receive up to `n` messages. `MessagesStore` -/// is a container that stores messages, it knows how many messages are expected to be received, -/// and should implement extra measures against malicious parties (e.g. prohibit message overwrite). -/// -/// ## Procedure -/// `MessagesStore` stores received messages. Once enough messages are received, it outputs [`MessagesStore::Output`]. -/// In order to save received messages, [`.add_message(msg)`] is called. Then, [`.wants_more()`] tells whether more -/// messages are needed to be received. If it returned `false`, then output can be retrieved by calling [`.output()`]. -/// -/// [`.add_message(msg)`]: Self::add_message -/// [`.wants_more()`]: Self::wants_more -/// [`.output()`]: Self::output -/// -/// ## Example -/// [`RoundInput`](super::simple_store::RoundInput) is an simple messages store. Refer to its docs to see usage examples. -pub trait MessagesStore: Sized + 'static { - /// Message type - type Msg; - /// Store output (e.g. `Vec<_>` of received messages) - type Output; - /// Store error - type Error: core::error::Error; - - /// Adds received message to the store - /// - /// Returns error if message cannot be processed. Usually it means that sender behaves maliciously. - fn add_message(&mut self, msg: Incoming) -> Result<(), Self::Error>; - /// Indicates if store expects more messages to receive - fn wants_more(&self) -> bool; - /// Retrieves store output if enough messages are received - /// - /// Returns `Err(self)` if more message are needed to be received. - /// - /// If store indicated that it needs no more messages (ie `store.wants_more() == false`), then - /// this function must return `Ok(_)`. - fn output(self) -> Result; -} - -/// Message of MPC protocol -/// -/// MPC protocols typically consist of several rounds, each round has differently typed message. -/// `ProtocolMessage` and [`RoundMessage`] traits are used to examine received message: `ProtocolMessage::round` -/// determines which round message belongs to, and then `RoundMessage` trait can be used to retrieve -/// actual round-specific message. -/// -/// You should derive these traits using proc macro (requires `derive` feature): -/// ```rust -/// use round_based::ProtocolMessage; -/// -/// #[derive(ProtocolMessage)] -/// pub enum Message { -/// Round1(Msg1), -/// Round2(Msg2), -/// // ... -/// } -/// -/// pub struct Msg1 { /* ... */ } -/// pub struct Msg2 { /* ... */ } -/// ``` -/// -/// This desugars into: -/// -/// ```rust -/// use round_based::rounds_router::{ProtocolMessage, RoundMessage}; -/// -/// pub enum Message { -/// Round1(Msg1), -/// Round2(Msg2), -/// // ... -/// } -/// -/// pub struct Msg1 { /* ... */ } -/// pub struct Msg2 { /* ... */ } -/// -/// impl ProtocolMessage for Message { -/// fn round(&self) -> u16 { -/// match self { -/// Message::Round1(_) => 1, -/// Message::Round2(_) => 2, -/// // ... -/// } -/// } -/// } -/// impl RoundMessage for Message { -/// const ROUND: u16 = 1; -/// fn to_protocol_message(round_message: Msg1) -> Self { -/// Message::Round1(round_message) -/// } -/// fn from_protocol_message(protocol_message: Self) -> Result { -/// match protocol_message { -/// Message::Round1(msg) => Ok(msg), -/// msg => Err(msg), -/// } -/// } -/// } -/// impl RoundMessage for Message { -/// const ROUND: u16 = 2; -/// fn to_protocol_message(round_message: Msg2) -> Self { -/// Message::Round2(round_message) -/// } -/// fn from_protocol_message(protocol_message: Self) -> Result { -/// match protocol_message { -/// Message::Round2(msg) => Ok(msg), -/// msg => Err(msg), -/// } -/// } -/// } -/// ``` -pub trait ProtocolMessage: Sized { - /// Number of round this message originates from - fn round(&self) -> u16; -} - -/// Round message -/// -/// See [`ProtocolMessage`] trait documentation. -pub trait RoundMessage: ProtocolMessage { - /// Number of the round this message belongs to - const ROUND: u16; - - /// Converts round message into protocol message (never fails) - fn to_protocol_message(round_message: M) -> Self; - /// Extracts round message from protocol message - /// - /// Returns `Err(protocol_message)` if `protocol_message.round() != Self::ROUND`, otherwise - /// returns `Ok(round_message)` - fn from_protocol_message(protocol_message: Self) -> Result; -} diff --git a/round-based/src/sim/async_env.rs b/round-based/src/sim/async_env.rs index 8bb43df..727d093 100644 --- a/round-based/src/sim/async_env.rs +++ b/round-based/src/sim/async_env.rs @@ -33,7 +33,8 @@ //! //! # type Result = std::result::Result; //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol you want to test //! pub async fn protocol_of_random_generation( //! party: M, @@ -41,7 +42,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -75,7 +76,10 @@ use futures_util::{Sink, Stream}; use tokio::sync::broadcast; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use crate::delivery::{Delivery, Incoming, Outgoing}; +use crate::{ + delivery::{Incoming, Outgoing}, + ProtocolMsg, +}; use crate::{MessageDestination, MessageType, MpcParty, MsgId, PartyIndex}; use super::SimResult; @@ -91,7 +95,7 @@ pub struct Network { impl Network where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, { /// Instantiates a new simulation pub fn new() -> Self { @@ -126,23 +130,23 @@ where let local_party_idx = self.next_party_idx; self.next_party_idx += 1; - MockedDelivery { - incoming: MockedIncoming { + MockedDelivery::new( + MockedIncoming { local_party_idx, receiver: BroadcastStream::new(self.channel.subscribe()), }, - outgoing: MockedOutgoing { + MockedOutgoing { local_party_idx, sender: self.channel.clone(), next_msg_id: self.next_msg_id.clone(), }, - } + ) } } impl Default for Network where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, { fn default() -> Self { Self::new() @@ -150,23 +154,17 @@ where } /// Mocked networking -pub struct MockedDelivery { - incoming: MockedIncoming, - outgoing: MockedOutgoing, -} +pub type MockedDelivery = crate::mpc::Halves, MockedOutgoing>; -impl Delivery for MockedDelivery -where - M: Clone + Send + Unpin + 'static, -{ - type Send = MockedOutgoing; - type Receive = MockedIncoming; - type SendError = broadcast::error::SendError<()>; - type ReceiveError = BroadcastStreamRecvError; - - fn split(self) -> (Self::Receive, Self::Send) { - (self.incoming, self.outgoing) - } +/// Delivery error +#[derive(Debug, thiserror::Error)] +pub enum MockedDeliveryError { + /// Error occurred when sending a message + #[error(transparent)] + Recv(BroadcastStreamRecvError), + /// Error occurred when receiving a message + #[error(transparent)] + Send(broadcast::error::SendError<()>), } /// Incoming channel of mocked network @@ -179,13 +177,13 @@ impl Stream for MockedIncoming where M: Clone + Send + 'static, { - type Item = Result, BroadcastStreamRecvError>; + type Item = Result, MockedDeliveryError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { let msg = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) { Some(Ok(m)) => m, - Some(Err(e)) => return Poll::Ready(Some(Err(e))), + Some(Err(e)) => return Poll::Ready(Some(Err(MockedDeliveryError::Recv(e)))), None => return Poll::Ready(None), }; if msg.recipient.is_p2p() @@ -206,7 +204,7 @@ pub struct MockedOutgoing { } impl Sink> for MockedOutgoing { - type Error = broadcast::error::SendError<()>; + type Error = MockedDeliveryError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -214,7 +212,7 @@ impl Sink> for MockedOutgoing { fn start_send(self: Pin<&mut Self>, msg: Outgoing) -> Result<(), Self::Error> { let msg_type = match msg.recipient { - MessageDestination::AllParties => MessageType::Broadcast, + MessageDestination::AllParties { reliable } => MessageType::Broadcast { reliable }, MessageDestination::OneParty(_) => MessageType::P2P, }; self.sender @@ -224,7 +222,7 @@ impl Sink> for MockedOutgoing { msg_type, msg: m, })) - .map_err(|_| broadcast::error::SendError(()))?; + .map_err(|_| MockedDeliveryError::Send(broadcast::error::SendError(())))?; Ok(()) } @@ -260,7 +258,8 @@ impl NextMessageId { /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// party: M, @@ -268,7 +267,7 @@ impl NextMessageId { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -294,7 +293,7 @@ pub async fn run( party_start: impl FnMut(u16, MpcParty>) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity(DEFAULT_CAPACITY, n, party_start).await @@ -311,12 +310,12 @@ pub async fn run_with_capacity( mut party_start: impl FnMut(u16, MpcParty>) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity_and_setup( capacity, - core::iter::repeat(()).take(n.into()), + core::iter::repeat_n((), n.into()), |i, party, ()| party_start(i, party), ) .await @@ -337,7 +336,8 @@ where /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// rng: impl rand::RngCore, @@ -346,7 +346,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -372,7 +372,7 @@ pub async fn run_with_setup( party_start: impl FnMut(u16, MpcParty>, S) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { run_with_capacity_and_setup::(DEFAULT_CAPACITY, setups, party_start).await @@ -389,7 +389,7 @@ pub async fn run_with_capacity_and_setup( mut party_start: impl FnMut(u16, MpcParty>, S) -> F, ) -> SimResult where - M: Clone + Send + Unpin + 'static, + M: ProtocolMsg + Clone + Send + Unpin + 'static, F: Future, { let mut network = Network::::with_capacity(capacity); diff --git a/round-based/src/sim/mod.rs b/round-based/src/sim/mod.rs index 56aa794..e1974f5 100644 --- a/round-based/src/sim/mod.rs +++ b/round-based/src/sim/mod.rs @@ -35,7 +35,8 @@ //! //! # type Result = std::result::Result; //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol you want to test //! pub async fn protocol_of_random_generation( //! party: M, @@ -43,7 +44,7 @@ //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: Mpc //! { //! // ... //! # todo!() @@ -67,7 +68,9 @@ use alloc::{boxed::Box, collections::VecDeque, string::ToString, vec::Vec}; use core::future::Future; -use crate::{state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing}; +use crate::{ + state_machine::ProceedResult, Incoming, MessageDestination, MessageType, Outgoing, ProtocolMsg, +}; #[cfg(feature = "sim-async")] pub mod async_env; @@ -229,7 +232,7 @@ enum Party<'a, O, M> { impl<'a, O, M> Simulation<'a, O, M> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, { /// Creates empty simulation containing no parties /// @@ -415,7 +418,7 @@ impl MessagesQueue { fn send_message(&mut self, sender: u16, msg: Outgoing) -> Result<(), SimError> { match msg.recipient { - MessageDestination::AllParties => { + MessageDestination::AllParties { reliable } => { let mut msg_ids = self.next_id..; for (destination, msg_id) in (0..) .zip(&mut self.queue) @@ -426,7 +429,7 @@ impl MessagesQueue { destination.push_back(Incoming { id: msg_id, sender, - msg_type: MessageType::Broadcast, + msg_type: MessageType::Broadcast { reliable }, msg: msg.msg.clone(), }) } @@ -471,7 +474,8 @@ impl MessagesQueue { /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// party: M, @@ -479,7 +483,7 @@ impl MessagesQueue { /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -504,10 +508,10 @@ pub fn run( mut party_start: impl FnMut(u16, crate::state_machine::MpcParty) -> F, ) -> Result, SimError> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, F: Future, { - run_with_setup(core::iter::repeat(()).take(n.into()), |i, party, ()| { + run_with_setup(core::iter::repeat_n((), n.into()), |i, party, ()| { party_start(i, party) }) } @@ -525,7 +529,8 @@ where /// /// # type Result = std::result::Result; /// # type Randomness = [u8; 32]; -/// # type Msg = (); +/// # #[derive(round_based::ProtocolMsg, Clone)] +/// # enum Msg {} /// // Any MPC protocol you want to test /// pub async fn protocol_of_random_generation( /// rng: impl rand::RngCore, @@ -534,7 +539,7 @@ where /// n: u16 /// ) -> Result /// where -/// M: Mpc +/// M: Mpc /// { /// // ... /// # todo!() @@ -559,7 +564,7 @@ pub fn run_with_setup( mut party_start: impl FnMut(u16, crate::state_machine::MpcParty, S) -> F, ) -> Result, SimError> where - M: Clone + 'static, + M: ProtocolMsg + Clone + 'static, F: Future, { let mut sim = Simulation::empty(); diff --git a/round-based/src/state_machine/delivery.rs b/round-based/src/state_machine/delivery.rs index 34b5936..9a710ff 100644 --- a/round-based/src/state_machine/delivery.rs +++ b/round-based/src/state_machine/delivery.rs @@ -1,18 +1,18 @@ use core::task::{ready, Poll}; -/// Stream of incoming messages -pub struct Incomings { +/// Provides a stream of incoming and sink for outgoing messages +pub struct Delivery { shared_state: super::shared_state::SharedStateRef, } -impl Incomings { +impl Delivery { pub(super) fn new(shared_state: super::shared_state::SharedStateRef) -> Self { Self { shared_state } } } -impl crate::Stream for Incomings { - type Item = Result, core::convert::Infallible>; +impl futures_util::Stream for Delivery { + type Item = Result, DeliveryErr>; fn poll_next( self: core::pin::Pin<&mut Self>, @@ -26,19 +26,8 @@ impl crate::Stream for Incomings { } } -/// Sink for outgoing messages -pub struct Outgoings { - shared_state: super::shared_state::SharedStateRef, -} - -impl Outgoings { - pub(super) fn new(shared_state: super::shared_state::SharedStateRef) -> Self { - Self { shared_state } - } -} - -impl crate::Sink> for Outgoings { - type Error = SendErr; +impl futures_util::Sink> for Delivery { + type Error = DeliveryErr; fn poll_ready( self: core::pin::Pin<&mut Self>, @@ -54,7 +43,7 @@ impl crate::Sink> for Outgoings { ) -> Result<(), Self::Error> { self.shared_state .protocol_saves_msg_to_be_sent(msg) - .map_err(|_| SendErr(SendErrReason::NotReady)) + .map_err(|_| DeliveryErr(Reason::NotReady)) } fn poll_flush( @@ -73,13 +62,13 @@ impl crate::Sink> for Outgoings { } } -/// Error returned by [`Outgoings`] sink +/// Error returned by [`Delivery`] #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct SendErr(SendErrReason); +pub struct DeliveryErr(Reason); #[derive(Debug, thiserror::Error)] -enum SendErrReason { +enum Reason { #[error("sink is not ready")] NotReady, } diff --git a/round-based/src/state_machine/mod.rs b/round-based/src/state_machine/mod.rs index d674122..dded810 100644 --- a/round-based/src/state_machine/mod.rs +++ b/round-based/src/state_machine/mod.rs @@ -8,19 +8,19 @@ //! ## Example //! ```rust,no_run //! # fn main() -> anyhow::Result<()> { -//! use round_based::{Mpc, PartyIndex}; //! use anyhow::{Result, Error, Context as _}; //! //! # type Randomness = [u8; 32]; -//! # type Msg = (); +//! # #[derive(round_based::ProtocolMsg, Clone)] +//! # enum Msg {} //! // Any MPC protocol //! pub async fn protocol_of_random_generation( //! party: M, -//! i: PartyIndex, +//! i: u16, //! n: u16 //! ) -> Result //! where -//! M: Mpc +//! M: round_based::Mpc //! { //! // ... //! # todo!() @@ -67,8 +67,10 @@ mod shared_state; use core::{future::Future, task::Poll}; +use crate::ProtocolMsg; + pub use self::{ - delivery::{Incomings, Outgoings, SendErr}, + delivery::{Delivery, DeliveryErr}, runtime::{Runtime, YieldNow}, }; @@ -226,9 +228,6 @@ where } } -/// Delivery implementation used in the state machine -pub type Delivery = (Incomings, Outgoings); - /// MpcParty instantiated with state machine implementation of delivery and async runtime pub type MpcParty = crate::MpcParty, Runtime>; @@ -245,15 +244,13 @@ pub fn wrap_protocol<'a, M, F>( ) -> impl StateMachine + 'a where F: Future + 'a, - M: 'static, + M: ProtocolMsg + 'static, { let shared_state = shared_state::SharedStateRef::new(); - let incomings = Incomings::new(shared_state.clone()); - let outgoings = Outgoings::new(shared_state.clone()); - let delivery = (incomings, outgoings); + let delivery = Delivery::new(shared_state.clone()); let runtime = Runtime::new(shared_state.clone()); - let future = protocol(crate::MpcParty::connected(delivery).set_runtime(runtime)); + let future = protocol(crate::mpc::connected(delivery).with_runtime(runtime)); let future = alloc::boxed::Box::pin(future); StateMachineImpl { diff --git a/round-based/src/state_machine/runtime.rs b/round-based/src/state_machine/runtime.rs index 005796c..6d64577 100644 --- a/round-based/src/state_machine/runtime.rs +++ b/round-based/src/state_machine/runtime.rs @@ -11,14 +11,13 @@ impl Runtime { } } -impl crate::runtime::AsyncRuntime for Runtime { - type YieldNowFuture = YieldNow; - - fn yield_now(&self) -> Self::YieldNowFuture { +impl crate::mpc::party::AsyncRuntime for Runtime { + async fn yield_now(&self) { YieldNow { shared_state: self.shared_state.clone(), yielded: false, } + .await } } diff --git a/round-based/src/state_machine/shared_state.rs b/round-based/src/state_machine/shared_state.rs index cd8ddf8..86bc622 100644 --- a/round-based/src/state_machine/shared_state.rs +++ b/round-based/src/state_machine/shared_state.rs @@ -29,7 +29,7 @@ impl SharedStateRef { ))) } - /// Any protocol-initated work (like flushing message to be sent, receiving message, etc.) can + /// Any protocol-initiated work (like flushing message to be sent, receiving message, etc.) can /// only be scheduled when there was no other task scheduled. /// /// This method checks whether a task can be scheduled, and returns [`CanSchedule`] which @@ -185,7 +185,7 @@ mod test { let executor_state = shared_state; let msg = Outgoing { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: false }, msg: 1, }; outgoings_state @@ -235,7 +235,7 @@ mod test { let incoming_msg = Incoming { id: 0, sender: 1, - msg_type: crate::MessageType::Broadcast, + msg_type: crate::MessageType::Broadcast { reliable: false }, msg: "hello", }; executor_state.executor_received_msg(incoming_msg).unwrap(); @@ -293,7 +293,7 @@ mod test { let shared_state = SharedStateRef::new(); shared_state .protocol_saves_msg_to_be_sent(Outgoing { - recipient: MessageDestination::AllParties, + recipient: MessageDestination::AllParties { reliable: false }, msg: 1, }) .expect("msg slot isn't empty"); diff --git a/round-based/tests/derive/compile-fail/wrong_usage.rs b/round-based/tests/derive/compile-fail/wrong_usage.rs index d081241..fb35bcc 100644 --- a/round-based/tests/derive/compile-fail/wrong_usage.rs +++ b/round-based/tests/derive/compile-fail/wrong_usage.rs @@ -1,9 +1,9 @@ -use round_based::ProtocolMessage; +use round_based::ProtocolMsg; -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] enum Msg { // Unnamed variant with single field is the only correct enum variant - // that doesn't contradicts with ProtocolMessage derivation + // that doesn't contradicts with ProtocolMsg derivation VariantA(u16), // Error: You can't have named variants VariantB { n: u32 }, @@ -15,38 +15,38 @@ enum Msg { VariantE, } -// Structure cannot implement ProtocolMessage -#[derive(ProtocolMessage)] +// Structure cannot implement ProtocolMsg +#[derive(ProtocolMsg)] struct Msg2 { some_field: u64, } -// Union cannot implement ProtocolMessage -#[derive(ProtocolMessage)] +// Union cannot implement ProtocolMsg +#[derive(ProtocolMsg)] union Msg3 { variant: u64, } // protocol_message is repeated twice -#[derive(ProtocolMessage)] -#[protocol_message(root = one)] -#[protocol_message(root = two)] +#[derive(ProtocolMsg)] +#[protocol_msg(root = one)] +#[protocol_msg(root = two)] enum Msg4 { One(u32), Two(u16), } // ", blah blah" is not permitted input -#[derive(ProtocolMessage)] -#[protocol_message(root = one, blah blah)] +#[derive(ProtocolMsg)] +#[protocol_msg(root = one, blah blah)] enum Msg5 { One(u32), Two(u16), } -// `protocol_message` must not be empty -#[derive(ProtocolMessage)] -#[protocol_message()] +// `protocol_msh` must not be empty +#[derive(ProtocolMsg)] +#[protocol_msg()] enum Msg6 { One(u32), Two(u16), diff --git a/round-based/tests/derive/compile-fail/wrong_usage.stderr b/round-based/tests/derive/compile-fail/wrong_usage.stderr index f3d8341..335630c 100644 --- a/round-based/tests/derive/compile-fail/wrong_usage.stderr +++ b/round-based/tests/derive/compile-fail/wrong_usage.stderr @@ -1,53 +1,53 @@ -error: named variants are not allowed in ProtocolMessage +error: named variants are not allowed in ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:9:5 | 9 | VariantB { n: u32 }, | ^^^^^^^^ -error: this variant must contain exactly one field to be valid ProtocolMessage +error: this variant must contain exactly one field to be valid ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:11:5 | 11 | VariantC(u32, String), | ^^^^^^^^ -error: this variant must contain exactly one field to be valid ProtocolMessage +error: this variant must contain exactly one field to be valid ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:13:5 | 13 | VariantD(), | ^^^^^^^^ -error: unit variants are not allowed in ProtocolMessage +error: unit variants are not allowed in ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:15:5 | 15 | VariantE, | ^^^^^^^^ -error: only enum may implement ProtocolMessage +error: only enum may implement ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:20:1 | 20 | struct Msg2 { | ^^^^^^ -error: only enum may implement ProtocolMessage +error: only enum may implement ProtocolMsg --> tests/derive/compile-fail/wrong_usage.rs:26:1 | 26 | union Msg3 { | ^^^^^ -error: #[protocol_message] attribute appears more than once +error: #[protocol_msg] attribute appears more than once --> tests/derive/compile-fail/wrong_usage.rs:33:3 | -33 | #[protocol_message(root = two)] - | ^^^^^^^^^^^^^^^^ +33 | #[protocol_msg(root = two)] + | ^^^^^^^^^^^^ error: unexpected token - --> tests/derive/compile-fail/wrong_usage.rs:41:30 + --> tests/derive/compile-fail/wrong_usage.rs:41:26 | -41 | #[protocol_message(root = one, blah blah)] - | ^ +41 | #[protocol_msg(root = one, blah blah)] + | ^ error: unexpected end of input, expected `root` - --> tests/derive/compile-fail/wrong_usage.rs:49:20 + --> tests/derive/compile-fail/wrong_usage.rs:49:16 | -49 | #[protocol_message()] - | ^ +49 | #[protocol_msg()] + | ^ diff --git a/round-based/tests/derive/compile-pass/correct_usage.rs b/round-based/tests/derive/compile-pass/correct_usage.rs index e4be823..21ee289 100644 --- a/round-based/tests/derive/compile-pass/correct_usage.rs +++ b/round-based/tests/derive/compile-pass/correct_usage.rs @@ -1,14 +1,14 @@ -use round_based::ProtocolMessage; +use round_based::ProtocolMsg; -#[derive(ProtocolMessage)] +#[derive(ProtocolMsg)] enum Msg { VariantA(u16), VariantB(String), VariantC((u16, String)), VariantD(MyStruct), } -#[derive(ProtocolMessage)] -#[protocol_message(root = round_based)] +#[derive(ProtocolMsg)] +#[protocol_msg(root = round_based)] enum Msg2 { VariantA(u16), VariantB(String),