diff --git a/mempool/recheck_pool.go b/mempool/recheck_pool.go index 6616006d8..1567cc64b 100644 --- a/mempool/recheck_pool.go +++ b/mempool/recheck_pool.go @@ -203,7 +203,7 @@ func (m *RecheckMempool) Insert(_ context.Context, tx sdk.Tx) error { } write() - m.markTxRechecked(tx) + m.markTxInserted(tx) return nil } @@ -452,6 +452,18 @@ func (m *RecheckMempool) markTxRechecked(txn sdk.Tx) { m.recheckedTxs.Do(func(store *CosmosTxStore) { store.AddTx(txn) }) } +// markTxInserted conservatively updates the current height snapshot for live inserts. +// If the inserted tx replaces an existing tx, any other txs from the same sender with +// a higher nonce is dropped and rebuilt by the next recheck. +func (m *RecheckMempool) markTxInserted(txn sdk.Tx) { + m.recheckedTxs.Do(func(store *CosmosTxStore) { + if store.InvalidateFrom(txn) > 0 { + return + } + store.AddTx(txn) + }) +} + type signerSequence struct { account string seq uint64 diff --git a/mempool/recheck_pool_test.go b/mempool/recheck_pool_test.go index 566a8b001..395233ab2 100644 --- a/mempool/recheck_pool_test.go +++ b/mempool/recheck_pool_test.go @@ -23,6 +23,7 @@ import ( "github.com/cosmos/evm/mempool/reserver" "cosmossdk.io/log/v2" + sdkmath "cosmossdk.io/math" storetypes "cosmossdk.io/store/types" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" @@ -924,6 +925,16 @@ func newRecheckTestTxWithNonce(t *testing.T, key *ecdsa.PrivateKey, nonce uint64 return &recheckTestTx{key: key, sequence: nonce} } +func newRecheckTestTxWithGasPrice(t *testing.T, key *ecdsa.PrivateKey, nonce uint64, gasPrice int64) sdk.Tx { + t.Helper() + return &recheckTestTx{ + key: key, + sequence: nonce, + gas: 100_000, + fee: sdk.NewCoins(sdk.NewInt64Coin(recheckTestFeeDenom, gasPrice*100_000)), + } +} + // newNonceTrackingAnteHandler returns an ante handler that enforces sequential // nonce ordering per account. Nonces are tracked in a map keyed by signer // address — each successful call increments the expected nonce. @@ -1037,6 +1048,110 @@ func TestRecheckMempool_InsertAfterRecheck(t *testing.T) { require.Equal(t, 3, mp.CountTx()) } +func TestRecheckMempool_InsertReplacementInvalidatesRechecked(t *testing.T) { + ctx := newRecheckTestContext() + tracker := reserver.NewReservationTracker() + handle := tracker.NewHandle(1) + pool := &recheckMockPool{} + bc := newMockContextProvider(ctx) + rc := newMockRechecker(ctx, noopAnteHandler) + recheckedTxs := newTestRecheckedTxs() + + mp := mempool.NewRecheckMempool(log.NewNopLogger(), pool, handle, rc, recheckedTxs, bc) + mp.Start(testHeader(0)) + t.Cleanup(func() { + require.NoError(t, mp.Close()) + }) + + key, err := crypto.GenerateKey() + require.NoError(t, err) + + tx4 := newRecheckTestTxWithNonce(t, key, 4) + tx5 := newRecheckTestTxWithNonce(t, key, 5) + tx6 := newRecheckTestTxWithNonce(t, key, 6) + + recheckedTxs.Do(func(store *mempool.CosmosTxStore) { + store.AddTx(tx4) + store.AddTx(tx5) + store.AddTx(tx6) + }) + + replacement := newRecheckTestTxWithNonce(t, key, 4) + require.NoError(t, mp.Insert(ctx, replacement)) + + iter := mp.RecheckedTxs(context.Background(), big.NewInt(0)) + rechecked := collectIteratorTxs(iter) + require.Empty(t, rechecked) + require.Equal(t, 1, mp.CountTx()) +} + +func TestRecheckMempool_RecheckRebuildsSnapshotAfterReplacement(t *testing.T) { + ctx := newRecheckTestContext() + tracker := reserver.NewReservationTracker() + handle := tracker.NewHandle(1) + pool := sdkmempool.NewPriorityMempool(sdkmempool.PriorityNonceMempoolConfig[sdkmath.Int]{ + TxPriority: sdkmempool.TxPriority[sdkmath.Int]{ + GetTxPriority: func(goCtx context.Context, tx sdk.Tx) sdkmath.Int { + _ = sdk.UnwrapSDKContext(goCtx) + cosmosTxFee, ok := tx.(sdk.FeeTx) + if !ok { + return sdkmath.ZeroInt() + } + found, coin := cosmosTxFee.GetFee().Find(recheckTestFeeDenom) + if !found { + return sdkmath.ZeroInt() + } + + gasPrice := coin.Amount.Quo(sdkmath.NewIntFromUint64(cosmosTxFee.GetGas())) + return gasPrice + }, + Compare: func(a, b sdkmath.Int) int { + return a.BigInt().Cmp(b.BigInt()) + }, + MinValue: sdkmath.ZeroInt(), + }, + TxReplacement: func(op, np sdkmath.Int, _ sdk.Tx, _ sdk.Tx) bool { + return np.GT(op) + }, + }) + bc := newMockContextProvider(ctx) + rc := newMockRechecker(ctx, noopAnteHandler) + recheckedTxs := newTestRecheckedTxs() + + mp := mempool.NewRecheckMempool(log.NewNopLogger(), pool, handle, rc, recheckedTxs, bc) + mp.Start(testHeader(0)) + t.Cleanup(func() { + require.NoError(t, mp.Close()) + }) + + key, err := crypto.GenerateKey() + require.NoError(t, err) + + tx3 := newRecheckTestTxWithGasPrice(t, key, 3, 1) + tx4 := newRecheckTestTxWithGasPrice(t, key, 4, 1) + tx5 := newRecheckTestTxWithGasPrice(t, key, 5, 1) + tx6 := newRecheckTestTxWithGasPrice(t, key, 6, 1) + replacement := newRecheckTestTxWithGasPrice(t, key, 4, 2) + + for _, tx := range []sdk.Tx{tx3, tx4, tx5, tx6} { + require.NoError(t, mp.Insert(ctx, tx)) + } + + // insert the replacement, which should invalidate the other txs in the pool with greater nonce. + require.NoError(t, mp.Insert(ctx, replacement)) + + iter := mp.RecheckedTxs(context.Background(), big.NewInt(0)) + rechecked := collectIteratorTxs(iter) + require.Len(t, rechecked, 1) + require.Equal(t, tx3, rechecked[0]) + + mp.TriggerRecheckSync(testHeader(1)) + + iter = mp.RecheckedTxs(context.Background(), big.NewInt(1)) + rechecked = collectIteratorTxs(iter) + require.Equal(t, []sdk.Tx{tx3, replacement, tx5, tx6}, rechecked) +} + // newRecheckTestTx creates a minimal sdk.Tx for unit testing RecheckMempool. func newRecheckTestTx(t *testing.T, key *ecdsa.PrivateKey) sdk.Tx { t.Helper() @@ -1047,14 +1162,44 @@ func newRecheckTestTx(t *testing.T, key *ecdsa.PrivateKey) sdk.Tx { type recheckTestTx struct { key *ecdsa.PrivateKey sequence uint64 + gas uint64 + fee sdk.Coins } +const recheckTestFeeDenom = "atest" + func (m *recheckTestTx) GetMsgs() []sdk.Msg { return nil } func (m *recheckTestTx) GetMsgsV2() ([]proto.Message, error) { return nil, nil } +func (m *recheckTestTx) GetGas() uint64 { + if m.gas == 0 { + return 100_000 + } + return m.gas +} + +func (m *recheckTestTx) GetFee() sdk.Coins { + if len(m.fee) == 0 { + return sdk.NewCoins(sdk.NewInt64Coin(recheckTestFeeDenom, 100_000)) + } + return m.fee +} + +func (m *recheckTestTx) FeePayer() []byte { + signers, err := m.GetSigners() + if err != nil || len(signers) == 0 { + return nil + } + return signers[0] +} + +func (m *recheckTestTx) FeeGranter() []byte { + return nil +} + func (m *recheckTestTx) GetSigners() ([][]byte, error) { pubKeyBytes := crypto.CompressPubkey(&m.key.PublicKey) pubKey := ðsecp256k1.PubKey{Key: pubKeyBytes} diff --git a/mempool/tx_store.go b/mempool/tx_store.go index 1c938efeb..7f01ae85d 100644 --- a/mempool/tx_store.go +++ b/mempool/tx_store.go @@ -1,6 +1,9 @@ package mempool import ( + "fmt" + "slices" + "strings" "sync" "cosmossdk.io/log/v2" @@ -14,32 +17,152 @@ import ( type CosmosTxStore struct { txs []sdk.Tx - // index maps a tx to its position in the txs slice for fast removal - index map[sdk.Tx]int + // keys is a map of -> index to txs slice. + keys map[string]int + logger log.Logger mu sync.RWMutex } // NewCosmosTxStore creates a new CosmosTxStore. -func NewCosmosTxStore(_ log.Logger) *CosmosTxStore { +func NewCosmosTxStore(l log.Logger) *CosmosTxStore { return &CosmosTxStore{ - index: make(map[sdk.Tx]int), + logger: l, + keys: make(map[string]int), } } -// AddTx adds a single tx to the store. Duplicate txs (by pointer identity) -// are ignored. +// AddTx adds a single tx to the store while constructing a validated snapshot. func (s *CosmosTxStore) AddTx(tx sdk.Tx) { s.mu.Lock() defer s.mu.Unlock() - if _, exists := s.index[tx]; exists { - return + if key, ok := cosmosTxKey(tx); ok { + if _, exists := s.keys[key]; exists { + // this should never happen. panicking for safety + s.logger.Warn("attempted to add duplicate tx to CosmosTxStore", "key", key) + return + } + s.keys[key] = len(s.txs) } - s.index[tx] = len(s.txs) + s.txs = append(s.txs, tx) } +// InvalidateFrom removes any stored tx that depends on the supplied tx's signer/nonces. +// It is used for live mempool replacements: once a tx at nonce N changes, any stored tx +// for the same signer(s) with nonce >= N can no longer be considered valid for proposal building. +func (s *CosmosTxStore) InvalidateFrom(tx sdk.Tx) int { + s.mu.Lock() + defer s.mu.Unlock() + + // first check if this tx is already here. If it isn't; no need to do anything. It's a fresh insert. + // If it is, we need to do the work of invaliding any txs from the same sender with a higher nonce. + if txKey, ok := cosmosTxKey(tx); ok { + if _, exists := s.keys[txKey]; !exists { + return 0 + } + } + + // nonce thresholds for each signer. + thresholds, ok := cosmosTxNonceMap(tx) + if !ok { + return 0 + } + + // rebuild the txs list, skipping txs that are invalidated. + removed := 0 + nextTxs := make([]sdk.Tx, 0, len(s.txs)) + for _, existing := range s.txs { + if invalidatesCosmosTx(existing, thresholds) { + removed++ + continue + } + nextTxs = append(nextTxs, existing) + } + + if removed == 0 { + return 0 + } + + // TODO: this isn't really the most optimal way to do this. but maybe fine for now + s.reindex(nextTxs) + return removed +} + +func cosmosTxKey(tx sdk.Tx) (string, bool) { + nonceMap, ok := cosmosTxNonceMap(tx) + if !ok { + return "", false + } + + var b strings.Builder + for i, sig := range sortedSignerNonces(nonceMap) { + if i > 0 { + b.WriteByte('|') + } + fmt.Fprintf(&b, "%s/%d", sig.account, sig.seq) + } + + return b.String(), true +} + +// cosmosTxNonceMap extracts the signers from the transaction +// and returns a signer -> nonce map. +func cosmosTxNonceMap(tx sdk.Tx) (map[string]uint64, bool) { + signerSeqs, err := extractSignerSequences(tx) + if err != nil || len(signerSeqs) == 0 { + return nil, false + } + + nonceMap := make(map[string]uint64, len(signerSeqs)) + for _, sig := range signerSeqs { + nonce, err := sdkmempool.ChooseNonce(sig.seq, tx) + if err != nil { + return nil, false + } + nonceMap[sig.account] = nonce + } + + return nonceMap, true +} + +func sortedSignerNonces(nonceMap map[string]uint64) []signerSequence { + signerSeqs := make([]signerSequence, 0, len(nonceMap)) + for account, seq := range nonceMap { + signerSeqs = append(signerSeqs, signerSequence{account: account, seq: seq}) + } + slices.SortFunc(signerSeqs, func(a, b signerSequence) int { + return strings.Compare(a.account, b.account) + }) + return signerSeqs +} + +func invalidatesCosmosTx(tx sdk.Tx, thresholds map[string]uint64) bool { + nonceMap, ok := cosmosTxNonceMap(tx) + if !ok { + return false + } + + for account, threshold := range thresholds { + nonce, exists := nonceMap[account] + if exists && nonce >= threshold { + return true + } + } + return false +} + +func (s *CosmosTxStore) reindex(txs []sdk.Tx) { + s.txs = txs + s.keys = make(map[string]int, len(txs)) + for i, tx := range txs { + if key, ok := cosmosTxKey(tx); ok { + s.keys[key] = i + } + } +} + // Txs returns a copy of the current set of txs in the store. func (s *CosmosTxStore) Txs() []sdk.Tx { s.mu.RLock() diff --git a/mempool/tx_store_test.go b/mempool/tx_store_test.go index 081f11149..d33e8fb26 100644 --- a/mempool/tx_store_test.go +++ b/mempool/tx_store_test.go @@ -3,14 +3,19 @@ package mempool import ( "testing" + "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" protov2 "google.golang.org/protobuf/proto" + "github.com/cosmos/evm/crypto/ethsecp256k1" "github.com/cosmos/gogoproto/proto" "cosmossdk.io/log/v2" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdk "github.com/cosmos/cosmos-sdk/types" + signingtypes "github.com/cosmos/cosmos-sdk/types/tx/signing" + authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" ) // mockTx is a minimal sdk.Tx implementation for testing. @@ -27,6 +32,46 @@ func newMockTx(id int) sdk.Tx { return &mockTx{id: id} } +type keyedMockTx struct { + pubKey cryptotypes.PubKey + sequence uint64 +} + +var ( + _ sdk.Tx = (*keyedMockTx)(nil) + _ authsigning.SigVerifiableTx = (*keyedMockTx)(nil) +) + +func newKeyedMockTx(t *testing.T, sequence uint64) sdk.Tx { + t.Helper() + + key, err := crypto.GenerateKey() + require.NoError(t, err) + + pubKeyBytes := crypto.CompressPubkey(&key.PublicKey) + return &keyedMockTx{ + pubKey: ðsecp256k1.PubKey{Key: pubKeyBytes}, + sequence: sequence, + } +} + +func (m *keyedMockTx) GetMsgs() []proto.Message { return nil } +func (m *keyedMockTx) GetMsgsV2() ([]protov2.Message, error) { return nil, nil } +func (m *keyedMockTx) GetSigners() ([][]byte, error) { + return [][]byte{m.pubKey.Address().Bytes()}, nil +} + +func (m *keyedMockTx) GetPubKeys() ([]cryptotypes.PubKey, error) { + return []cryptotypes.PubKey{m.pubKey}, nil +} + +func (m *keyedMockTx) GetSignaturesV2() ([]signingtypes.SignatureV2, error) { + return []signingtypes.SignatureV2{{ + PubKey: m.pubKey, + Sequence: m.sequence, + }}, nil +} + func TestCosmosTxStoreAddAndGet(t *testing.T) { store := NewCosmosTxStore(log.NewNopLogger()) @@ -45,7 +90,7 @@ func TestCosmosTxStoreAddAndGet(t *testing.T) { func TestCosmosTxStoreDedup(t *testing.T) { store := NewCosmosTxStore(log.NewNopLogger()) - tx := newMockTx(1) + tx := newKeyedMockTx(t, 1) store.AddTx(tx) store.AddTx(tx) diff --git a/mempool/txpool/legacypool/legacypool_test.go b/mempool/txpool/legacypool/legacypool_test.go index 08ed57e14..287f703ea 100644 --- a/mempool/txpool/legacypool/legacypool_test.go +++ b/mempool/txpool/legacypool/legacypool_test.go @@ -43,6 +43,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/trie" "github.com/holiman/uint256" + "github.com/stretchr/testify/require" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/evm/mempool/txpool" @@ -334,6 +335,56 @@ func validateEvents(events chan core.NewTxsEvent, count int) error { return nil } +// TestMarkTxRemovedInvalidatesPending tests that when you have txs with nonces 4,5,6, and you submit a replacement for 4, +// txs 5 and 6 need to be revalidated. +func TestMarkTxRemovedInvalidatesPending(t *testing.T) { + pool, _, key := setupPool() + defer pool.Close() + + addr := crypto.PubkeyToAddress(key.PublicKey) + tx4 := pricedTransaction(4, 100000, big.NewInt(1), key) + tx5 := pricedTransaction(5, 100000, big.NewInt(1), key) + tx6 := pricedTransaction(6, 100000, big.NewInt(1), key) + replacement := pricedTransaction(4, 100000, big.NewInt(2), key) + + testAddBalance(pool, addr, big.NewInt(100000000000000)) + testSetNonce(pool, addr, 4) + pool.chain.(*testBlockChain).statedb.SetNonce(addr, 4, tracing.NonceChangeUnspecified) + + pool.all.Add(tx4) + pool.priced.Put(tx4) + pool.promoteTx(addr, tx4.Hash(), tx4) + + pool.all.Add(tx5) + pool.priced.Put(tx5) + pool.promoteTx(addr, tx5.Hash(), tx5) + + pool.all.Add(tx6) + pool.priced.Put(tx6) + pool.promoteTx(addr, tx6.Hash(), tx6) + + pending := pool.Pending(context.Background(), big.NewInt(0), txpool.PendingFilter{}) + require.Len(t, pending[addr], 3) // at this point, should have txs 4,5,6. + + replaced, err := pool.add(replacement) + require.NoError(t, err) + require.True(t, replaced) + + pending = pool.Pending(context.Background(), big.NewInt(0), txpool.PendingFilter{}) + require.Empty(t, pending[addr]) // now should have nothing, since tx 4 is now a new tx, and 5,6 depended on 4. + + pool.Reset(nil, nil) // recheck + + pending = pool.Pending(context.Background(), pool.chain.CurrentBlock().Number, txpool.PendingFilter{}) + require.Len(t, pending[addr], 3) + require.Equal(t, []uint64{4, 5, 6}, []uint64{ + pending[addr][0].Resolve().Nonce(), + pending[addr][1].Resolve().Nonce(), + pending[addr][2].Resolve().Nonce(), + }) + require.Equal(t, replacement.Hash(), pending[addr][0].Resolve().Hash()) +} + func deriveSender(tx *types.Transaction) (common.Address, error) { return types.Sender(types.HomesteadSigner{}, tx) } diff --git a/mempool/txpool/legacypool/tx_store.go b/mempool/txpool/legacypool/tx_store.go index 0b58d01f0..be91f1566 100644 --- a/mempool/txpool/legacypool/tx_store.go +++ b/mempool/txpool/legacypool/tx_store.go @@ -113,6 +113,7 @@ func (t *TxStore) AddTxs(addr common.Address, txs types.Transactions) { continue } toAdd = append(toAdd, tx) + t.lookup[tx.Hash()] = struct{}{} } if existing, ok := t.txs[addr]; ok { @@ -136,25 +137,37 @@ func (t *TxStore) AddTx(addr common.Address, tx *types.Transaction) { // RemoveTx removes a tx for an address from the current set. func (t *TxStore) RemoveTx(addr common.Address, tx *types.Transaction) { + t.RemoveTxsFromNonce(addr, tx.Nonce()) +} + +// RemoveTxsFromNonce removes all txs for addr whose nonce is >= minNonce. +func (t *TxStore) RemoveTxsFromNonce(addr common.Address, minNonce uint64) { t.mu.Lock() defer t.mu.Unlock() - defer delete(t.lookup, tx.Hash()) - txs, ok := t.txs[addr] if !ok { return } - // Find and remove the tx by nonce - nonce := tx.Nonce() - for i := 0; i < len(txs); i++ { - if txs[i].Nonce() == nonce { - // Swap with last element and truncate - txs[i] = txs[len(txs)-1] - t.txs[addr] = txs[:len(txs)-1] - t.total -= 1 - return + next := txs[:0] + numRemoved := 0 + for _, existing := range txs { + if existing.Nonce() >= minNonce { + delete(t.lookup, existing.Hash()) + numRemoved++ + continue } + next = append(next, existing) + } + + // memory reclaim + clear(txs[len(next):]) + + t.total -= uint64(numRemoved) + if len(next) == 0 { + delete(t.txs, addr) + return } + t.txs[addr] = next } diff --git a/mempool/txpool/legacypool/tx_store_test.go b/mempool/txpool/legacypool/tx_store_test.go index d328c29a0..8f7c2dcd2 100644 --- a/mempool/txpool/legacypool/tx_store_test.go +++ b/mempool/txpool/legacypool/tx_store_test.go @@ -2,7 +2,6 @@ package legacypool import ( "math/big" - "sync" "testing" "cosmossdk.io/log/v2" @@ -89,46 +88,44 @@ func TestTxStoreSortedByNonce(t *testing.T) { } } -func TestTxStoreRemoveTx(t *testing.T) { +// TestTxStoreRetainsPreviousTxs tests that if you remove a middle nonce, the earlier nonce txs stay retained. +func TestTxStoreRetainsPreviousTxs(t *testing.T) { store := NewTxStore(log.NewNopLogger()) addr1 := common.HexToAddress("0x1") + tx1 := createTestTx(0, big.NewInt(1e9), big.NewInt(2e9)) tx2 := createTestTx(1, big.NewInt(1e9), big.NewInt(2e9)) + tx3 := createTestTx(2, big.NewInt(1e9), big.NewInt(2e9)) + tx4 := createTestTx(3, big.NewInt(1e9), big.NewInt(2e9)) + tx5 := createTestTx(4, big.NewInt(1e9), big.NewInt(2e9)) + txs := []*types.Transaction{tx1, tx2, tx3, tx4, tx5} + for _, tx := range txs { + store.AddTx(addr1, tx) + } - store.AddTx(addr1, tx1) - store.AddTx(addr1, tx2) - store.RemoveTx(addr1, tx1) + store.RemoveTx(addr1, tx4) result := store.Txs(txpool.PendingFilter{}) - require.Len(t, result[addr1], 1) - require.Equal(t, uint64(1), result[addr1][0].Tx.Nonce()) + require.Len(t, result[addr1], 3) // should just have 0,1,2. + for i, tx := range result[addr1] { + require.Equal(t, uint64(i), tx.Tx.Nonce()) + } } -func TestTxStoreConcurrentRemove(t *testing.T) { +func TestTxStoreRemoveTx(t *testing.T) { store := NewTxStore(log.NewNopLogger()) addr1 := common.HexToAddress("0x1") - var numTxs uint64 = 1000 - var nonce uint64 = 0 - - for ; nonce < numTxs; nonce++ { - store.AddTx(addr1, createTestTx(nonce, big.NewInt(1e9), big.NewInt(2e9))) - } + tx1 := createTestTx(0, big.NewInt(1e9), big.NewInt(2e9)) + tx2 := createTestTx(1, big.NewInt(1e9), big.NewInt(2e9)) - // concurrently remove even-nonce txs - var wg sync.WaitGroup - for nonce = 0; nonce < numTxs; nonce += 2 { - wg.Add(1) - go func(nonce uint64) { - defer wg.Done() - store.RemoveTx(addr1, createTestTx(nonce, big.NewInt(1e9), big.NewInt(2e9))) - }(nonce) - } - wg.Wait() + store.AddTx(addr1, tx1) + store.AddTx(addr1, tx2) + store.RemoveTx(addr1, tx1) result := store.Txs(txpool.PendingFilter{}) - require.Len(t, result[addr1], 500) + require.Len(t, result[addr1], 0) } func TestTxStoreBlobTxsFiltered(t *testing.T) {