From 4c0e03dd3bec01b51f8a51930b0de4f8c4ea8ee5 Mon Sep 17 00:00:00 2001 From: Calin Martinconi Date: Tue, 3 Feb 2026 10:16:15 +0200 Subject: [PATCH 1/2] feat: improve performance on syncing batches from the snapshot file --- pkg/node/snapshot.go | 4 +++ pkg/postage/batchservice/batchservice.go | 42 ++++++++++++++++++------ pkg/postage/listener/listener.go | 14 ++++++-- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/pkg/node/snapshot.go b/pkg/node/snapshot.go index aac0a6eb140..ef190c0e26c 100644 --- a/pkg/node/snapshot.go +++ b/pkg/node/snapshot.go @@ -50,6 +50,10 @@ func NewSnapshotLogFilterer(logger log.Logger, getter SnapshotGetter) *SnapshotL } } +func (f *SnapshotLogFilterer) GetBatchSnapshot() []byte { + return f.getter.GetBatchSnapshot() +} + // loadSnapshot is responsible for loading and processing the snapshot data. // It is intended to be called exactly once by initOnce.Do. func (f *SnapshotLogFilterer) loadSnapshot() error { diff --git a/pkg/postage/batchservice/batchservice.go b/pkg/postage/batchservice/batchservice.go index 6ad76b0bed4..aac224c1a89 100644 --- a/pkg/postage/batchservice/batchservice.go +++ b/pkg/postage/batchservice/batchservice.go @@ -40,6 +40,8 @@ type batchService struct { checksum hash.Hash // checksum hasher resync bool + + pendingChainState *postage.ChainState } type Interface interface { @@ -95,15 +97,22 @@ func New( } } - return &batchService{stateStore, storer, logger.WithName(loggerName).Register(), listener, owner, batchListener, sum, resync}, nil + return &batchService{stateStore: stateStore, storer: storer, logger: logger.WithName(loggerName).Register(), listener: listener, owner: owner, batchListener: batchListener, checksum: sum, resync: resync}, nil +} + +func (svc *batchService) getChainState() *postage.ChainState { + if svc.pendingChainState != nil { + return svc.pendingChainState + } + return svc.storer.GetChainState() } // Create will create a new batch with the given ID, owner value and depth and // stores it in the BatchedStore. func (svc *batchService) Create(id, owner []byte, totalAmout, normalisedBalance *big.Int, depth, bucketDepth uint8, immutable bool, txHash common.Hash) error { - // don't add batches which have value which equals total cumulative + // dont add batches which have value which equals total cumulative // payout or that are going to expire already within the next couple of blocks - val := big.NewInt(0).Add(svc.storer.GetChainState().TotalAmount, svc.storer.GetChainState().CurrentPrice) + val := big.NewInt(0).Add(svc.getChainState().TotalAmount, svc.getChainState().CurrentPrice) if normalisedBalance.Cmp(val) <= 0 { // don't do anything return fmt.Errorf("batch service: batch %x: %w", id, ErrZeroValueBatch) @@ -112,7 +121,7 @@ func (svc *batchService) Create(id, owner []byte, totalAmout, normalisedBalance ID: id, Owner: owner, Value: normalisedBalance, - Start: svc.storer.GetChainState().Block, + Start: svc.getChainState().Block, Depth: depth, BucketDepth: bucketDepth, Immutable: immutable, @@ -196,10 +205,13 @@ func (svc *batchService) UpdateDepth(id []byte, depth uint8, normalisedBalance * // UpdatePrice implements the EventUpdater interface. It sets the current // price from the chain in the service chain state. func (svc *batchService) UpdatePrice(price *big.Int, txHash common.Hash) error { - cs := svc.storer.GetChainState() + cs := svc.getChainState() cs.CurrentPrice = price - if err := svc.storer.PutChainState(cs); err != nil { - return fmt.Errorf("put chain state: %w", err) + + if svc.pendingChainState == nil { + if err := svc.storer.PutChainState(cs); err != nil { + return fmt.Errorf("put chain state: %w", err) + } } sum, err := svc.updateChecksum(txHash) @@ -212,7 +224,7 @@ func (svc *batchService) UpdatePrice(price *big.Int, txHash common.Hash) error { } func (svc *batchService) UpdateBlockNumber(blockNumber uint64) error { - cs := svc.storer.GetChainState() + cs := svc.getChainState() if blockNumber == cs.Block { return nil } @@ -223,17 +235,27 @@ func (svc *batchService) UpdateBlockNumber(blockNumber uint64) error { cs.TotalAmount.Add(cs.TotalAmount, diff.Mul(diff, cs.CurrentPrice)) cs.Block = blockNumber - if err := svc.storer.PutChainState(cs); err != nil { - return fmt.Errorf("put chain state: %w", err) + + if svc.pendingChainState == nil { + if err := svc.storer.PutChainState(cs); err != nil { + return fmt.Errorf("put chain state: %w", err) + } } svc.logger.Debug("block height updated", "new_block", blockNumber) return nil } func (svc *batchService) TransactionStart() error { + svc.pendingChainState = svc.storer.GetChainState() return svc.stateStore.Put(dirtyDBKey, true) } func (svc *batchService) TransactionEnd() error { + if svc.pendingChainState != nil { + if err := svc.storer.PutChainState(svc.pendingChainState); err != nil { + return fmt.Errorf("put chain state: %w", err) + } + svc.pendingChainState = nil + } return svc.stateStore.Delete(dirtyDBKey) } diff --git a/pkg/postage/listener/listener.go b/pkg/postage/listener/listener.go index d9d52b2c5a1..e13314601f2 100644 --- a/pkg/postage/listener/listener.go +++ b/pkg/postage/listener/listener.go @@ -29,6 +29,7 @@ const loggerName = "listener" const ( blockPage = 5000 // how many blocks to sync every time we page + blockPageSnapshot = 50000 // how many blocks to sync every time from snapshot tailSize = 4 // how many blocks to tail from the tip of the chain defaultBatchFactor = uint64(5) // minimal number of blocks to sync at once ) @@ -241,6 +242,15 @@ func (l *listener) Listen(ctx context.Context, from uint64, updater postage.Even l.logger.Debug("batch factor", "value", batchFactor) + // Type assertion to detect if backend is SnapshotLogFilterer + pageSize := uint64(blockPage) + if _, isSnapshot := l.ev.(interface{ GetBatchSnapshot() []byte }); isSnapshot { + pageSize = blockPageSnapshot + l.logger.Debug("using snapshot page size", "page_size", pageSize) + } else { + l.logger.Debug("using standard page size", "page_size", pageSize) + } + synced := make(chan error) closeOnce := new(sync.Once) paged := true @@ -321,9 +331,9 @@ func (l *listener) Listen(ctx context.Context, from uint64, updater postage.Even } // do some paging (sub-optimal) - if to-from >= blockPage { + if to-from >= pageSize { paged = true - to = from + blockPage - 1 + to = from + pageSize - 1 } else { closeOnce.Do(func() { synced <- nil }) } From d5f7a184ba75261435ec09c2922a31919d75156c Mon Sep 17 00:00:00 2001 From: Calin Martinconi Date: Mon, 9 Feb 2026 18:40:49 +0200 Subject: [PATCH 2/2] test: add real snapshot in tests --- pkg/node/snapshot_test.go | 82 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/pkg/node/snapshot_test.go b/pkg/node/snapshot_test.go index 81edbcaad5f..eab031067fd 100644 --- a/pkg/node/snapshot_test.go +++ b/pkg/node/snapshot_test.go @@ -20,6 +20,8 @@ import ( "github.com/ethersphere/bee/v2/pkg/postage/listener" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + archive "github.com/ethersphere/batch-archive" ) type mockSnapshotGetter struct { @@ -33,6 +35,12 @@ func (m mockSnapshotGetter) GetBatchSnapshot() []byte { return m.data } +type realSnapshotGetter struct{} + +func (r realSnapshotGetter) GetBatchSnapshot() []byte { + return archive.GetBatchSnapshot() +} + func makeSnapshotData(logs []types.Log) []byte { var buf bytes.Buffer gz := gzip.NewWriter(&buf) @@ -149,3 +157,77 @@ func TestNewSnapshotLogFilterer(t *testing.T) { assert.Equal(t, 0, res[3].Topics[0].Cmp(common.HexToHash("0xa4"))) }) } + +func TestSnapshotLogFilterer_RealSnapshot(t *testing.T) { + t.Parallel() + + getter := realSnapshotGetter{} + filterer := node.NewSnapshotLogFilterer(log.Noop, getter) + + t.Run("block number", func(t *testing.T) { + blockNumber, err := filterer.BlockNumber(context.Background()) + assert.NoError(t, err) + assert.Greater(t, blockNumber, uint64(0)) + }) + + t.Run("filter range", func(t *testing.T) { + // arbitrary range that should exist in the snapshot + from := big.NewInt(20000000) + to := big.NewInt(20001000) + res, err := filterer.FilterLogs(context.Background(), ethereum.FilterQuery{ + FromBlock: from, + ToBlock: to, + }) + require.NoError(t, err) + for _, l := range res { + assert.GreaterOrEqual(t, l.BlockNumber, from.Uint64()) + assert.LessOrEqual(t, l.BlockNumber, to.Uint64()) + } + }) + + t.Run("filter address mismatch", func(t *testing.T) { + // random address that should not match the postage stamp contract + addr := common.HexToAddress("0x1234567890123456789012345678901234567890") + res, err := filterer.FilterLogs(context.Background(), ethereum.FilterQuery{ + Addresses: []common.Address{addr}, + }) + require.NoError(t, err) + assert.Empty(t, res) + }) +} + +func BenchmarkNewSnapshotLogFilterer_Load(b *testing.B) { + getter := realSnapshotGetter{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + filterer := node.NewSnapshotLogFilterer(log.Noop, getter) + _, err := filterer.BlockNumber(context.Background()) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSnapshotLogFilterer(b *testing.B) { + getter := realSnapshotGetter{} + filterer := node.NewSnapshotLogFilterer(log.Noop, getter) + // ensure loaded + if _, err := filterer.BlockNumber(context.Background()); err != nil { + b.Fatal(err) + } + + b.Run("FilterLogs", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + from := big.NewInt(20000000) + to := big.NewInt(20001000) + _, err := filterer.FilterLogs(context.Background(), ethereum.FilterQuery{ + FromBlock: from, + ToBlock: to, + }) + if err != nil { + b.Fatal(err) + } + } + }) +}