Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions lib/cusf_enforcer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use futures::{FutureExt as _, TryFutureExt as _, TryStreamExt as _};
use thiserror::Error;
use tracing::instrument;

// TODO: Enable specifying txs that can be restored to the mempool
#[derive(Clone, Debug)]
pub enum ConnectBlockAction {
Accept { remove_mempool_txs: HashSet<Txid> },
Expand All @@ -27,6 +28,12 @@ impl Default for ConnectBlockAction {
}
}

// TODO: Enable specifying txs that can be restored to the mempool
#[derive(Clone, Debug, Default)]
pub struct DisconnectBlockAction {
pub remove_mempool_txs: HashSet<Txid>,
}

#[derive(Clone, Debug)]
pub enum TxAcceptAction {
Accept {
Expand Down Expand Up @@ -60,7 +67,9 @@ pub trait CusfEnforcer {
fn disconnect_block(
&mut self,
block_hash: BlockHash,
) -> impl Future<Output = Result<(), Self::DisconnectBlockError>> + Send;
) -> impl Future<
Output = Result<DisconnectBlockAction, Self::DisconnectBlockError>,
> + Send;

type AcceptTxError: std::error::Error + Send + Sync + 'static;

Expand Down Expand Up @@ -315,7 +324,9 @@ where
}
}
BlockHashEvent::Disconnected => {
let () = enforcer
let DisconnectBlockAction {
remove_mempool_txs: _,
} = enforcer
.disconnect_block(block_hash)
.map_err(TaskError::DisconnectBlock)
.await?;
Expand Down Expand Up @@ -410,7 +421,9 @@ where
},
) => {
// Disconnect block on right enforcer
let () = self
let DisconnectBlockAction {
remove_mempool_txs: _,
} = self
.1
.disconnect_block(block.block_hash())
.map_err(|err| {
Expand All @@ -428,7 +441,9 @@ where
ConnectBlockAction::Reject,
) => {
// Disconnect block on left enforcer
let () = self
let DisconnectBlockAction {
remove_mempool_txs: _,
} = self
.0
.disconnect_block(block.block_hash())
.map_err(|err| {
Expand All @@ -451,16 +466,19 @@ where
async fn disconnect_block(
&mut self,
block_hash: BlockHash,
) -> Result<(), Self::DisconnectBlockError> {
let () = self
) -> Result<DisconnectBlockAction, Self::DisconnectBlockError> {
let mut res = self
.0
.disconnect_block(block_hash)
.map_err(Either::Left)
.await?;
self.1
let DisconnectBlockAction { remove_mempool_txs } = self
.1
.disconnect_block(block_hash)
.map_err(Either::Right)
.await
.await?;
res.remove_mempool_txs.extend(remove_mempool_txs);
Ok(res)
}

type AcceptTxError = Either<C0::AcceptTxError, C1::AcceptTxError>;
Expand Down Expand Up @@ -521,8 +539,8 @@ impl CusfEnforcer for DefaultEnforcer {
async fn disconnect_block(
&mut self,
_block_hash: BlockHash,
) -> Result<(), Self::DisconnectBlockError> {
Ok(())
) -> Result<DisconnectBlockAction, Self::DisconnectBlockError> {
Ok(DisconnectBlockAction::default())
}

type AcceptTxError = Infallible;
Expand Down Expand Up @@ -592,7 +610,7 @@ where
async fn disconnect_block(
&mut self,
block_hash: BlockHash,
) -> Result<(), Self::DisconnectBlockError> {
) -> Result<DisconnectBlockAction, Self::DisconnectBlockError> {
match self {
Self::Left(left) => {
left.disconnect_block(block_hash)
Expand Down
10 changes: 8 additions & 2 deletions lib/mempool/sync/initial_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ use super::{
RequestItem, RequestQueue, ResponseItem,
};
use crate::{
cusf_enforcer::{self, ConnectBlockAction, CusfEnforcer},
cusf_enforcer::{
self, ConnectBlockAction, CusfEnforcer, DisconnectBlockAction,
},
zmq::{
BlockHashEvent, BlockHashMessage, SequenceMessage, SequenceStream,
SequenceStreamError, TxHashEvent, TxHashMessage,
Expand Down Expand Up @@ -173,11 +175,15 @@ where
// FIXME: insert without info
let () = todo!();
}
let () = sync_state
let DisconnectBlockAction { remove_mempool_txs } = sync_state
.enforcer
.disconnect_block(block.hash)
.await
.map_err(cusf_enforcer::Error::DisconnectBlock)?;
sync_state
.post_sync
.remove_mempool_txs
.extend(remove_mempool_txs);
sync_state.mempool.chain.tip =
block.previousblockhash.unwrap_or_else(BlockHash::all_zeros);
Ok(())
Expand Down
Loading