diff --git a/simtfl/bc/__init__.py b/simtfl/bc/__init__.py index a9e4180..fb361ba 100644 --- a/simtfl/bc/__init__.py +++ b/simtfl/bc/__init__.py @@ -14,301 +14,3 @@ The simulation of the shielded protocol does not attempt to model any actual privacy properties. """ - - -from __future__ import annotations -from typing import Iterable, Optional -from collections.abc import Sequence -from dataclasses import dataclass -from enum import Enum, auto - -from collections import deque -from itertools import chain, islice -from sys import version_info - -from ..util import Unique - - -class BlockHash(Unique): - """Unique value representing a best-chain block hash.""" - pass - - -class BCTransaction: - """A transaction for a best-chain protocol.""" - - @dataclass(frozen=True) - class _TXO: - tx: BCTransaction - index: int - value: int - - @dataclass(eq=False) - class _Note(Unique): - """ - A shielded note. Unlike in the actual protocol, we conflate notes, note - commitments, and nullifiers. This will be sufficient because we don't - need to maintain any actual privacy. - - This is not a frozen dataclass; its identity is important, and models the - fact that each note has a unique commitment and nullifier in the actual - protocol. - """ - value: int - - def __init__(self, - transparent_inputs: Sequence[BCTransaction._TXO], - transparent_output_values: Sequence[int], - shielded_inputs: Sequence[BCTransaction._Note], - shielded_output_values: Sequence[int], - fee: int, - anchor: Optional[BCContext]=None, - issuance: int=0): - """ - Constructs a `BCTransaction` with the given transparent inputs, transparent - output values, anchor, shielded inputs, shielded output values, fee, and - (if it is a coinbase transaction) issuance. - - The elements of `transparent_inputs` are TXO objects obtained from the - `transparent_output` method of another `BCTransaction`. The elements of - `shielded_inputs` are Note objects obtained from the `shielded_output` - method of another `BCTransaction`. The TXO and Note classes are private, - and these objects should not be constructed directly. - - The anchor is modelled as a `BCContext` such that - `anchor.can_spend(shielded_inputs)`. If there are no shielded inputs, - `anchor` must be `None`. The anchor object must not be modified after - passing it to this constructor (copy it if necessary). - - For a coinbase transaction, pass `[]` for `transparent_inputs` and - `shielded_inputs`, and pass `fee` as a negative value of magnitude equal - to the total amount of fees paid by other transactions in the block. - """ - assert issuance >= 0 - coinbase = len(transparent_inputs) + len(shielded_inputs) == 0 - assert fee >= 0 or coinbase - assert issuance == 0 or coinbase - assert all((v >= 0 for v in chain(transparent_output_values, shielded_output_values))) - assert ( - sum((txin.value for txin in transparent_inputs)) - + sum((note.value for note in shielded_inputs)) - + issuance == - sum(transparent_output_values) - + sum(shielded_output_values) - + fee - ) - assert anchor is None if len(shielded_inputs) == 0 else ( - anchor is not None and anchor.can_spend(shielded_inputs)) - - self.transparent_inputs = transparent_inputs - self.transparent_outputs = [self._TXO(self, i, v) - for (i, v) in enumerate(transparent_output_values)] - self.shielded_inputs = shielded_inputs - self.shielded_outputs = [self._Note(v) for v in shielded_output_values] - self.fee = fee - self.anchor = anchor - self.issuance = issuance - - def transparent_input(self, index: int) -> BCTransaction._TXO: - """Returns the transparent input TXO with the given index.""" - return self.transparent_inputs[index] - - def transparent_output(self, index: int) -> BCTransaction._TXO: - """Returns the transparent output TXO with the given index.""" - return self.transparent_outputs[index] - - def shielded_input(self, index: int) -> BCTransaction._Note: - """Returns the shielded input note with the given index.""" - return self.shielded_inputs[index] - - def shielded_output(self, index: int) -> BCTransaction._Note: - """Returns the shielded output note with the given index.""" - return self.shielded_outputs[index] - - def is_coinbase(self) -> bool: - """ - Returns `True` if this is a coinbase transaction (it has no inputs). - """ - return len(self.transparent_inputs) + len(self.shielded_inputs) == 0 - - -class Spentness(Enum): - """The spentness status of a note.""" - Unspent = auto() - """The note is unspent.""" - Spent = auto() - """The note is spent.""" - - -class BCContext: - """ - A context that allows checking transactions for contextual validity in a - best-chain protocol. - """ - - assert version_info >= (3, 7), "This code relies on insertion-ordered dicts." - - def __init__(self): - """Constructs an empty `BCContext`.""" - self.transactions: deque[BCTransaction] = deque() - self.utxo_set: set[BCTransaction._TXO] = set() - - # Since dicts are insertion-ordered, this models the sequence in which - # notes are committed as well as their spentness. - self.notes: dict[BCTransaction._Note, Spentness] = {} - - self.total_issuance = 0 - - def committed_notes(self) -> list[(BCTransaction._Note, Spentness)]: - """ - Returns a list of (`Note`, `Spentness`) for notes added to this context, - preserving the commitment order. - """ - return list(self.notes.items()) - - def can_spend(self, tospend: Iterable[BCTransaction._Note]) -> bool: - """Can all of the notes in `tospend` be spent in this context?""" - return all((self.notes.get(note) == Spentness.Unspent for note in tospend)) - - def _check(self, tx: BCTransaction) -> tuple[bool, set[BCTransaction._TXO]]: - """ - Checks whether `tx` is valid. To avoid recomputation, this returns - a pair of the validity, and the set of transparent inputs of `tx`. - """ - txins = set(tx.transparent_inputs) - valid = txins.issubset(self.utxo_set) and self.can_spend(tx.shielded_inputs) - return (valid, txins) - - def is_valid(self, tx: BCTransaction) -> bool: - """Is `tx` valid in this context?""" - return self._check(tx)[0] - - def add_if_valid(self, tx: BCTransaction) -> bool: - """ - If `tx` is valid in this context, add it to the context and return `True`. - Otherwise leave the context unchanged and return `False`. - """ - (valid, txins) = self._check(tx) - if valid: - self.utxo_set -= txins - self.utxo_set |= set(tx.transparent_outputs) - - for note in tx.shielded_inputs: - self.notes[note] = Spentness.Spent - for note in tx.shielded_outputs: - assert note not in self.notes - self.notes[note] = Spentness.Unspent - - self.total_issuance += tx.issuance - self.transactions.append(tx) - - return valid - - def copy(self) -> BCContext: - """Returns an independent copy of this `BCContext`.""" - ctx = BCContext() - ctx.transactions = self.transactions.copy() - ctx.utxo_set = self.utxo_set.copy() - ctx.notes = self.notes.copy() - ctx.total_issuance = self.total_issuance - return ctx - - -class BCBlock: - """A block in a best-chain protocol.""" - - def __init__(self, - parent: Optional[BCBlock], - added_score: int, - transactions: Sequence[BCTransaction], - allow_invalid: bool=False): - """ - Constructs a `BCBlock` with the given parent block, score relative to the - parent, and sequence of transactions. `transactions` must not be modified - after passing it to this constructor (copy it if necessary). - If `allow_invalid` is set, the block need not be valid. - Use `parent=None` to construct the genesis block. - """ - self.parent = parent - self.score = added_score - if self.parent is not None: - self.score += self.parent.score - self.transactions = transactions - self.hash = BlockHash() - if not allow_invalid: - self.assert_noncontextually_valid() - - def assert_noncontextually_valid(self) -> None: - """Assert that non-contextual consensus rules are satisfied for this block.""" - assert len(self.transactions) > 0 - assert self.transactions[0].is_coinbase() - assert not any((tx.is_coinbase() for tx in islice(self.transactions, 1, None))) - assert sum((tx.fee for tx in self.transactions)) == 0 - - def is_noncontextually_valid(self) -> bool: - """Are non-contextual consensus rules satisfied for this block?""" - try: - self.assert_noncontextually_valid() - return True - except AssertionError: - return False - - -@dataclass -class BCProtocol: - """A best-chain protocol.""" - - Transaction: type[object] = BCTransaction - """The type of transactions for this protocol.""" - - Context: type[object] = BCContext - """The type of contexts for this protocol.""" - - Block: type[object] = BCBlock - """The type of blocks for this protocol.""" - - -__all__ = ['BCTransaction', 'BCContext', 'BCBlock', 'BCProtocol', 'BlockHash', 'Spentness'] - - -import unittest - - -class TestBC(unittest.TestCase): - def test_basic(self) -> None: - ctx = BCContext() - coinbase_tx0 = BCTransaction([], [10], [], [], 0, issuance=10) - self.assertTrue(ctx.add_if_valid(coinbase_tx0)) - genesis = BCBlock(None, 1, [coinbase_tx0]) - self.assertEqual(genesis.score, 1) - self.assertEqual(ctx.total_issuance, 10) - - coinbase_tx1 = BCTransaction([], [6], [], [], -1, issuance=5) - spend_tx = BCTransaction([coinbase_tx0.transparent_output(0)], [9], [], [], 1) - self.assertTrue(ctx.add_if_valid(coinbase_tx1)) - self.assertTrue(ctx.add_if_valid(spend_tx)) - block1 = BCBlock(genesis, 1, [coinbase_tx1, spend_tx]) - self.assertEqual(block1.score, 2) - self.assertEqual(ctx.total_issuance, 15) - - coinbase_tx2 = BCTransaction([], [6], [], [], -1, issuance=5) - shielding_tx = BCTransaction([coinbase_tx1.transparent_output(0), spend_tx.transparent_output(0)], - [], [], [8, 6], 1) - self.assertTrue(ctx.add_if_valid(coinbase_tx2)) - self.assertTrue(ctx.add_if_valid(shielding_tx)) - block2 = BCBlock(block1, 2, [coinbase_tx2, shielding_tx]) - block2_anchor = ctx.copy() - self.assertEqual(block2.score, 4) - self.assertEqual(ctx.total_issuance, 20) - - coinbase_tx3 = BCTransaction([], [7], [], [], -2, issuance=5) - shielded_tx = BCTransaction([], [], [shielding_tx.shielded_output(0)], [7], 1, - anchor=block2_anchor) - deshielding_tx = BCTransaction([], [5], [shielding_tx.shielded_output(1)], [], 1, - anchor=block2_anchor) - self.assertTrue(ctx.add_if_valid(coinbase_tx3)) - self.assertTrue(ctx.add_if_valid(shielded_tx)) - self.assertTrue(ctx.add_if_valid(deshielding_tx)) - block3 = BCBlock(block2, 3, [coinbase_tx3, shielded_tx, deshielding_tx]) - self.assertEqual(block3.score, 7) - self.assertEqual(ctx.total_issuance, 25) diff --git a/simtfl/bc/chain.py b/simtfl/bc/chain.py new file mode 100644 index 0000000..0aca192 --- /dev/null +++ b/simtfl/bc/chain.py @@ -0,0 +1,301 @@ +""" +Abstractions for best-chain transactions, contexts, and blocks. +""" + + +from __future__ import annotations +from typing import Iterable, Optional, TypeAlias +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum, auto + +from collections import deque +from itertools import chain, islice +from sys import version_info + +from ..util import Unique + + +class BlockHash(Unique): + """Unique value representing a best-chain block hash.""" + pass + + +class BCTransaction: + """A transaction for a best-chain protocol.""" + + @dataclass(frozen=True) + class _TXO: + tx: BCTransaction + index: int + value: int + + @dataclass(eq=False) + class _Note(Unique): + """ + A shielded note. Unlike in the actual protocol, we conflate notes, note + commitments, and nullifiers. This will be sufficient because we don't + need to maintain any actual privacy. + + This is not a frozen dataclass; its identity is important, and models the + fact that each note has a unique commitment and nullifier in the actual + protocol. + """ + value: int + + def __init__(self, + transparent_inputs: Sequence[BCTransaction._TXO], + transparent_output_values: Sequence[int], + shielded_inputs: Sequence[BCTransaction._Note], + shielded_output_values: Sequence[int], + fee: int, + anchor: Optional[BCContext]=None, + issuance: int=0): + """ + Constructs a `BCTransaction` with the given transparent inputs, transparent + output values, anchor, shielded inputs, shielded output values, fee, and + (if it is a coinbase transaction) issuance. + + The elements of `transparent_inputs` are TXO objects obtained from the + `transparent_output` method of another `BCTransaction`. The elements of + `shielded_inputs` are Note objects obtained from the `shielded_output` + method of another `BCTransaction`. The TXO and Note classes are private, + and these objects should not be constructed directly. + + The anchor is modelled as a `BCContext` such that + `anchor.can_spend(shielded_inputs)`. If there are no shielded inputs, + `anchor` must be `None`. The anchor object must not be modified after + passing it to this constructor (copy it if necessary). + + For a coinbase transaction, pass `[]` for `transparent_inputs` and + `shielded_inputs`, and pass `fee` as a negative value of magnitude equal + to the total amount of fees paid by other transactions in the block. + """ + assert issuance >= 0 + coinbase = len(transparent_inputs) + len(shielded_inputs) == 0 + assert fee >= 0 or coinbase + assert issuance == 0 or coinbase + assert all((v >= 0 for v in chain(transparent_output_values, shielded_output_values))) + assert ( + sum((txin.value for txin in transparent_inputs)) + + sum((note.value for note in shielded_inputs)) + + issuance == + sum(transparent_output_values) + + sum(shielded_output_values) + + fee + ) + assert anchor is None if len(shielded_inputs) == 0 else ( + anchor is not None and anchor.can_spend(shielded_inputs)) + + self.transparent_inputs = transparent_inputs + self.transparent_outputs = [self._TXO(self, i, v) + for (i, v) in enumerate(transparent_output_values)] + self.shielded_inputs = shielded_inputs + self.shielded_outputs = [self._Note(v) for v in shielded_output_values] + self.fee = fee + self.anchor = anchor + self.issuance = issuance + + def transparent_input(self, index: int) -> BCTransaction._TXO: + """Returns the transparent input TXO with the given index.""" + return self.transparent_inputs[index] + + def transparent_output(self, index: int) -> BCTransaction._TXO: + """Returns the transparent output TXO with the given index.""" + return self.transparent_outputs[index] + + def shielded_input(self, index: int) -> BCTransaction._Note: + """Returns the shielded input note with the given index.""" + return self.shielded_inputs[index] + + def shielded_output(self, index: int) -> BCTransaction._Note: + """Returns the shielded output note with the given index.""" + return self.shielded_outputs[index] + + def is_coinbase(self) -> bool: + """ + Returns `True` if this is a coinbase transaction (it has no inputs). + """ + return len(self.transparent_inputs) + len(self.shielded_inputs) == 0 + + +class Spentness(Enum): + """The spentness status of a note.""" + Unspent = auto() + """The note is unspent.""" + Spent = auto() + """The note is spent.""" + + +class BCContext: + """ + A context that allows checking transactions for contextual validity in a + best-chain protocol. + """ + + assert version_info >= (3, 7), "This code relies on insertion-ordered dicts." + + def __init__(self): + """Constructs an empty `BCContext`.""" + self.transactions: deque[BCTransaction] = deque() + self.utxo_set: set[BCTransaction._TXO] = set() + + # Since dicts are insertion-ordered, this models the sequence in which + # notes are committed as well as their spentness. + self.notes: dict[BCTransaction._Note, Spentness] = {} + + self.total_issuance = 0 + + def committed_notes(self) -> list[(BCTransaction._Note, Spentness)]: + """ + Returns a list of (`Note`, `Spentness`) for notes added to this context, + preserving the commitment order. + """ + return list(self.notes.items()) + + def can_spend(self, tospend: Iterable[BCTransaction._Note]) -> bool: + """Can all of the notes in `tospend` be spent in this context?""" + return all((self.notes.get(note) == Spentness.Unspent for note in tospend)) + + def _check(self, tx: BCTransaction) -> tuple[bool, set[BCTransaction._TXO]]: + """ + Checks whether `tx` is valid. To avoid recomputation, this returns + a pair of the validity, and the set of transparent inputs of `tx`. + """ + txins = set(tx.transparent_inputs) + valid = txins.issubset(self.utxo_set) and self.can_spend(tx.shielded_inputs) + return (valid, txins) + + def is_valid(self, tx: BCTransaction) -> bool: + """Is `tx` valid in this context?""" + return self._check(tx)[0] + + def add_if_valid(self, tx: BCTransaction) -> bool: + """ + If `tx` is valid in this context, add it to the context and return `True`. + Otherwise leave the context unchanged and return `False`. + """ + (valid, txins) = self._check(tx) + if valid: + self.utxo_set -= txins + self.utxo_set |= set(tx.transparent_outputs) + + for note in tx.shielded_inputs: + self.notes[note] = Spentness.Spent + for note in tx.shielded_outputs: + assert note not in self.notes + self.notes[note] = Spentness.Unspent + + self.total_issuance += tx.issuance + self.transactions.append(tx) + + return valid + + def copy(self) -> BCContext: + """Returns an independent copy of this `BCContext`.""" + ctx = BCContext() + ctx.transactions = self.transactions.copy() + ctx.utxo_set = self.utxo_set.copy() + ctx.notes = self.notes.copy() + ctx.total_issuance = self.total_issuance + return ctx + + +class BCBlock: + """A block in a best-chain protocol.""" + + def __init__(self, + parent: Optional[BCBlock], + added_score: int, + transactions: Sequence[BCTransaction], + allow_invalid: bool=False): + """ + Constructs a `BCBlock` with the given parent block, score relative to the + parent, and sequence of transactions. `transactions` must not be modified + after passing it to this constructor (copy it if necessary). + If `allow_invalid` is set, the block need not be valid. + Use `parent=None` to construct the genesis block. + """ + self.parent = parent + self.score = added_score + if self.parent is not None: + self.score += self.parent.score + self.transactions = transactions + self.hash = BlockHash() + if not allow_invalid: + self.assert_noncontextually_valid() + + def assert_noncontextually_valid(self) -> None: + """Assert that non-contextual consensus rules are satisfied for this block.""" + assert len(self.transactions) > 0 + assert self.transactions[0].is_coinbase() + assert not any((tx.is_coinbase() for tx in islice(self.transactions, 1, None))) + assert sum((tx.fee for tx in self.transactions)) == 0 + + def is_noncontextually_valid(self) -> bool: + """Are non-contextual consensus rules satisfied for this block?""" + try: + self.assert_noncontextually_valid() + return True + except AssertionError: + return False + + +@dataclass +class BCProtocol: + """A best-chain protocol.""" + + Transaction: TypeAlias = BCTransaction + """The type of transactions for this protocol.""" + + Context: TypeAlias = BCContext + """The type of contexts for this protocol.""" + + Block: TypeAlias = BCBlock + """The type of blocks for this protocol.""" + + +__all__ = ['BCTransaction', 'BCContext', 'BCBlock', 'BCProtocol', 'BlockHash', 'Spentness'] + + +import unittest + + +class TestBC(unittest.TestCase): + def test_basic(self) -> None: + ctx = BCContext() + coinbase_tx0 = BCTransaction([], [10], [], [], 0, issuance=10) + self.assertTrue(ctx.add_if_valid(coinbase_tx0)) + genesis = BCBlock(None, 1, [coinbase_tx0]) + self.assertEqual(genesis.score, 1) + self.assertEqual(ctx.total_issuance, 10) + + coinbase_tx1 = BCTransaction([], [6], [], [], -1, issuance=5) + spend_tx = BCTransaction([coinbase_tx0.transparent_output(0)], [9], [], [], 1) + self.assertTrue(ctx.add_if_valid(coinbase_tx1)) + self.assertTrue(ctx.add_if_valid(spend_tx)) + block1 = BCBlock(genesis, 1, [coinbase_tx1, spend_tx]) + self.assertEqual(block1.score, 2) + self.assertEqual(ctx.total_issuance, 15) + + coinbase_tx2 = BCTransaction([], [6], [], [], -1, issuance=5) + shielding_tx = BCTransaction([coinbase_tx1.transparent_output(0), spend_tx.transparent_output(0)], + [], [], [8, 6], 1) + self.assertTrue(ctx.add_if_valid(coinbase_tx2)) + self.assertTrue(ctx.add_if_valid(shielding_tx)) + block2 = BCBlock(block1, 2, [coinbase_tx2, shielding_tx]) + block2_anchor = ctx.copy() + self.assertEqual(block2.score, 4) + self.assertEqual(ctx.total_issuance, 20) + + coinbase_tx3 = BCTransaction([], [7], [], [], -2, issuance=5) + shielded_tx = BCTransaction([], [], [shielding_tx.shielded_output(0)], [7], 1, + anchor=block2_anchor) + deshielding_tx = BCTransaction([], [5], [shielding_tx.shielded_output(1)], [], 1, + anchor=block2_anchor) + self.assertTrue(ctx.add_if_valid(coinbase_tx3)) + self.assertTrue(ctx.add_if_valid(shielded_tx)) + self.assertTrue(ctx.add_if_valid(deshielding_tx)) + block3 = BCBlock(block2, 3, [coinbase_tx3, shielded_tx, deshielding_tx]) + self.assertEqual(block3.score, 7) + self.assertEqual(ctx.total_issuance, 25) diff --git a/simtfl/bft/__init__.py b/simtfl/bft/__init__.py index da82aea..8120bb5 100644 --- a/simtfl/bft/__init__.py +++ b/simtfl/bft/__init__.py @@ -7,168 +7,6 @@ protocols — but that's okay; it's a prototype. [CS2020] https://eprint.iacr.org/2020/088.pdf + [Crosslink] https://hackmd.io/JqENg--qSmyqRt_RqY7Whw?view """ - - -from __future__ import annotations - - -def two_thirds_threshold(n: int) -> int: - """ - Calculate the notarization threshold used in most permissioned BFT protocols: - `ceiling(n * 2/3)`. - """ - return (n * 2 + 2) // 3 - - -class PermissionedBFTBase: - """ - This class is used for the genesis block in a permissioned BFT protocol - (which is taken to be notarized, and therefore valid, by definition). - - It is also used as a base class for other BFT block and proposal classes. - """ - def __init__(self, n: int, t: int): - """ - Constructs a genesis block for a permissioned BFT protocol with - `n` nodes, of which at least `t` must sign each proposal. - """ - self.n = n - self.t = t - self.parent = None - - def last_final(self) -> PermissionedBFTBase: - """ - Returns the last final block in this block's ancestor chain. - For the genesis block, this is itself. - """ - return self - - -class PermissionedBFTBlock(PermissionedBFTBase): - """ - A block for a BFT protocol. Each non-genesis block is based on a - notarized proposal, and in practice consists of the proposer's signature - over the notarized proposal. - - Honest proposers must only ever sign at most one valid proposal for the - given epoch in which they are a proposer. - - BFT blocks are taken to be notarized, and therefore valid, by definition. - """ - - def __init__(self, proposal: PermissionedBFTProposal): - """Constructs a `PermissionedBFTBlock` for the given proposal.""" - super().__init__(proposal.n, proposal.t) - - proposal.assert_notarized() - self.proposal = proposal - self.parent = proposal.parent - - def last_final(self): - """ - Returns the last final block in this block's ancestor chain. - This should be overridden by subclasses; the default implementation - will (inefficiently) just return the genesis block. - """ - return self if self.parent is None else self.parent.last_final() - - -class PermissionedBFTProposal(PermissionedBFTBase): - """A proposal for a BFT protocol.""" - - def __init__(self, parent: PermissionedBFTBase): - """ - Constructs a `PermissionedBFTProposal` with the given parent - `PermissionedBFTBlock`. The parameters are determined by the parent - block. - """ - super().__init__(parent.n, parent.t) - self.parent = parent - self.signers = set() - - def assert_valid(self) -> None: - """ - Assert that this proposal is valid. This does not assert that it is - notarized. This should be overridden by subclasses. - """ - pass - - def is_valid(self) -> bool: - """Is this proposal valid?""" - try: - self.assert_valid() - return True - except AssertionError: - return False - - def assert_notarized(self) -> None: - """ - Assert that this proposal is notarized. A `PermissionedBFTProposal` - is notarized iff it is valid and has at least the threshold number of - signatures. - """ - self.assert_valid() - assert len(self.signers) >= self.t - - def is_notarized(self) -> bool: - """Is this proposal notarized?""" - try: - self.assert_notarized() - return True - except AssertionError: - return False - - def add_signature(self, index: int) -> None: - """ - Record that the node with the given `index` has signed this proposal. - If the same node signs more than once, the subsequent signatures are - ignored. - """ - self.signers.add(index) - assert len(self.signers) <= self.n - - -__all__ = ['two_thirds_threshold', 'PermissionedBFTBase', 'PermissionedBFTBlock', 'PermissionedBFTProposal'] - -import unittest - - -class TestPermissionedBFT(unittest.TestCase): - def test_basic(self) -> None: - # Construct the genesis block. - genesis = PermissionedBFTBase(5, 2) - current = genesis - self.assertEqual(current.last_final(), genesis) - - for _ in range(2): - proposal = PermissionedBFTProposal(current) - proposal.assert_valid() - self.assertTrue(proposal.is_valid()) - self.assertFalse(proposal.is_notarized()) - - # not enough signatures - proposal.add_signature(0) - self.assertFalse(proposal.is_notarized()) - - # same index, so we still only have one signature - proposal.add_signature(0) - self.assertFalse(proposal.is_notarized()) - - # different index, now we have two signatures as required - proposal.add_signature(1) - proposal.assert_notarized() - self.assertTrue(proposal.is_notarized()) - - current = PermissionedBFTBlock(proposal) - self.assertEqual(current.last_final(), genesis) - - def test_assertions(self) -> None: - genesis = PermissionedBFTBase(5, 2) - proposal = PermissionedBFTProposal(genesis) - self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) - proposal.add_signature(0) - self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) - proposal.add_signature(1) - _ = PermissionedBFTBlock(proposal) diff --git a/simtfl/bft/chain.py b/simtfl/bft/chain.py new file mode 100644 index 0000000..2b75f96 --- /dev/null +++ b/simtfl/bft/chain.py @@ -0,0 +1,209 @@ +""" +Abstractions for Byzantine Fault-Tolerant protocols. +""" + + +from __future__ import annotations + + +def two_thirds_threshold(n: int) -> int: + """ + Calculate the notarization threshold used in most permissioned BFT protocols: + `ceiling(n * 2/3)`. + """ + return (n * 2 + 2) // 3 + + +class PermissionedBFTBase: + """ + This class is used for the genesis block in a permissioned BFT protocol + (which is taken to be notarized, and therefore valid, by definition). + + It is also used as a base class for other BFT block and proposal classes. + """ + def __init__(self, n: int, t: int): + """ + Constructs a genesis block for a permissioned BFT protocol with + `n` nodes, of which at least `t` must sign each proposal. + """ + + self.n = n + """The number of voters.""" + + self.t = t + """The threshold of votes required for notarization.""" + + self.parent = None + """The genesis block has no parent (represented as `None`).""" + + self.length = 1 + """The genesis chain length is 1.""" + + self.last_final = self + """The last final block for the genesis block is itself.""" + + def preceq(self, other: PermissionedBFTBase): + """Return True if this block is an ancestor of `other`.""" + if self.length > other.length: + return False # optimization + return self == other or (other.parent is not None and self.preceq(other.parent)) + + def __eq__(self, other) -> bool: + return other.parent is None and (self.n, self.t) == (other.n, other.t) + + def __hash__(self) -> int: + return hash((self.n, self.t)) + + +class PermissionedBFTBlock(PermissionedBFTBase): + """ + A block for a BFT protocol. Each non-genesis block is based on a + notarized proposal, and in practice consists of the proposer's signature + over the notarized proposal. + + Honest proposers must only ever sign at most one valid proposal for the + given epoch in which they are a proposer. + + BFT blocks are taken to be notarized, and therefore valid, by definition. + """ + + def __init__(self, proposal: PermissionedBFTProposal): + """Constructs a `PermissionedBFTBlock` for the given proposal.""" + super().__init__(proposal.n, proposal.t) + + proposal.assert_notarized() + self.proposal = proposal + """The proposal for this block.""" + + assert proposal.parent is not None + self.parent = proposal.parent + """The parent of this block.""" + + self.length = proposal.length + """The chain length of this block.""" + + self.last_final = self.parent.last_final + """The last final block for this block.""" + + def __eq__(self, other) -> bool: + return (isinstance(other, PermissionedBFTBlock) and + (self.n, self.t, self.proposal) == (other.n, other.t, other.proposal)) + + def __hash__(self) -> int: + return hash((self.n, self.t, self.proposal)) + + +class PermissionedBFTProposal(PermissionedBFTBase): + """A proposal for a BFT protocol.""" + + def __init__(self, parent: PermissionedBFTBase): + """ + Constructs a `PermissionedBFTProposal` with the given parent + `PermissionedBFTBlock`. The parameters are determined by the parent + block. + """ + super().__init__(parent.n, parent.t) + + self.parent = parent + """The parent block of this proposal.""" + + self.length = parent.length + 1 + """The chain length of this proposal is one greater than its parent block.""" + + self.votes = set() + """The set of voter indices that have voted for this proposal.""" + + def __eq__(self, other): + """Two proposals are equal iff they are the same object.""" + return self is other + + def __hash__(self) -> int: + return id(self) + + def assert_valid(self) -> None: + """ + Assert that this proposal is valid. This does not assert that it is + notarized. This should be overridden by subclasses. + """ + pass + + def is_valid(self) -> bool: + """Is this proposal valid?""" + try: + self.assert_valid() + return True + except AssertionError: + return False + + def assert_notarized(self) -> None: + """ + Assert that this proposal is notarized. A `PermissionedBFTProposal` + is notarized iff it is valid and has at least the threshold number of + signatures. + """ + self.assert_valid() + assert len(self.votes) >= self.t + + def is_notarized(self) -> bool: + """Is this proposal notarized?""" + try: + self.assert_notarized() + return True + except AssertionError: + return False + + def add_vote(self, index: int) -> None: + """ + Record that the node with the given `index` has voted for this proposal. + Calls that add the same vote more than once are ignored. + """ + self.votes.add(index) + assert len(self.votes) <= self.n + + +__all__ = ['two_thirds_threshold', 'PermissionedBFTBase', 'PermissionedBFTBlock', 'PermissionedBFTProposal'] + +import unittest + + +class TestPermissionedBFT(unittest.TestCase): + def test_basic(self) -> None: + # Construct the genesis block. + genesis = PermissionedBFTBase(5, 2) + current = genesis + self.assertEqual(current.last_final, genesis) + + for _ in range(2): + parent = current + proposal = PermissionedBFTProposal(parent) + proposal.assert_valid() + self.assertTrue(proposal.is_valid()) + self.assertFalse(proposal.is_notarized()) + + # not enough votes + proposal.add_vote(0) + self.assertFalse(proposal.is_notarized()) + + # same index, so we still only have one vote + proposal.add_vote(0) + self.assertFalse(proposal.is_notarized()) + + # different index, now we have two votes as required + proposal.add_vote(1) + proposal.assert_notarized() + self.assertTrue(proposal.is_notarized()) + + current = PermissionedBFTBlock(proposal) + self.assertTrue(parent.preceq(current)) + self.assertFalse(current.preceq(parent)) + self.assertNotEqual(current, parent) + self.assertEqual(current.last_final, genesis) + + def test_assertions(self) -> None: + genesis = PermissionedBFTBase(5, 2) + proposal = PermissionedBFTProposal(genesis) + self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) + proposal.add_vote(0) + self.assertRaises(AssertionError, PermissionedBFTBlock, proposal) + proposal.add_vote(1) + _ = PermissionedBFTBlock(proposal) diff --git a/simtfl/bft/streamlet/__init__.py b/simtfl/bft/streamlet/__init__.py index 135f7fe..e9d7395 100644 --- a/simtfl/bft/streamlet/__init__.py +++ b/simtfl/bft/streamlet/__init__.py @@ -2,179 +2,6 @@ An implementation of adapted-Streamlet ([CS2020] as modified in [Crosslink]). [CS2020] https://eprint.iacr.org/2020/088.pdf + [Crosslink] https://hackmd.io/JqENg--qSmyqRt_RqY7Whw?view """ - - -from __future__ import annotations -from typing import Optional -from collections.abc import Sequence - -from .. import PermissionedBFTBase, PermissionedBFTBlock, PermissionedBFTProposal, \ - two_thirds_threshold - - -class StreamletProposal(PermissionedBFTProposal): - """An adapted-Streamlet proposal.""" - - def __init__(self, parent: StreamletBlock | StreamletGenesis, epoch: int): - """ - Constructs a `StreamletProposal` with the given parent `StreamletBlock`, - for the given `epoch`. The parameters are determined by the parent block. - A proposal must be for an epoch after its parent's epoch. - """ - super().__init__(parent) - assert epoch > parent.epoch - self.epoch = epoch - """The epoch of this proposal.""" - - def __repr__(self) -> str: - return "StreamletProposal(parent=%r, epoch=%r)" % (self.parent, self.epoch) - - -class StreamletGenesis(PermissionedBFTBase): - """An adapted-Streamlet genesis block.""" - - def __init__(self, n: int): - """ - Constructs a genesis block for adapted-Streamlet with `n` nodes. - """ - super().__init__(n, two_thirds_threshold(n)) - self.epoch = 0 - """The genesis block has epoch 0.""" - - def __repr__(self) -> str: - return "StreamletGenesis(n=%r)" % (self.n,) - - -class StreamletBlock(PermissionedBFTBlock): - """ - An adapted-Streamlet block. Each non-genesis Streamlet block is - based on a notarized `StreamletProposal`. - - `StreamletBlock`s are taken to be notarized by definition. - All validity conditions are enforced in the contructor. - """ - - def __init__(self, proposal: StreamletProposal): - """Constructs a `StreamletBlock` for the given proposal.""" - super().__init__(proposal) - self.epoch = proposal.epoch - - def last_final(self) -> StreamletBlock | StreamletGenesis: - """ - Returns the last final block in this block's ancestor chain. - In Streamlet this is the middle block of the last group of three - that were proposed in consecutive epochs. - """ - last = self - if last.parent is None: - return last - middle = last.parent - if middle.parent is None: - return middle - first = middle.parent - while True: - if first.parent is None: - return first - if (first.epoch + 1, middle.epoch + 1) == (middle.epoch, last.epoch): - return middle - (first, middle, last) = (first.parent, first, middle) - - def __repr__(self) -> str: - return "StreamletBlock(proposal=%r)" % (self.proposal,) - - -import unittest -from itertools import count - - -class TestStreamlet(unittest.TestCase): - def test_simple(self) -> None: - """ - Very simple example. - - 0 --- 1 --- 2 --- 3 - """ - self._test_last_final([0, 1, 2], [0, 0, 2]) - - def test_figure_1(self) -> None: - """ - Figure 1: Streamlet finalization example (without the invalid 'X' proposal). - - 0 --- 2 --- 5 --- 6 --- 7 - \ - -- 1 --- 3 - - 0 - Genesis - N - Notarized block - - This diagram implies the epoch 6 block is the last-final block in the - context of the epoch 7 block, because it is in the middle of 3 blocks - with consecutive epoch numbers, and 6 is the most recent such block. - - (We don't include the block/proposal with the red X because that's not - what we're testing.) - """ - self._test_last_final([0, 0, 1, None, 2, 5, 6], [0, 0, 0, 0, 0, 0, 6]) - - def test_complex(self) -> None: - """ - Safety Violation: due to three simultaneous properties: - - - 6 is `last_final` in the context of 7 - - 9 is `last_final` in the context of 10 - - 9 is not a descendant of 6 - - 0 --- 2 --- 5 --- 6 --- 7 - \ - -- 1 --- 3 --- 8 --- 9 --- 10 - """ - self._test_last_final([0, 0, 1, None, 2, 5, 6, 3, 8, 9], [0, 0, 0, 0, 0, 0, 6, 0, 0, 9]) - - def _test_last_final(self, parent_map: Sequence[Optional[int]], final_map: Sequence[int]) -> None: - """ - This test constructs a tree of proposals with structure determined by - `parent_map`, and asserts `block.last_final()` matches the structure - determined by `final_map`. - - parent_map: sequence of parent epoch numbers - final_map: sequence of final epoch numbers - """ - - assert len(parent_map) == len(final_map) - - # Construct the genesis block. - genesis = StreamletGenesis(3) - current = genesis - self.assertEqual(current.last_final(), genesis) - blocks = [genesis] - - for (epoch, parent_epoch, final_epoch) in zip(count(1), parent_map, final_map): - if parent_epoch is None: - blocks.append(None) - continue - - parent = blocks[parent_epoch] - assert parent is not None - proposal = StreamletProposal(parent, epoch) - proposal.assert_valid() - self.assertTrue(proposal.is_valid()) - self.assertFalse(proposal.is_notarized()) - - # not enough signatures - proposal.add_signature(0) - self.assertFalse(proposal.is_notarized()) - - # same index, so we still only have one signature - proposal.add_signature(0) - self.assertFalse(proposal.is_notarized()) - - # different index, now we have two signatures as required - proposal.add_signature(1) - proposal.assert_notarized() - self.assertTrue(proposal.is_notarized()) - - current = StreamletBlock(proposal) - blocks.append(current) - self.assertEqual(current.last_final(), blocks[final_epoch]) diff --git a/simtfl/bft/streamlet/chain.py b/simtfl/bft/streamlet/chain.py new file mode 100644 index 0000000..ac64ed4 --- /dev/null +++ b/simtfl/bft/streamlet/chain.py @@ -0,0 +1,100 @@ +""" +Adapted-Streamlet chain classes. +""" + + +from __future__ import annotations +from typing import Optional + +from ..chain import PermissionedBFTBase, PermissionedBFTBlock, PermissionedBFTProposal, \ + two_thirds_threshold + + +class StreamletProposal(PermissionedBFTProposal): + """An adapted-Streamlet proposal.""" + + def __init__(self, parent: StreamletBlock | StreamletGenesis, epoch: int): + """ + Constructs a `StreamletProposal` with the given parent `StreamletBlock`, + for the given `epoch`. The parameters are determined by the parent block. + A proposal must be for an epoch after its parent's epoch. + """ + super().__init__(parent) + self.parent: StreamletBlock | StreamletGenesis = parent + + assert epoch > parent.epoch + self.epoch = epoch + """The epoch of this proposal.""" + + def __str__(self) -> str: + return f"StreamletProposal(parent={self.parent}, epoch={self.epoch}, length={self.length})" + + +class StreamletGenesis(PermissionedBFTBase): + """An adapted-Streamlet genesis block.""" + + def __init__(self, n: int): + """ + Constructs a genesis block for adapted-Streamlet with `n` nodes. + """ + super().__init__(n, two_thirds_threshold(n)) + + self.parent: Optional[StreamletBlock | StreamletGenesis] = None + """The genesis block has no parent (represented as `None`).""" + + self.epoch = 0 + """The epoch of the genesis block is 0.""" + + self.last_final = self + """The last final block of the genesis block is itself.""" + + def __str__(self) -> str: + return f"StreamletGenesis(n={self.n})" + + def proposer_for_epoch(self, epoch: int): + assert epoch > 0 + return (epoch - 1) % self.n + + +class StreamletBlock(PermissionedBFTBlock): + """ + An adapted-Streamlet block. Each non-genesis Streamlet block is + based on a notarized `StreamletProposal`. + + `StreamletBlock`s are taken to be notarized by definition. + All validity conditions are enforced in the contructor. + """ + + def __init__(self, proposal: StreamletProposal): + """Constructs a `StreamletBlock` for the given proposal.""" + super().__init__(proposal) + + self.epoch = proposal.epoch + """The epoch of this proposal.""" + + self.parent: StreamletBlock | StreamletGenesis = proposal.parent + + self.last_final = self._compute_last_final() + """ + The last final block in this block's ancestor chain. + In Streamlet this is the middle block of the last group of three + that were proposed in consecutive epochs. + """ + + def _compute_last_final(self) -> StreamletBlock | StreamletGenesis: + last: StreamletBlock | StreamletGenesis = self + if last.parent is None: + return last + middle: StreamletBlock | StreamletGenesis = last.parent + if middle.parent is None: + return middle + first: StreamletBlock | StreamletGenesis = middle.parent + while True: + if first.parent is None: + return first + if (first.epoch + 1, middle.epoch + 1) == (middle.epoch, last.epoch): + return middle + (first, middle, last) = (first.parent, first, middle) + + def __str__(self) -> str: + return f"StreamletBlock(proposal={self.proposal})" diff --git a/simtfl/bft/streamlet/node.py b/simtfl/bft/streamlet/node.py index 9838738..484500d 100644 --- a/simtfl/bft/streamlet/node.py +++ b/simtfl/bft/streamlet/node.py @@ -4,12 +4,15 @@ from __future__ import annotations +from typing import Optional +from collections.abc import Sequence +from dataclasses import dataclass from ...node import SequentialNode from ...message import Message, PayloadMessage from ...util import skip, ProcessEffect -from . import StreamletGenesis, StreamletBlock, StreamletProposal +from .chain import StreamletGenesis, StreamletBlock, StreamletProposal class Echo(PayloadMessage): @@ -20,6 +23,35 @@ class Echo(PayloadMessage): pass +@dataclass(frozen=True) +class Ballot(Message): + """ + A ballot message, recording that a voter has voted for a `StreamletProposal`. + Ballots should not be forged unless modelling an attack that allows doing so. + """ + proposal: StreamletProposal + """The proposal.""" + voter: int + """The voter.""" + + def __str__(self) -> str: + return f"Ballot({self.proposal}, voter={self.voter})" + + +class Proposal(PayloadMessage): + """ + A message containing a `StreamletProposal`. + """ + pass + + +class Block(PayloadMessage): + """ + A message containing a `StreamletBlock`. + """ + pass + + class StreamletNode(SequentialNode): """ A Streamlet node. @@ -32,7 +64,32 @@ def __init__(self, genesis: StreamletGenesis): """ assert genesis.epoch == 0 self.genesis = genesis + """The genesis block.""" + self.voted_epoch = genesis.epoch + """The last epoch on which this node voted.""" + + self.tip: StreamletBlock | StreamletGenesis = genesis + """ + A longest chain seen by this node. The node's last final block is given by + `self.tip.last_final`. + """ + + self.proposal: Optional[StreamletProposal] = None + """The current proposal by this node, when it is the proposer.""" + + self.safety_violations: set[tuple[StreamletBlock | StreamletGenesis, + StreamletBlock | StreamletGenesis]] = set() + """The set of safety violations detected by this node.""" + + def propose(self, proposal: StreamletProposal) -> ProcessEffect: + """ + (process) Ask the node to make a proposal. + """ + assert proposal.is_valid() + assert proposal.epoch > self.voted_epoch + self.proposal = proposal + return self.broadcast(Proposal(proposal), False) def handle(self, sender: int, message: Message) -> ProcessEffect: """ @@ -42,32 +99,242 @@ def handle(self, sender: int, message: Message) -> ProcessEffect: (This causes the number of messages to blow up by a factor of `n`, but it's what the Streamlet paper specifies and is necessary for its liveness proof.) - * Received non-duplicate proposals may cause us to send a `Vote`. - * ... + * Receiving a non-duplicate `Proposal` may cause us to broadcast a `Ballot`. + * If we are the current proposer, keep track of ballots for our proposal. + * Receiving a `Block` may cause us to update our `tip`. """ if isinstance(message, Echo): message = message.payload else: - yield from self.broadcast(Echo(message)) + yield from self.broadcast(Echo(message), False) - if isinstance(message, StreamletProposal): - yield from self.handle_proposal(message) - elif isinstance(message, StreamletBlock): - yield from self.handle_block(message) + if isinstance(message, Proposal): + yield from self.handle_proposal(message.payload) + elif isinstance(message, Block): + yield from self.handle_block(message.payload) + elif isinstance(message, Ballot): + yield from self.handle_ballot(message) else: yield from super().handle(sender, message) def handle_proposal(self, proposal: StreamletProposal) -> ProcessEffect: """ (process) If we already voted in the epoch specified by the proposal or a - later epoch, ignore this proposal. + later epoch, ignore this proposal. Otherwise, cast a vote for it iff it + is valid. """ if proposal.epoch <= self.voted_epoch: - self.log("handle", + self.log("proposal", f"received proposal for epoch {proposal.epoch} but we already voted in epoch {self.voted_epoch}") return skip() - return skip() + if proposal.is_valid(): + self.log("proposal", f"voting for {proposal}") + # For now we just forget that we made a proposal if we receive a different + # valid one from another node. This is not realistic. Note that we can and + # should vote for our own proposal. + if proposal != self.proposal: + self.proposal = None + + self.voted_epoch = proposal.epoch + return self.broadcast(Ballot(proposal, self.ident), True) + else: + return skip() def handle_block(self, block: StreamletBlock) -> ProcessEffect: - raise NotImplementedError + """ + If `block.last_final` does not descend from `self.tip.last_final`, reject the block. + (In this case, if also `self.tip.last_final` does not descend from `block.last_final`, + this is a detected safety violation.) + + Otherwise, update `self.tip` to `block` iff `block` is later in lexicographic ordering + by `(length, epoch)`. + """ + if not self.tip.last_final.preceq(block.last_final): + self.log("block", f"× not ⪰ last_final: {block}") + if not block.last_final.preceq(self.tip.last_final): + self.log("block", f"! safety violation: ({block}, {self.tip})") + self.safety_violations.add((block, self.tip)) + return skip() + + # TODO: analyse tie-breaking rule. + if (self.tip.length, self.tip.epoch) >= (block.length, block.epoch): + self.log("block", f"× not updating tip: {block}") + return skip() + + self.log("block", f"✓ updating tip: {block}") + self.tip = block + return skip() + + def handle_ballot(self, ballot: Ballot) -> ProcessEffect: + """ + If we have made a proposal that is not yet notarized and the ballot is + for that proposal, add the vote. If it is now notarized, broadcast it + as a block. + """ + proposal = ballot.proposal + if proposal == self.proposal: + self.log("count", f"{ballot.voter} voted for our proposal in epoch {proposal.epoch}") + proposal.add_vote(ballot.voter) + if proposal.is_notarized(): + yield from self.broadcast(Block(StreamletBlock(proposal)), True) + # It's fine to forget that we made the proposal now. + self.proposal = None + + def final_block(self) -> StreamletBlock | StreamletGenesis: + """ + Return the last final block seen by this node. + """ + return self.tip.last_final + + +__all__ = ['Echo', 'Ballot', 'StreamletNode'] + +import unittest +from itertools import count +from simpy import Environment +from simpy.events import Process, Timeout + +from ...network import Network +from ...logging import PrintLogger + + +class TestStreamlet(unittest.TestCase): + def test_simple(self) -> None: + """ + Very simple example. + + 0 --- 1 --- 2 --- 3 + """ + self._test_last_final([0, 1, 2], + [0, 0, 2]) + + def test_figure_1(self) -> None: + """ + Figure 1: Streamlet finalization example (without the invalid 'X' proposal). + + 0 --- 2 --- 5 --- 6 --- 7 + \ + -- 1 --- 3 + + 0 - Genesis + N - Notarized block + + This diagram implies the epoch 6 block is the last-final block in the + context of the epoch 7 block, because it is in the middle of 3 blocks + with consecutive epoch numbers, and 6 is the most recent such block. + + (We don't include the block/proposal with the red X because that's not + what we're testing.) + """ + N = None + self._test_last_final([0, 0, 1, N, 2, 5, 6], + [0, 0, 0, 0, 0, 0, 6]) + + def test_complex(self) -> None: + """ + Safety Violation: due to three simultaneous properties: + + - 6 is `last_final` in the context of 7 + - 9 is `last_final` in the context of 10 + - 9 is not a descendant of 6 + + 0 --- 2 --- 5 --- 6 --- 7 + \ + -- 1 --- 3 --- 8 --- 9 --- 10 + """ + N = None + self._test_last_final([0, 0, 1, N, 2, 5, 6, 3, 8, 9], + [0, 0, 0, 0, 0, 0, 6, 0, 0, 9], + expect_divergence_at_epoch=8, + expect_safety_violations={(10, 7)}) + + def _test_last_final(self, + parent_map: Sequence[Optional[int]], + final_map: Sequence[int], + expect_divergence_at_epoch: Optional[int]=None, + expect_safety_violations: set[tuple[int, int]]=set()) -> None: + """ + This test constructs a tree of proposals with structure determined by + `parent_map`, and asserts `block.last_final` matches the structure + determined by `final_map`. + + parent_map: sequence of parent epoch numbers + final_map: sequence of final epoch numbers + expect_divergence_at_epoch: first epoch at which a block does not become the new tip + expect_safety_violations: safety violation proofs + """ + + assert len(parent_map) == len(final_map) + + # Construct the genesis block. + genesis = StreamletGenesis(3) + network = Network(Environment(), logger=PrintLogger()) + for _ in range(genesis.n): + network.add_node(StreamletNode(genesis)) + + current = genesis + self.assertEqual(current.last_final, genesis) + blocks: list[Optional[StreamletBlock | StreamletGenesis]] = [genesis] + + def run() -> ProcessEffect: + for (epoch, parent_epoch, final_epoch) in zip(count(1), parent_map, final_map): + yield Timeout(network.env, 10) + if parent_epoch is None: + blocks.append(None) + continue + + parent = blocks[parent_epoch] + assert parent is not None + proposer = network.node(genesis.proposer_for_epoch(epoch)) + proposal = StreamletProposal(parent, epoch) + self.assertEqual(proposal.length, parent.length + 1) + proposal.assert_valid() + self.assertFalse(proposal.is_notarized()) + + proposer.propose(proposal) + yield Timeout(network.env, 10) + + # The proposer should have sent the block. + assert proposer.proposal is None + + # Make a fake block `current` from the proposal so that we can append + # it to `blocks` and check its `last_final`. + current = StreamletBlock(proposal) + self.assertEqual(current.length, proposal.length) + self.assertTrue(parent.preceq(current)) + self.assertFalse(current.preceq(parent)) + self.assertEqual(len(blocks), current.epoch) + blocks.append(current) + final_block = blocks[final_epoch] + assert final_block is not None + self.assertEqual(current.last_final, final_block) + + # All nodes' tips should be the same. + tip = network.node(0).tip + for i in range(1, network.num_nodes()): + self.assertEqual(network.node(i).tip, tip) + + # If we try to create a new block on top of a chain that is not the longest, + # the nodes will ignore it. + if epoch == expect_divergence_at_epoch: + self.assertLess(current.length, tip.length) + elif expect_divergence_at_epoch is None or epoch < expect_divergence_at_epoch: + self.assertEqual(current.length, tip.length) + self.assertEqual(tip.epoch, epoch) + self.assertEqual(tip.proposal, proposal) + + for node in network.nodes: + node_final = node.final_block() + self.assertEqual(node_final, final_block, + f"epoch {node_final.epoch} != epoch {final_block.epoch}") + + for node in network.nodes: + self.assertEqual(set(((a.epoch, b.epoch) for (a, b) in node.safety_violations)), + expect_safety_violations) + + network.done = True + + Process(network.env, run()) + network.run_all() + self.assertTrue(network.done) diff --git a/simtfl/message.py b/simtfl/message.py index 003919e..d1226f0 100644 --- a/simtfl/message.py +++ b/simtfl/message.py @@ -22,3 +22,6 @@ class PayloadMessage(Message): """ payload: Any """The payload.""" + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.payload})" diff --git a/simtfl/network.py b/simtfl/network.py index 0b98eb3..e5f3e88 100644 --- a/simtfl/network.py +++ b/simtfl/network.py @@ -30,7 +30,7 @@ def initialize(self, ident: int, env: Environment, network: Network): self.env = env self.network = network - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}" def log(self, event: str, detail: str): @@ -47,13 +47,13 @@ def send(self, target: int, message: Message, delay: Optional[Number]=None) -> P """ return self.network.send(self.ident, target, message, delay=delay) - def broadcast(self, message: Message, delay: Optional[Number]=None) -> ProcessEffect: + def broadcast(self, message: Message, include_self: bool, delay: Optional[Number]=None) -> ProcessEffect: """ (process) This method can be overridden to intercept messages being broadcast by this node. The implementation in this class calls `self.network.broadcast` with this node as the sender. """ - return self.network.broadcast(self.ident, message, delay=delay) + return self.network.broadcast(self.ident, message, include_self, delay=delay) def receive(self, sender: int, message: Message) -> ProcessEffect: """ @@ -86,8 +86,14 @@ def __init__(self, env: Environment, nodes: Optional[list[Node]]=None, delay: Nu a set of initial nodes, message propagation delay, and logger. """ self.env = env + """The `simpy.Environment`.""" + self.nodes = nodes or [] + """The nodes in this network.""" + self.delay = delay + """The message propagation delay.""" + self._logger = logger logger.header() @@ -166,19 +172,21 @@ def send(self, sender: int, target: int, message: Message, delay: Optional[Numbe # TODO: make it take some time on the sending node. return skip() - def broadcast(self, sender: int, message: Message, delay: Optional[Number]=None) -> ProcessEffect: + def broadcast(self, sender: int, message: Message, include_self: bool, + delay: Optional[Number]=None) -> ProcessEffect: """ - (process) Broadcasts a message to every other node. The message - propagation delay is normally given by `self.delay`, but can be - overridden by the `delay` parameter. + (process) Broadcasts a message to every node (including ourself only when + `include_self` is set). The message propagation delay is normally given by + `self.delay`, but can be overridden by the `delay` parameter. """ if delay is None: delay = self.delay - self.log(sender, "broadcast", f"to * with delay {delay:2d}: {message}") + c = "+" if include_self else "-" + self.log(sender, "broadcast", f"to {c}* with delay {delay:2d}: {message}") # Run `convey` in a new process for each node. for target in range(self.num_nodes()): - if target != sender: + if include_self or target != sender: Process(self.env, self.convey(delay, sender, target, message)) # Broadcasting is currently instantaneous. diff --git a/simtfl/node.py b/simtfl/node.py index f921f2c..ea00ec8 100644 --- a/simtfl/node.py +++ b/simtfl/node.py @@ -89,7 +89,7 @@ def run(self) -> ProcessEffect: while True: while len(self._mailbox) > 0: (sender, message) = self._mailbox.popleft() - self.log("handle", f"from {sender:2d}: {message}") + self.log("handle", f"from {sender:2d}: {message}") yield from self.handle(sender, message) # This naive implementation is fine because we have no actual @@ -147,7 +147,7 @@ def run(self) -> ProcessEffect: yield Timeout(self.env, 1) # This message is broadcast at time 4 and received at time 5. - yield from self.broadcast(PayloadMessage(4)) + yield from self.broadcast(PayloadMessage(4), False) class TestFramework(unittest.TestCase):