From dff289059d02cb95b1480eff8d26c706c04e227c Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Tue, 28 Feb 2023 21:20:18 +0800 Subject: [PATCH 01/51] hard fork: add state expiry hard forks; types: add state epoch; --- core/types/state_epoch.go | 31 +++++++++++ core/types/state_epoch_test.go | 98 ++++++++++++++++++++++++++++++++++ params/config.go | 60 +++++++++++++++++++-- 3 files changed, 185 insertions(+), 4 deletions(-) create mode 100644 core/types/state_epoch.go create mode 100644 core/types/state_epoch_test.go diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go new file mode 100644 index 0000000000..df19b4f300 --- /dev/null +++ b/core/types/state_epoch.go @@ -0,0 +1,31 @@ +package types + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" + "math/big" +) + +var ( + // EpochPeriod indicates the state rotate epoch block length + EpochPeriod = big.NewInt(7_008_000) +) + +// GetCurrentEpoch computes the current state epoch by hard fork and block number +// state epoch will indicate if the state is accessible or expiry. +// Before ClaudeBlock indicates state epoch0. +// ClaudeBlock indicates start state epoch1. +// ElwoodBlock indicates start state epoch2 and start epoch rotate by EpochPeriod. +// When N>=2 and epochN started, epoch(N-2)'s state will expire. +func GetCurrentEpoch(config *params.ChainConfig, blockNumber *big.Int) *big.Int { + if config.IsElwood(blockNumber) { + ret := new(big.Int).Sub(blockNumber, config.ElwoodBlock) + ret.Div(ret, EpochPeriod) + ret.Add(ret, common.Big2) + return ret + } else if config.IsClaude(blockNumber) { + return common.Big1 + } else { + return common.Big0 + } +} diff --git a/core/types/state_epoch_test.go b/core/types/state_epoch_test.go new file mode 100644 index 0000000000..c5d988b0b6 --- /dev/null +++ b/core/types/state_epoch_test.go @@ -0,0 +1,98 @@ +package types + +import ( + "github.com/ethereum/go-ethereum/params" + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestStateForkConfig(t *testing.T) { + temp := ¶ms.ChainConfig{} + assert.NoError(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(1), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(0), + ElwoodBlock: big.NewInt(0), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10000), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(2), + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(0), + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10001), + } + assert.NoError(t, temp.CheckConfigForkOrder()) +} + +func TestSimpleStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(20000), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) + assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(1000))) + assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(10000))) + assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(19999))) + assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(20000))) + assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(20000), EpochPeriod))) + assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) +} + +func TestNoZeroStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(1), + ElwoodBlock: big.NewInt(2), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) + assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(1))) + assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(2))) + assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(10000))) + assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(2), EpochPeriod))) + assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) +} + +func TestNearestStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10001), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) + assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(10000))) + assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(10001))) + assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(10001), EpochPeriod))) + assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) +} diff --git a/params/config.go b/params/config.go index f4c485f227..d10e63a6af 100644 --- a/params/config.go +++ b/params/config.go @@ -115,6 +115,8 @@ var ( MoranBlock: big.NewInt(22107423), GibbsBlock: big.NewInt(23846001), PlanckBlock: big.NewInt(27281024), + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on mainNet + //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ Period: 3, @@ -142,6 +144,9 @@ var ( NanoBlock: big.NewInt(23482428), MoranBlock: big.NewInt(23603940), PlanckBlock: big.NewInt(28196022), + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on testnet + //ElwoodBlock: big.NewInt(-), + Parlia: &ParliaConfig{ Period: 3, Epoch: 200, @@ -168,6 +173,8 @@ var ( NanoBlock: nil, MoranBlock: nil, PlanckBlock: nil, + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on QA net + //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ Period: 3, @@ -180,16 +187,16 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil, nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil, nil} // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Clique consensus. // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), nil, nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), nil, nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil, nil} + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, new(EthashConfig), nil, nil} TestRules = TestChainConfig.Rules(new(big.Int), false) ) @@ -286,6 +293,8 @@ type ChainConfig struct { NanoBlock *big.Int `json:"nanoBlock,omitempty" toml:",omitempty"` // nanoBlock switch block (nil = no fork, 0 = already activated) MoranBlock *big.Int `json:"moranBlock,omitempty" toml:",omitempty"` // moranBlock switch block (nil = no fork, 0 = already activated) PlanckBlock *big.Int `json:"planckBlock,omitempty" toml:",omitempty"` // planckBlock switch block (nil = no fork, 0 = already activated) + ClaudeBlock *big.Int `json:"claudeBlock,omitempty" toml:",omitempty"` // claudeBlock switch block (nil = no fork, 0 = already activated) + ElwoodBlock *big.Int `json:"elwoodBlock,omitempty" toml:",omitempty"` // elwoodBlock switch block (nil = no fork, 0 = already activated) // Various consensus engines Ethash *EthashConfig `json:"ethash,omitempty" toml:",omitempty"` @@ -336,7 +345,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, Ramanujan: %v, Niels: %v, MirrorSync: %v, Bruno: %v, Berlin: %v, YOLO v3: %v, CatalystBlock: %v, London: %v, ArrowGlacier: %v, MergeFork:%v, Euler: %v, Gibbs: %v, Nano: %v, Moran: %v, Planck: %v, Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, Ramanujan: %v, Niels: %v, MirrorSync: %v, Bruno: %v, Berlin: %v, YOLO v3: %v, CatalystBlock: %v, London: %v, ArrowGlacier: %v, MergeFork:%v, Euler: %v, Gibbs: %v, Nano: %v, Moran: %v, Planck: %v, Claude: %v, Elwood: %v, Engine: %v}", c.ChainID, c.HomesteadBlock, c.DAOForkBlock, @@ -364,6 +373,8 @@ func (c *ChainConfig) String() string { c.NanoBlock, c.MoranBlock, c.PlanckBlock, + c.ClaudeBlock, + c.ElwoodBlock, engine, ) } @@ -527,6 +538,22 @@ func (c *ChainConfig) IsOnPlanck(num *big.Int) bool { return configNumEqual(c.PlanckBlock, num) } +func (c *ChainConfig) IsClaude(num *big.Int) bool { + return isForked(c.ClaudeBlock, num) +} + +func (c *ChainConfig) IsOnClaude(num *big.Int) bool { + return configNumEqual(c.ClaudeBlock, num) +} + +func (c *ChainConfig) IsElwood(num *big.Int) bool { + return isForked(c.ElwoodBlock, num) +} + +func (c *ChainConfig) IsOnElwood(num *big.Int) bool { + return configNumEqual(c.ElwoodBlock, num) +} + // CheckCompatible checks whether scheduled fork transitions have been imported // with a mismatching chain configuration. func (c *ChainConfig) CheckCompatible(newcfg *ChainConfig, height uint64) *ConfigCompatError { @@ -578,6 +605,25 @@ func (c *ChainConfig) CheckConfigForkOrder() error { lastFork = cur } } + + // check state expiry's hard forks + if c.ClaudeBlock != nil || c.ElwoodBlock != nil { + if c.ClaudeBlock == nil { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v ElwoodBlock %v", + c.ClaudeBlock, c.ElwoodBlock) + } + + if c.ClaudeBlock.Cmp(common.Big0) <= 0 { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v", + c.ClaudeBlock) + } + + if c.ElwoodBlock != nil && c.ClaudeBlock.Cmp(c.ElwoodBlock) >= 0 { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v ElwoodBlock %v", + c.ClaudeBlock, c.ElwoodBlock) + } + } + return nil } @@ -658,6 +704,12 @@ func (c *ChainConfig) checkCompatible(newcfg *ChainConfig, head *big.Int) *Confi if isForkIncompatible(c.PlanckBlock, newcfg.PlanckBlock, head) { return newCompatError("planck fork block", c.PlanckBlock, newcfg.PlanckBlock) } + if isForkIncompatible(c.ClaudeBlock, newcfg.ClaudeBlock, head) { + return newCompatError("claude fork block", c.ClaudeBlock, newcfg.ClaudeBlock) + } + if isForkIncompatible(c.ElwoodBlock, newcfg.ElwoodBlock, head) { + return newCompatError("elwood fork block", c.ElwoodBlock, newcfg.ElwoodBlock) + } return nil } From ad06cefa8805742ac8675fb9f6cd287269af2f90 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Wed, 29 Mar 2023 21:40:48 +0800 Subject: [PATCH 02/51] core/transaction: add revive state transaction; message support witnessList; --- accounts/abi/bind/backends/simulated.go | 23 ++--- core/state_transition.go | 1 + core/types/access_list_tx.go | 4 + core/types/dynamic_fee_tx.go | 4 + core/types/legacy_tx.go | 4 + core/types/revive_state_tx.go | 111 ++++++++++++++++++++++++ core/types/revive_witness.go | 14 +++ core/types/transaction.go | 107 +++++++++++++---------- core/types/transaction_marshalling.go | 62 +++++++++++++ core/types/transaction_signing.go | 83 ++++++++++++++++++ core/types/transaction_test.go | 99 +++++++++++++++++++++ interfaces.go | 3 +- internal/ethapi/transaction_args.go | 9 +- les/odr_test.go | 4 +- light/odr_test.go | 2 +- params/config.go | 2 +- tests/state_test_util.go | 2 +- 17 files changed, 470 insertions(+), 64 deletions(-) create mode 100644 core/types/revive_state_tx.go create mode 100644 core/types/revive_witness.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index f8ceec8838..434fac8d6a 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -803,17 +803,18 @@ type callMsg struct { ethereum.CallMsg } -func (m callMsg) From() common.Address { return m.CallMsg.From } -func (m callMsg) Nonce() uint64 { return 0 } -func (m callMsg) IsFake() bool { return true } -func (m callMsg) To() *common.Address { return m.CallMsg.To } -func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callMsg) GasFeeCap() *big.Int { return m.CallMsg.GasFeeCap } -func (m callMsg) GasTipCap() *big.Int { return m.CallMsg.GasTipCap } -func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callMsg) Value() *big.Int { return m.CallMsg.Value } -func (m callMsg) Data() []byte { return m.CallMsg.Data } -func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) IsFake() bool { return true } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) GasFeeCap() *big.Int { return m.CallMsg.GasFeeCap } +func (m callMsg) GasTipCap() *big.Int { return m.CallMsg.GasTipCap } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } +func (m callMsg) WitnessList() types.WitnessList { return m.CallMsg.WitnessList } // filterBackend implements filters.Backend to support filtering for logs without // taking bloom-bits acceleration structures into account. diff --git a/core/state_transition.go b/core/state_transition.go index 8083a4ea61..1d8dfcf7de 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -80,6 +80,7 @@ type Message interface { IsFake() bool Data() []byte AccessList() types.AccessList + WitnessList() types.WitnessList } // ExecutionResult includes all output after executing given evm diff --git a/core/types/access_list_tx.go b/core/types/access_list_tx.go index 8ad5e739e9..b2c4da78f3 100644 --- a/core/types/access_list_tx.go +++ b/core/types/access_list_tx.go @@ -106,6 +106,10 @@ func (tx *AccessListTx) value() *big.Int { return tx.Value } func (tx *AccessListTx) nonce() uint64 { return tx.Nonce } func (tx *AccessListTx) to() *common.Address { return tx.To } +func (tx *AccessListTx) witnessList() WitnessList { + return nil +} + func (tx *AccessListTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/dynamic_fee_tx.go b/core/types/dynamic_fee_tx.go index 53f246ea1f..54de7dcfed 100644 --- a/core/types/dynamic_fee_tx.go +++ b/core/types/dynamic_fee_tx.go @@ -94,6 +94,10 @@ func (tx *DynamicFeeTx) value() *big.Int { return tx.Value } func (tx *DynamicFeeTx) nonce() uint64 { return tx.Nonce } func (tx *DynamicFeeTx) to() *common.Address { return tx.To } +func (tx *DynamicFeeTx) witnessList() WitnessList { + return nil +} + func (tx *DynamicFeeTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/legacy_tx.go b/core/types/legacy_tx.go index cb86bed772..f0857acc8c 100644 --- a/core/types/legacy_tx.go +++ b/core/types/legacy_tx.go @@ -103,6 +103,10 @@ func (tx *LegacyTx) value() *big.Int { return tx.Value } func (tx *LegacyTx) nonce() uint64 { return tx.Nonce } func (tx *LegacyTx) to() *common.Address { return tx.To } +func (tx *LegacyTx) witnessList() WitnessList { + return nil +} + func (tx *LegacyTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/revive_state_tx.go b/core/types/revive_state_tx.go new file mode 100644 index 0000000000..43646a465a --- /dev/null +++ b/core/types/revive_state_tx.go @@ -0,0 +1,111 @@ +package types + +import ( + "github.com/ethereum/go-ethereum/common" + "math/big" +) + +type WitnessList []ReviveWitness + +// ReviveStateTx is the transaction for revive state. +type ReviveStateTx struct { + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + WitnessList WitnessList // revive witness + + V, R, S *big.Int // signature values +} + +func (tx *ReviveStateTx) txType() byte { + return ReviveStateTxType +} + +func (tx *ReviveStateTx) copy() TxData { + cpy := &ReviveStateTx{ + Nonce: tx.Nonce, + To: copyAddressPtr(tx.To), + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + WitnessList: make(WitnessList, len(tx.WitnessList)), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + + copy(cpy.WitnessList, tx.WitnessList) + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +func (tx *ReviveStateTx) chainID() *big.Int { + return deriveChainId(tx.V) +} + +func (tx *ReviveStateTx) accessList() AccessList { + return nil +} + +func (tx *ReviveStateTx) witnessList() WitnessList { + return tx.WitnessList +} + +func (tx *ReviveStateTx) data() []byte { + return tx.Data +} + +func (tx *ReviveStateTx) gas() uint64 { + return tx.Gas +} + +func (tx *ReviveStateTx) gasPrice() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) gasTipCap() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) gasFeeCap() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) value() *big.Int { + return tx.Value +} + +func (tx *ReviveStateTx) nonce() uint64 { + return tx.Nonce +} + +func (tx *ReviveStateTx) to() *common.Address { + return tx.To +} + +func (tx *ReviveStateTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *ReviveStateTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go new file mode 100644 index 0000000000..1900c71fbf --- /dev/null +++ b/core/types/revive_witness.go @@ -0,0 +1,14 @@ +package types + +import "github.com/ethereum/go-ethereum/common" + +type MPTProof struct { + key []byte // prefix key + proof [][]byte // list of RLP-encoded nodes +} + +type ReviveWitness struct { + witnessType byte // only support Merkle Proof for now + address *common.Address // target account address + proofList []MPTProof // revive multiple slots (same address) +} diff --git a/core/types/transaction.go b/core/types/transaction.go index 95a9da87af..4984eb709a 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -45,6 +45,8 @@ const ( LegacyTxType = iota AccessListTxType DynamicFeeTxType + + ReviveStateTxType = 127 // reserve for BSC revive state ) // Transaction is an Ethereum transaction. @@ -74,6 +76,7 @@ type TxData interface { chainID() *big.Int accessList() AccessList + witnessList() WitnessList data() []byte gas() uint64 gasPrice() *big.Int @@ -191,6 +194,10 @@ func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { var inner DynamicFeeTx err := rlp.DecodeBytes(b[1:], &inner) return &inner, err + case ReviveStateTxType: + var inner ReviveStateTx + err := rlp.DecodeBytes(b[1:], &inner) + return &inner, err default: return nil, ErrTxTypeNotSupported } @@ -268,6 +275,10 @@ func (tx *Transaction) Data() []byte { return tx.inner.data() } // AccessList returns the access list of the transaction. func (tx *Transaction) AccessList() AccessList { return tx.inner.accessList() } +func (tx *Transaction) WitnessList() WitnessList { + return tx.inner.witnessList() +} + // Gas returns the gas limit of the transaction. func (tx *Transaction) Gas() uint64 { return tx.inner.gas() } @@ -655,48 +666,51 @@ func (t *TransactionsByPriceAndNonce) Forward(tx *Transaction) { // // NOTE: In a future PR this will be removed. type Message struct { - to *common.Address - from common.Address - nonce uint64 - amount *big.Int - gasLimit uint64 - gasPrice *big.Int - gasFeeCap *big.Int - gasTipCap *big.Int - data []byte - accessList AccessList - isFake bool -} - -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice, gasFeeCap, gasTipCap *big.Int, data []byte, accessList AccessList, isFake bool) Message { + to *common.Address + from common.Address + nonce uint64 + amount *big.Int + gasLimit uint64 + gasPrice *big.Int + gasFeeCap *big.Int + gasTipCap *big.Int + data []byte + accessList AccessList + witnessList WitnessList + isFake bool +} + +func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice, gasFeeCap, gasTipCap *big.Int, data []byte, accessList AccessList, witnessList WitnessList, isFake bool) Message { return Message{ - from: from, - to: to, - nonce: nonce, - amount: amount, - gasLimit: gasLimit, - gasPrice: gasPrice, - gasFeeCap: gasFeeCap, - gasTipCap: gasTipCap, - data: data, - accessList: accessList, - isFake: isFake, + from: from, + to: to, + nonce: nonce, + amount: amount, + gasLimit: gasLimit, + gasPrice: gasPrice, + gasFeeCap: gasFeeCap, + gasTipCap: gasTipCap, + data: data, + accessList: accessList, + witnessList: witnessList, + isFake: isFake, } } // AsMessage returns the transaction as a core.Message. func (tx *Transaction) AsMessage(s Signer, baseFee *big.Int) (Message, error) { msg := Message{ - nonce: tx.Nonce(), - gasLimit: tx.Gas(), - gasPrice: new(big.Int).Set(tx.GasPrice()), - gasFeeCap: new(big.Int).Set(tx.GasFeeCap()), - gasTipCap: new(big.Int).Set(tx.GasTipCap()), - to: tx.To(), - amount: tx.Value(), - data: tx.Data(), - accessList: tx.AccessList(), - isFake: false, + nonce: tx.Nonce(), + gasLimit: tx.Gas(), + gasPrice: new(big.Int).Set(tx.GasPrice()), + gasFeeCap: new(big.Int).Set(tx.GasFeeCap()), + gasTipCap: new(big.Int).Set(tx.GasTipCap()), + to: tx.To(), + amount: tx.Value(), + data: tx.Data(), + accessList: tx.AccessList(), + witnessList: tx.WitnessList(), + isFake: false, } // If baseFee provided, set gasPrice to effectiveGasPrice. if baseFee != nil { @@ -716,17 +730,18 @@ func (tx *Transaction) AsMessageNoNonceCheck(s Signer) (Message, error) { return msg, err } -func (m Message) From() common.Address { return m.from } -func (m Message) To() *common.Address { return m.to } -func (m Message) GasPrice() *big.Int { return m.gasPrice } -func (m Message) GasFeeCap() *big.Int { return m.gasFeeCap } -func (m Message) GasTipCap() *big.Int { return m.gasTipCap } -func (m Message) Value() *big.Int { return m.amount } -func (m Message) Gas() uint64 { return m.gasLimit } -func (m Message) Nonce() uint64 { return m.nonce } -func (m Message) Data() []byte { return m.data } -func (m Message) AccessList() AccessList { return m.accessList } -func (m Message) IsFake() bool { return m.isFake } +func (m Message) From() common.Address { return m.from } +func (m Message) To() *common.Address { return m.to } +func (m Message) GasPrice() *big.Int { return m.gasPrice } +func (m Message) GasFeeCap() *big.Int { return m.gasFeeCap } +func (m Message) GasTipCap() *big.Int { return m.gasTipCap } +func (m Message) Value() *big.Int { return m.amount } +func (m Message) Gas() uint64 { return m.gasLimit } +func (m Message) Nonce() uint64 { return m.nonce } +func (m Message) Data() []byte { return m.data } +func (m Message) AccessList() AccessList { return m.accessList } +func (m Message) WitnessList() WitnessList { return m.witnessList } +func (m Message) IsFake() bool { return m.isFake } // copyAddressPtr copies an address. func copyAddressPtr(a *common.Address) *common.Address { diff --git a/core/types/transaction_marshalling.go b/core/types/transaction_marshalling.go index aad31a5a97..346fbb10f3 100644 --- a/core/types/transaction_marshalling.go +++ b/core/types/transaction_marshalling.go @@ -46,6 +46,9 @@ type txJSON struct { ChainID *hexutil.Big `json:"chainId,omitempty"` AccessList *AccessList `json:"accessList,omitempty"` + // Revive State transaction fields: + WitnessList *WitnessList `json:"witnessList,omitempty"` + // Only used for encoding: Hash common.Hash `json:"hash"` } @@ -94,6 +97,17 @@ func (t *Transaction) MarshalJSON() ([]byte, error) { enc.V = (*hexutil.Big)(tx.V) enc.R = (*hexutil.Big)(tx.R) enc.S = (*hexutil.Big)(tx.S) + case *ReviveStateTx: + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.WitnessList = &tx.WitnessList + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) } return json.Marshal(&enc) } @@ -263,6 +277,54 @@ func (t *Transaction) UnmarshalJSON(input []byte) error { } } + case ReviveStateTxType: + var itx ReviveStateTx + inner = &itx + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.WitnessList == nil { + return errors.New("missing required field 'WitnessList' in transaction") + } + itx.WitnessList = *dec.WitnessList + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, true); err != nil { + return err + } + } default: return ErrTxTypeNotSupported } diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 1d0d2a4c75..4499c7438f 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -329,6 +329,89 @@ func (s eip2930Signer) Hash(tx *Transaction) common.Hash { } } +type BEP215Signer struct { + EIP155Signer +} + +func NewBEP215Signer(chainId *big.Int) BEP215Signer { + return BEP215Signer{NewEIP155Signer(chainId)} +} +func (s BEP215Signer) ChainID() *big.Int { + return s.chainId +} + +func (s BEP215Signer) Equal(s2 Signer) bool { + bep215, ok := s2.(BEP215Signer) + return ok && bep215.chainId.Cmp(s.chainId) == 0 +} + +func (s BEP215Signer) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType && tx.Type() != ReviveStateTxType { + return common.Address{}, ErrTxTypeNotSupported + } + + if !tx.Protected() { + return HomesteadSigner{}.Sender(tx) + } + if tx.ChainId().Cmp(s.chainId) != 0 { + return common.Address{}, ErrInvalidChainId + } + V, R, S := tx.RawSignatureValues() + V = new(big.Int).Sub(V, s.chainIdMul) + V.Sub(V, big8) + return recoverPlain(s.Hash(tx), R, S, V, true) +} + +// SignatureValues returns signature values. This signature +// needs to be in the [R || S || V] format where V is 0 or 1. +func (s BEP215Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { + if tx.Type() != LegacyTxType && tx.Type() != ReviveStateTxType { + return nil, nil, nil, ErrTxTypeNotSupported + } + R, S, V = decodeSignature(sig) + if s.chainId.Sign() != 0 { + V = big.NewInt(int64(sig[64] + 35)) + V.Add(V, s.chainIdMul) + } + return R, S, V, nil +} + +// Hash returns the hash to be signed by the sender. +// It does not uniquely identify the transaction. +func (s BEP215Signer) Hash(tx *Transaction) common.Hash { + switch tx.Type() { + case LegacyTxType: + return rlpHash([]interface{}{ + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + s.chainId, uint(0), uint(0), + }) + case ReviveStateTxType: + return prefixedRlpHash( + tx.Type(), + []interface{}{ + s.chainId, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + tx.WitnessList(), + }) + default: + // This _should_ not happen, but in case someone sends in a bad + // json struct via RPC, it's probably more prudent to return an + // empty hash instead of killing the node with a panic + //panic("Unsupported transaction type: %d", tx.typ) + return common.Hash{} + } +} + // EIP155Signer implements Signer using the EIP-155 rules. This accepts transactions which // are replay-protected as well as unprotected homestead transactions. type EIP155Signer struct { diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 3177a04d45..3ad58a024b 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -21,6 +21,7 @@ import ( "crypto/ecdsa" "encoding/json" "fmt" + "github.com/stretchr/testify/assert" "math/big" "math/rand" "reflect" @@ -563,6 +564,104 @@ func TestTransactionCoding(t *testing.T) { } } +// TestTransactionCoding tests serializing/de-serializing to/from rlp and JSON, signer sender & chainId. +func TestReviveStateTxAndSigner(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("could not generate key: %v", err) + } + var ( + signer = NewBEP215Signer(common.Big1) + from = crypto.PubkeyToAddress(key.PublicKey) + addr = common.HexToAddress("0x0000000000000000000000000000000000000001") + recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + witness = WitnessList{{ + witnessType: 0, + address: &addr, + proofList: []MPTProof{{ + key: common.Hex2Bytes("095e7baea6a6c7c4c2"), + proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, + }}, + }} + ) + for i := uint64(0); i < 500; i++ { + var txdata TxData + switch i % 5 { + case 0: + // Legacy tx. + txdata = &LegacyTx{ + Nonce: i, + To: &recipient, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 1: + // Legacy tx contract creation. + txdata = &LegacyTx{ + Nonce: i, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 2: + // Tx with non-zero witness in revive state. + txdata = &ReviveStateTx{ + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + WitnessList: witness, + Data: []byte("abcdef"), + } + case 3: + // Tx with empty revive state. + txdata = &ReviveStateTx{ + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + Data: []byte("abcdef"), + } + case 4: + // Contract creation with revive state. + txdata = &ReviveStateTx{ + Nonce: i, + GasPrice: big.NewInt(10), + Gas: 123457, + Data: []byte("abcdef"), + WitnessList: witness, + } + } + tx, err := SignNewTx(key, signer, txdata) + if err != nil { + t.Fatalf("could not sign transaction: %v", err) + } + // RLP + parsedTx, err := encodeDecodeBinary(tx) + if err != nil { + t.Fatal(err) + } + if err := assertEqual(parsedTx, tx); err != nil { + t.Fatal(err) + } + + // JSON + parsedTx, err = encodeDecodeJSON(tx) + if err != nil { + t.Fatal(err) + } + if err := assertEqual(parsedTx, tx); err != nil { + t.Fatal(err) + } + + assert.Equal(t, common.Big1, tx.ChainId()) + sender, err := signer.Sender(tx) + assert.NoError(t, err) + assert.Equal(t, from, sender) + } +} + func encodeDecodeJSON(tx *Transaction) (*Transaction, error) { data, err := json.Marshal(tx) if err != nil { diff --git a/interfaces.go b/interfaces.go index 76c1ef6908..9d69dc98b8 100644 --- a/interfaces.go +++ b/interfaces.go @@ -141,7 +141,8 @@ type CallMsg struct { Value *big.Int // amount of wei sent along with the call Data []byte // input data, usually an ABI-encoded contract method invocation - AccessList types.AccessList // EIP-2930 access list. + AccessList types.AccessList // EIP-2930 access list. + WitnessList types.WitnessList // EIP-2930 access list. } // A ContractCaller provides contract calls, essentially transactions that are executed by diff --git a/internal/ethapi/transaction_args.go b/internal/ethapi/transaction_args.go index 9284439162..b6de1a599c 100644 --- a/internal/ethapi/transaction_args.go +++ b/internal/ethapi/transaction_args.go @@ -52,6 +52,9 @@ type TransactionArgs struct { // Introduced by AccessListTxType transaction. AccessList *types.AccessList `json:"accessList,omitempty"` ChainID *hexutil.Big `json:"chainId,omitempty"` + + // Introduced by ReviveStateTxType transaction. + WitnessList *types.WitnessList `json:"witnessList,omitempty"` } // from retrieves the transaction sender address. @@ -245,7 +248,11 @@ func (args *TransactionArgs) ToMessage(globalGasCap uint64, baseFee *big.Int) (t if args.AccessList != nil { accessList = *args.AccessList } - msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, gasFeeCap, gasTipCap, data, accessList, true) + var witnessList types.WitnessList + if args.WitnessList != nil { + witnessList = *args.WitnessList + } + msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, gasFeeCap, gasTipCap, data, accessList, witnessList, true) return msg, nil } diff --git a/les/odr_test.go b/les/odr_test.go index ad77abf5b9..b77bb30d9b 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -135,7 +135,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai from := statedb.GetOrNewStateObject(bankAddr) from.SetBalance(math.MaxBig256) - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} context := core.NewEVMBlockContext(header, bc, nil) txContext := core.NewEVMTxContext(msg) @@ -150,7 +150,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai header := lc.GetHeaderByHash(bhash) state := light.NewState(ctx, header, lc.Odr()) state.SetBalance(bankAddr, math.MaxBig256) - msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} context := core.NewEVMBlockContext(header, lc, nil) txContext := core.NewEVMTxContext(msg) vmenv := vm.NewEVM(context, txContext, state, config, vm.Config{NoBaseFee: true}) diff --git a/light/odr_test.go b/light/odr_test.go index fdf657a82e..1a91662484 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -194,7 +194,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain // Perform read-only call. st.SetBalance(testBankAddress, math.MaxBig256) - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} txContext := core.NewEVMTxContext(msg) context := core.NewEVMBlockContext(header, chain, nil) vmenv := vm.NewEVM(context, txContext, st, config, vm.Config{NoBaseFee: true}) diff --git a/params/config.go b/params/config.go index d10e63a6af..b123bd187f 100644 --- a/params/config.go +++ b/params/config.go @@ -187,7 +187,7 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil, nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil, nil} // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Clique consensus. diff --git a/tests/state_test_util.go b/tests/state_test_util.go index e71cd2bf3c..bca38048ab 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -364,7 +364,7 @@ func (tx *stTransaction) toMessage(ps stPostState, baseFee *big.Int) (core.Messa } msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, gasPrice, - tx.MaxFeePerGas, tx.MaxPriorityFeePerGas, data, accessList, false) + tx.MaxFeePerGas, tx.MaxPriorityFeePerGas, data, accessList, nil, false) return msg, nil } From 6ae9ec781348968cfba1ff7d01b3e05560573e32 Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 28 Mar 2023 11:43:49 +0800 Subject: [PATCH 03/51] trie: implement expire subtree by prefix --- trie/trie.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/trie/trie.go b/trie/trie.go index 878f4cd3a3..007408a87f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -492,6 +492,64 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } } +func (t *Trie) ExpireByPrefix(prefixKey []byte) { + _, err := t.expireByPrefix(t.root, prefixKey) + if err != nil { + log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + } +} + +func (t *Trie) expireByPrefix(n node, prefixKey []byte) (node, error) { + // Loop through prefix key + // When prefix key is empty, generate the hash node of the current node + // Replace current node with the hash node + + // If length of prefix key is empty + if len(prefixKey) == 0 { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + var hn node + _, hn = hasher.proofHash(n) + + return hn, nil + } + + switch n := n.(type) { + case *shortNode: + matchLen := prefixLen(prefixKey, n.Key) + if matchLen == len(prefixKey) { + return nil, fmt.Errorf("")// Found the node to expire + } + + hn, err := t.expireByPrefix(n.Val, prefixKey[matchLen:]) + if err != nil { + return nil, err + } + + // Replace child node with hash node + if hn != nil { + n.Val = hn + } + + return nil, err + case *fullNode: + hn, err := t.expireByPrefix(n.Children[prefixKey[0]], prefixKey[1:]) + if err != nil { + return nil, err + } + + // Replace child node with hash node + if hn != nil { + n.Children[prefixKey[0]] = hn + } + + return nil, err + default: + return nil, fmt.Errorf("invalid node type") + } +} + + func concat(s1 []byte, s2 ...byte) []byte { r := make([]byte, len(s1)+len(s2)) copy(r, s1) From 7b0530ef4c6a99045a9c8742ebd92e3866b2b5a1 Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 28 Mar 2023 11:45:10 +0800 Subject: [PATCH 04/51] witness: add test to check proof with expired sister node --- trie/proof_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/trie/proof_test.go b/trie/proof_test.go index 29866714c2..9895d56f53 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -892,6 +892,54 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { } } +func TestProofWithExpiredSisterNode(t *testing.T) { + trie := new(Trie) + + expiredData := map[string]string{ + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + + unexpiredData := map[string]string{ + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + // Loop through the data and insert it into the trie + for k, v := range expiredData { + updateString(trie, k, v) + } + + for k, v := range unexpiredData { + updateString(trie, k, v) + } + + prefixKey := keybytesToHex([]byte("abcd"))[:2] + + trie.ExpireByPrefix(prefixKey) + + for i, prover := range makeProvers(trie){ + for k, v := range unexpiredData { + proof := prover([]byte(k)) + if proof == nil { + t.Fatalf("proof %d is nil", i) + } + val, err := VerifyProof(trie.Hash(), []byte("degi"), proof) + if err != nil { + t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) + } + if !bytes.Equal(val, []byte("H")) { + t.Fatalf("prover %d: verified value mismatch: have %x, want %v", i, val, v) + } + } + } +} + // mutateByte changes one byte in b. func mutateByte(b []byte) { for r := mrand.Intn(len(b)); ; { From f05e43f3667b628d5b43ec8339e0e176277ff70a Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 28 Mar 2023 13:57:06 +0800 Subject: [PATCH 05/51] witness: decouple traverse nodes --- trie/proof.go | 65 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/trie/proof.go b/trie/proof.go index b413eeaf1a..062f88d756 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -39,32 +39,11 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e key = keybytesToHex(key) var nodes []node tn := t.root - for len(key) > 0 && tn != nil { - switch n := tn.(type) { - case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - // The trie doesn't contain the key. - tn = nil - } else { - tn = n.Val - key = key[len(n.Key):] - } - nodes = append(nodes, n) - case *fullNode: - tn = n.Children[key[0]] - key = key[1:] - nodes = append(nodes, n) - case hashNode: - var err error - tn, err = t.resolveHash(n, nil) - if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return err - } - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) - } + _, err := t.traverseNodes(tn, key, &nodes) + if err != nil { + return err } + hasher := newHasher(false) defer returnHasherToPool(hasher) @@ -128,6 +107,42 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } +// TODO (asyukii): Write function comment +func (t *Trie) traverseNodes(tn node, key []byte, nodes *[]node) (node, error) { + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + key = key[len(n.Key):] + } + if nodes != nil{ + *nodes = append(*nodes, n) + } + case *fullNode: + tn = n.Children[key[0]] + key = key[1:] + if nodes != nil{ + *nodes = append(*nodes, n) + } + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return tn, err + } + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + + return tn, nil +} + // proofToPath converts a merkle proof to trie node path. The main purpose of // this function is recovering a node path from the merkle proof stream. All // necessary nodes will be resolved and leave the remaining as hashnode. From b1322b0befe21beaad4a5f71f53164888bf34516 Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 29 Mar 2023 12:36:10 +0800 Subject: [PATCH 06/51] witness: add MPTProof --- trie/proof.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/trie/proof.go b/trie/proof.go index 062f88d756..81aaafa39b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -27,6 +27,11 @@ import ( "github.com/ethereum/go-ethereum/log" ) +type MPTProof struct { + prefixKey []byte // prefix key + proof [][]byte // list of RLP-encoded nodes +} + // Prove constructs a merkle proof for key. The result contains all encoded nodes // on the path to the value at key. The value itself is also included in the last // node and can be retrieved by verifying the proof. From 0b9a12be6598962a9b1da5b79715765f58fbe22d Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 29 Mar 2023 17:01:12 +0800 Subject: [PATCH 07/51] witness: generate storage proof --- trie/proof.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/trie/proof.go b/trie/proof.go index 81aaafa39b..332823b86d 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -83,6 +83,53 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri return t.trie.Prove(key, fromLevel, proofDb) } +// ProveStorage constructs a merkle proof for a storage key. The storage key should +// already be converted to nibbles. If the prefix key is specified, the proof will +// start from the node that contains the prefix key to get the partial proof. +// The result contains all encoded nodes from the starting node to the node that contains +// the value. The value itself is also included in the last node and can be retrieved by +// verifying the proof. +func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { + + if len(key) == 0 { + return fmt.Errorf("key is empty") + } + + // traverse down using the prefixKey + var nodes []node + tn := t.root + startNode, err := t.traverseNodes(tn, prefixKey, nil) // obtain the node that contains the prefixKey + if err != nil { + return err + } + + key = key[len(prefixKey):] // obtain the suffix key + + // traverse through the suffix key + _, err = t.traverseNodes(startNode, key, &nodes) + if err != nil{ + return err + } + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + // construct the proof + for i, n := range nodes { + var hn node + n, hn = hasher.proofHash(n) + if hash, ok := hn.(hashNode); ok || i == 0 { + enc := nodeToBytes(n) + if !ok { + hash = hasher.hashData(enc) + } + proofDb.Put(hash, enc) + } + } + + return nil +} + // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. From ce66824e7c7b3e5d1b136926a59b96144959e586 Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 29 Mar 2023 17:02:08 +0800 Subject: [PATCH 08/51] witness: verify storage proof --- trie/proof.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/trie/proof.go b/trie/proof.go index 332823b86d..b41659327f 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -159,11 +159,63 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } -// TODO (asyukii): Write function comment +// VerifyStorageProof checks a merkle proof for a storage key. The storage key should +// already be converted to nibbles. If the prefix key is specified, it will traverse +// down to the node that contains the prefix key. From there, proof will be verified. +// VerifyStorageProof returns an error if the proof contains invalid trie nodes. +func (t *Trie) VerifyStorageProof(key []byte, prefixKey []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { + + if len(key) == 0 { + return nil, fmt.Errorf("empty key provided") + } + + tn := t.root + startNode, err := t.traverseNodes(tn, prefixKey, nil) + if err != nil { + return nil, err + } + + key = key[len(prefixKey):] // obtain the suffix key + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + _, hn := hasher.proofHash(startNode) + wantHash, ok := hn.(hashNode) + + if !ok { // node is not hashed + return nil, nil + } + for i := 0; ; i++ { + buf, _ := proofDb.Get(wantHash[:]) + if buf == nil { + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) + } + n, err := decodeNode(wantHash[:], buf) + if err != nil { + return nil, fmt.Errorf("bad proof node %d: %v", i, err) + } + keyrest, cld := get(n, key, true) + switch cld := cld.(type) { + case nil: + return nil, nil + case hashNode: + key = keyrest + copy(wantHash[:], cld) + case valueNode: + return cld, nil + } + } +} + +// traverseNodes traverses the trie with the given key starting at the given node. +// If the trie contains the key, the returned node is the node that contains the +// value for the key. If nodes is specified, the traversed nodes are appended to +// it. func (t *Trie) traverseNodes(tn node, key []byte, nodes *[]node) (node, error) { for len(key) > 0 && tn != nil { switch n := tn.(type) { - case *shortNode: + case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { // The trie doesn't contain the key. tn = nil From 2ea256b5398a42b3f68808a8c09bcd2d8a1f6e2a Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 29 Mar 2023 17:02:47 +0800 Subject: [PATCH 09/51] witness: add test cases --- trie/proof_test.go | 213 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 208 insertions(+), 5 deletions(-) diff --git a/trie/proof_test.go b/trie/proof_test.go index 9895d56f53..3bde1b1786 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -892,14 +892,40 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { } } -func TestProofWithExpiredSisterNode(t *testing.T) { +// TestStorageProof tests the storage proof generation and verification. +// This test will also test for partial proof generation and verification. +func TestStorageProof(t *testing.T){ + trie, vals := randomTrie(500) + for _, kv := range vals { + prefixKeys := getPrefixKeys(trie, []byte(kv.k)) + for _, prefixKey := range prefixKeys { + proof := memorydb.New() + key := keybytesToHex([]byte(kv.k)) + err := trie.ProveStorage(key, prefixKey, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", kv.k) + } + val, err := trie.VerifyStorageProof(key, prefixKey, proof) + if err != nil { + t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) + } + if val != nil && !bytes.Equal(val, []byte(kv.v)) { + t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) + } + } + } +} + +// TestStorageProofWithExpiredSubTree tests the storage proof with expired sub tree. +// The prover is expected to give valid proof for the unexpired data. +func TestStorageProofWithExpiredSubTree(t *testing.T) { trie := new(Trie) expiredData := map[string]string{ - "defg": "E", - "defh": "F", - "degh": "G", - "degi": "H", + "abcd": "A", + "abce": "B", + "abde": "C", + "abdf": "D", } @@ -940,6 +966,141 @@ func TestProofWithExpiredSisterNode(t *testing.T) { } } +// TestOneElementStorageProof tests the storage proof generation and verification +// for a trie with only one element. +func TestOneElementStorageProof(t *testing.T){ + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := keybytesToHex([]byte("k")) + err := trie.ProveStorage(key, nil, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", key) + } + + if proof.Len() != 1 { + t.Errorf("proof should have one element") + } + + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) + if err != nil { + t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proof) + } + + if !bytes.Equal(val, []byte("v")) { + t.Fatalf("verified value mismatch: have %x, want 'v'", val) + } +} + +// TestEmptyStorageProof tests storage verification with empty proof. +// The verifier should nil for both value and error. +func TestEmptyStorageProof(t *testing.T){ + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := keybytesToHex([]byte("k")) + + val, err := trie.VerifyStorageProof(key, nil, proof) + if val != nil && err != nil{ + t.Fatalf("expected nil value and error for empty proof") + } +} + +// TestEmptyKeyStorageProof tests the storage proof with empty key. +// The prover is expected to return +func TestEmptyKeyStorageProof(t *testing.T){ + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + err := trie.ProveStorage([]byte(""), nil, proof) + if err == nil { + t.Fatalf("expected error for empty key") + } +} + +// TestEmptyPrefixKeyStorageProof tests the storage proof with empty prefix key, +// which means that all proofs generated are from the root node. +func TestEmptyPrefixKeyStorageProof(t *testing.T){ + trie, vals := randomTrie(500) + for _, kv := range vals { + proof := memorydb.New() + key := keybytesToHex(kv.k) + err := trie.ProveStorage(key, nil, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", key) + } + val, err := trie.VerifyStorageProof(key, nil, proof) + if err != nil { + t.Fatalf("failed to verify proof for key %x: %v\nraw proof: %x", key, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("verified value mismatch for key %x: have %x, want %x", key, val, kv.v) + } + } +} + +// TestBadStorageProof tests a few cases which the proof is wrong. +// The proof is expected to detect the error. +func TestBadStorageProof(t *testing.T){ + + trie, vals := randomTrie(500) + for _, kv := range vals { + prefixKeys := getPrefixKeys(trie, []byte(kv.k)) + for _, prefixKey := range prefixKeys { + proof := memorydb.New() + key := keybytesToHex([]byte(kv.k)) + err := trie.ProveStorage(key, prefixKey, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", key) + } + + if proof.Len() == 0 { + continue + } + + it := proof.NewIterator(nil, nil) + for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ { + it.Next() + } + itKey := it.Key() + itVal, _ := proof.Get(itKey) + proof.Delete(itKey) + it.Release() + + mutateByte(itVal) + proof.Put(crypto.Keccak256(itVal), itVal) + + if val, err := trie.VerifyStorageProof(key, prefixKey, proof); err == nil && val != nil{ + t.Fatalf("expected proof to fail for key: %x, prefix: %x", key, prefixKey) + } + } + } +} + +// TODO +func TestBadKeyStorageProof(t *testing.T){ + return +} + +// TODO +func TestBadPrefixKeyStorageProof(t *testing.T){ + return +} + +// TODO +func TestKeyPrefixKeySame(t *testing.T){ + return +} + +// TODO: get the proof of unexpired tree, then expire it, then get proof of expired tree. +// compare them +func TestExpiredProof(t *testing.T){ + +} + // mutateByte changes one byte in b. func mutateByte(b []byte) { for r := mrand.Intn(len(b)); ; { @@ -1148,3 +1309,45 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { t.Error("expected more to be false") } } + +func getPrefixKeys(t *Trie, key []byte) ([][]byte){ + var prefixKeys [][]byte + key = keybytesToHex(key) + tn := t.root + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + // Check if there is a previous key in prefixKeys + if len(prefixKeys) == 0 { + prefixKeys = append(prefixKeys, n.Key) + } else{ + prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], n.Key...)) + } + key = key[len(n.Key):] + } + case *fullNode: + tn = n.Children[key[0]] + if len(prefixKeys) == 0 { + prefixKeys = append(prefixKeys, key[:1]) + } else{ + prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], key[:1]...)) + } + key = key[1:] + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + return nil + } + default: + return nil + } + } + + return prefixKeys +} \ No newline at end of file From 7c94fa88cc2ff7f3d71b262b5286ae64623f3516 Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 09:35:09 +0800 Subject: [PATCH 10/51] witness: add test cases --- trie/proof_test.go | 149 +++++++++++++++++++++++++++------------------ 1 file changed, 90 insertions(+), 59 deletions(-) diff --git a/trie/proof_test.go b/trie/proof_test.go index 3bde1b1786..99cf3b8a9b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -916,56 +916,6 @@ func TestStorageProof(t *testing.T){ } } -// TestStorageProofWithExpiredSubTree tests the storage proof with expired sub tree. -// The prover is expected to give valid proof for the unexpired data. -func TestStorageProofWithExpiredSubTree(t *testing.T) { - trie := new(Trie) - - expiredData := map[string]string{ - "abcd": "A", - "abce": "B", - "abde": "C", - "abdf": "D", - } - - - unexpiredData := map[string]string{ - "defg": "E", - "defh": "F", - "degh": "G", - "degi": "H", - } - - // Loop through the data and insert it into the trie - for k, v := range expiredData { - updateString(trie, k, v) - } - - for k, v := range unexpiredData { - updateString(trie, k, v) - } - - prefixKey := keybytesToHex([]byte("abcd"))[:2] - - trie.ExpireByPrefix(prefixKey) - - for i, prover := range makeProvers(trie){ - for k, v := range unexpiredData { - proof := prover([]byte(k)) - if proof == nil { - t.Fatalf("proof %d is nil", i) - } - val, err := VerifyProof(trie.Hash(), []byte("degi"), proof) - if err != nil { - t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) - } - if !bytes.Equal(val, []byte("H")) { - t.Fatalf("prover %d: verified value mismatch: have %x, want %v", i, val, v) - } - } - } -} - // TestOneElementStorageProof tests the storage proof generation and verification // for a trie with only one element. func TestOneElementStorageProof(t *testing.T){ @@ -1080,25 +1030,106 @@ func TestBadStorageProof(t *testing.T){ } } -// TODO +// TestBadKeyStorageProof tests the storage proof with a bad key. +// The verifier is expected to return nil for both value and error. func TestBadKeyStorageProof(t *testing.T){ - return + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := keybytesToHex([]byte("x")) + trie.ProveStorage(key, nil, proof) + + val, err := trie.VerifyStorageProof(key, nil, proof) + if val != nil && err != nil{ + t.Fatalf("expected nil value and error for bad key") + } } -// TODO +// TestBadPrefixKeyStorageProof tests the storage proof with a bad prefix key. +// The verifier is expected to return nil for both value and error. func TestBadPrefixKeyStorageProof(t *testing.T){ - return + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := keybytesToHex([]byte("k")) + + prefixKey := keybytesToHex([]byte("x")) + + trie.ProveStorage(key, prefixKey, proof) + + val, err := trie.VerifyStorageProof(key, prefixKey, proof) + if val != nil && err != nil{ + t.Fatalf("expected nil value and error for bad prefix key") + } } -// TODO +// TestKeyPrefixKeySame tests the storage proof with the same key and prefix key. +// The proof size should be 0 and the verifier should return nil for both value and error. func TestKeyPrefixKeySame(t *testing.T){ - return + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := keybytesToHex([]byte("k")) + + trie.ProveStorage(key, key, proof) + if proof.Len() != 0 { + t.Fatalf("expected proof size to be 0 for same key and prefix key") + } + + val, err := trie.VerifyStorageProof(key, key, proof) + if val != nil && err != nil{ + t.Fatalf("expected nil value and error for same key and prefix key") + } } -// TODO: get the proof of unexpired tree, then expire it, then get proof of expired tree. -// compare them -func TestExpiredProof(t *testing.T){ +// TestUnexpiredStorageProof tests the storage proof with a trie containing +// both expired and unexpired data. The prover is expected to give valid proof +// for the unexpired data. +func TestUnexpiredStorageProof(t *testing.T) { + trie := new(Trie) + expiredData := map[string]string{ + "abcd": "A", + "abce": "B", + "abde": "C", + "abdf": "D", + } + + + unexpiredData := map[string]string{ + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + // Loop through the data and insert it into the trie + for k, v := range expiredData { + updateString(trie, k, v) + } + + for k, v := range unexpiredData { + updateString(trie, k, v) + } + + prefixKey := keybytesToHex([]byte("abcd"))[:2] + + trie.ExpireByPrefix(prefixKey) + + proof := memorydb.New() + key := keybytesToHex([]byte("degi")) + + trie.ProveStorage(key, nil, proof) + val, err := trie.VerifyStorageProof(key, nil, proof) + if err != nil { + t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proof) + } + if !bytes.Equal(val, []byte("H")) { + t.Fatalf("verified value mismatch: have %x, want %v", val, "H") + } } // mutateByte changes one byte in b. From 2925a0ab0e81f0875d6ad5703a0289ec07cac1cf Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 10:16:28 +0800 Subject: [PATCH 11/51] core/state: add ProveStorage to Trie interface --- core/state/database.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/state/database.go b/core/state/database.go index 0f31bc9139..1198d1b64b 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -126,6 +126,9 @@ type Trie interface { // nodes of the longest existing prefix of the key (at least the root), ending // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error + + ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error + } // NewDatabase creates a backing store for state. The returned database is safe for From def43ee883c0d071e4e2362dd87bc728f04cc95a Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 10:18:06 +0800 Subject: [PATCH 12/51] trie: implement ProveStorage --- trie/dummy_trie.go | 4 ++++ trie/proof.go | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 42e79a3719..12fd342e56 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -32,6 +32,10 @@ func (t *EmptyTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWrit return nil } +func (t *EmptyTrie) ProveStorage(key, from []byte, proofDb ethdb.KeyValueWriter) error { + return nil +} + // NewSecure creates a dummy trie func NewEmptyTrie() *EmptyTrie { return &EmptyTrie{} diff --git a/trie/proof.go b/trie/proof.go index b41659327f..1d29c78685 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -130,6 +130,10 @@ func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValue return nil } +func (t *SecureTrie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { + return t.trie.ProveStorage(key, prefixKey, proofDb) +} + // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. From c44da023be44b83dc3321aa83d38458b8fa7ee23 Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 10:18:57 +0800 Subject: [PATCH 13/51] witness: convert key to hex inside function --- trie/proof.go | 14 ++++++++------ trie/proof_test.go | 19 ++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/trie/proof.go b/trie/proof.go index 1d29c78685..8d8b2df79e 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -83,9 +83,8 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri return t.trie.Prove(key, fromLevel, proofDb) } -// ProveStorage constructs a merkle proof for a storage key. The storage key should -// already be converted to nibbles. If the prefix key is specified, the proof will -// start from the node that contains the prefix key to get the partial proof. +// ProveStorage constructs a merkle proof for a storage key. If the prefix key is specified, +// the proof will start from the node that contains the prefix key to get the partial proof. // The result contains all encoded nodes from the starting node to the node that contains // the value. The value itself is also included in the last node and can be retrieved by // verifying the proof. @@ -95,6 +94,8 @@ func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValue return fmt.Errorf("key is empty") } + key = keybytesToHex(key) + // traverse down using the prefixKey var nodes []node tn := t.root @@ -163,15 +164,16 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } -// VerifyStorageProof checks a merkle proof for a storage key. The storage key should -// already be converted to nibbles. If the prefix key is specified, it will traverse -// down to the node that contains the prefix key. From there, proof will be verified. +// VerifyStorageProof checks a merkle proof for a storage key. If the prefix key is specified, +// it will traverse down to the node that contains the prefix key. From there, proof will be verified. // VerifyStorageProof returns an error if the proof contains invalid trie nodes. func (t *Trie) VerifyStorageProof(key []byte, prefixKey []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { if len(key) == 0 { return nil, fmt.Errorf("empty key provided") } + + key = keybytesToHex(key) tn := t.root startNode, err := t.traverseNodes(tn, prefixKey, nil) diff --git a/trie/proof_test.go b/trie/proof_test.go index 99cf3b8a9b..a81f5fba5c 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -900,7 +900,7 @@ func TestStorageProof(t *testing.T){ prefixKeys := getPrefixKeys(trie, []byte(kv.k)) for _, prefixKey := range prefixKeys { proof := memorydb.New() - key := keybytesToHex([]byte(kv.k)) + key := kv.k err := trie.ProveStorage(key, prefixKey, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", kv.k) @@ -923,7 +923,7 @@ func TestOneElementStorageProof(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - key := keybytesToHex([]byte("k")) + key := []byte("k") err := trie.ProveStorage(key, nil, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) @@ -950,7 +950,7 @@ func TestEmptyStorageProof(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - key := keybytesToHex([]byte("k")) + key := []byte("k") val, err := trie.VerifyStorageProof(key, nil, proof) if val != nil && err != nil{ @@ -977,7 +977,8 @@ func TestEmptyPrefixKeyStorageProof(t *testing.T){ trie, vals := randomTrie(500) for _, kv := range vals { proof := memorydb.New() - key := keybytesToHex(kv.k) + key := kv.k + err := trie.ProveStorage(key, nil, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) @@ -1001,7 +1002,7 @@ func TestBadStorageProof(t *testing.T){ prefixKeys := getPrefixKeys(trie, []byte(kv.k)) for _, prefixKey := range prefixKeys { proof := memorydb.New() - key := keybytesToHex([]byte(kv.k)) + key := []byte(kv.k) err := trie.ProveStorage(key, prefixKey, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) @@ -1037,7 +1038,7 @@ func TestBadKeyStorageProof(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - key := keybytesToHex([]byte("x")) + key := []byte("x") trie.ProveStorage(key, nil, proof) val, err := trie.VerifyStorageProof(key, nil, proof) @@ -1053,7 +1054,7 @@ func TestBadPrefixKeyStorageProof(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - key := keybytesToHex([]byte("k")) + key := []byte("k") prefixKey := keybytesToHex([]byte("x")) @@ -1072,7 +1073,7 @@ func TestKeyPrefixKeySame(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - key := keybytesToHex([]byte("k")) + key := []byte("k") trie.ProveStorage(key, key, proof) if proof.Len() != 0 { @@ -1120,7 +1121,7 @@ func TestUnexpiredStorageProof(t *testing.T) { trie.ExpireByPrefix(prefixKey) proof := memorydb.New() - key := keybytesToHex([]byte("degi")) + key := []byte("degi") trie.ProveStorage(key, nil, proof) val, err := trie.VerifyStorageProof(key, nil, proof) From 1397d453f4293c0e11525eb6cce9a13469f14eb6 Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 10:19:19 +0800 Subject: [PATCH 14/51] core/state: modify GetStorageProof --- core/state/statedb.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/state/statedb.go b/core/state/statedb.go index 617dbfa1b7..58665350ef 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -497,13 +497,13 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { } // GetStorageProof returns the Merkle proof for given storage slot. -func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { +func (s *StateDB) GetStorageProof(a common.Address, prefixKey []byte, key common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) if trie == nil { return proof, errors.New("storage trie for requested address does not exist") } - err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) + err := trie.ProveStorage(crypto.Keccak256(key.Bytes()), prefixKey, &proof) return proof, err } From 9d2bcaa5ba957edd3267e199d73f188275f526ad Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 10:22:10 +0800 Subject: [PATCH 15/51] witness: remove MPTProof struct --- trie/proof.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/trie/proof.go b/trie/proof.go index 8d8b2df79e..2fa40e91b3 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -27,11 +27,6 @@ import ( "github.com/ethereum/go-ethereum/log" ) -type MPTProof struct { - prefixKey []byte // prefix key - proof [][]byte // list of RLP-encoded nodes -} - // Prove constructs a merkle proof for key. The result contains all encoded nodes // on the path to the value at key. The value itself is also included in the last // node and can be retrieved by verifying the proof. From ea96468a9d37276695dcb0a8393ba091aa303efc Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 17:14:25 +0800 Subject: [PATCH 16/51] core/state: add GetStorageWitness() and revert GetStorageProof --- core/state/statedb.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/core/state/statedb.go b/core/state/statedb.go index 58665350ef..1163e0d45f 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -496,14 +496,25 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { return proof, err } -// GetStorageProof returns the Merkle proof for given storage slot. -func (s *StateDB) GetStorageProof(a common.Address, prefixKey []byte, key common.Hash) ([][]byte, error) { +// GetStorageWitness returns only the Merkle proof for given storage slot. +func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, slotKey common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) if trie == nil { return proof, errors.New("storage trie for requested address does not exist") } - err := trie.ProveStorage(crypto.Keccak256(key.Bytes()), prefixKey, &proof) + err := trie.ProveStorage(crypto.Keccak256(slotKey.Bytes()), prefixKeyHex, &proof) + return proof, err +} + +// TODO: GetStorageProof returns the combined Merkle proof and Shadow Tree proof for given storage slot. +func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { + var proof proofList + trie := s.StorageTrie(a) + if trie == nil { + return proof, errors.New("storage trie for requested address does not exist") + } + err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) return proof, err } From 1b14e57a21a1e4bf8052713b454763efcedcb83e Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 30 Mar 2023 17:20:54 +0800 Subject: [PATCH 17/51] refactor: rename function and parameter names --- core/state/database.go | 2 +- core/state/statedb.go | 4 ++-- trie/dummy_trie.go | 2 +- trie/proof.go | 22 +++++++++++----------- trie/proof_test.go | 34 +++++++++++++++++----------------- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 1198d1b64b..37430062e8 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -127,7 +127,7 @@ type Trie interface { // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error - ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error + ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error } diff --git a/core/state/statedb.go b/core/state/statedb.go index 1163e0d45f..c232537f5e 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -497,13 +497,13 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { } // GetStorageWitness returns only the Merkle proof for given storage slot. -func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, slotKey common.Hash) ([][]byte, error) { +func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, key common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) if trie == nil { return proof, errors.New("storage trie for requested address does not exist") } - err := trie.ProveStorage(crypto.Keccak256(slotKey.Bytes()), prefixKeyHex, &proof) + err := trie.ProveStorageWitness(crypto.Keccak256(key.Bytes()), prefixKeyHex, &proof) return proof, err } diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 12fd342e56..be0c3a01b6 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -32,7 +32,7 @@ func (t *EmptyTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWrit return nil } -func (t *EmptyTrie) ProveStorage(key, from []byte, proofDb ethdb.KeyValueWriter) error { +func (t *EmptyTrie) ProveStorageWitness(key, from []byte, proofDb ethdb.KeyValueWriter) error { return nil } diff --git a/trie/proof.go b/trie/proof.go index 2fa40e91b3..63c3d22760 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -78,12 +78,12 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri return t.trie.Prove(key, fromLevel, proofDb) } -// ProveStorage constructs a merkle proof for a storage key. If the prefix key is specified, +// ProveStorageWitness constructs a merkle proof for a storage key. If the prefix key is specified, // the proof will start from the node that contains the prefix key to get the partial proof. // The result contains all encoded nodes from the starting node to the node that contains // the value. The value itself is also included in the last node and can be retrieved by // verifying the proof. -func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { +func (t *Trie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { if len(key) == 0 { return fmt.Errorf("key is empty") @@ -91,15 +91,15 @@ func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValue key = keybytesToHex(key) - // traverse down using the prefixKey + // traverse down using the prefixKeyHex var nodes []node tn := t.root - startNode, err := t.traverseNodes(tn, prefixKey, nil) // obtain the node that contains the prefixKey + startNode, err := t.traverseNodes(tn, prefixKeyHex, nil) // obtain the node that contains the prefixKeyHex if err != nil { return err } - key = key[len(prefixKey):] // obtain the suffix key + key = key[len(prefixKeyHex):] // obtain the suffix key // traverse through the suffix key _, err = t.traverseNodes(startNode, key, &nodes) @@ -126,8 +126,8 @@ func (t *Trie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValue return nil } -func (t *SecureTrie) ProveStorage(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { - return t.trie.ProveStorage(key, prefixKey, proofDb) +func (t *SecureTrie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { + return t.trie.ProveStorageWitness(key, prefixKeyHex, proofDb) } // VerifyProof checks merkle proofs. The given proof must contain the value for @@ -159,10 +159,10 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } -// VerifyStorageProof checks a merkle proof for a storage key. If the prefix key is specified, +// VerifyStorageWitness checks a merkle proof for a storage key. If the prefix key is specified, // it will traverse down to the node that contains the prefix key. From there, proof will be verified. // VerifyStorageProof returns an error if the proof contains invalid trie nodes. -func (t *Trie) VerifyStorageProof(key []byte, prefixKey []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { +func (t *Trie) VerifyStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { if len(key) == 0 { return nil, fmt.Errorf("empty key provided") @@ -171,12 +171,12 @@ func (t *Trie) VerifyStorageProof(key []byte, prefixKey []byte, proofDb ethdb.Ke key = keybytesToHex(key) tn := t.root - startNode, err := t.traverseNodes(tn, prefixKey, nil) + startNode, err := t.traverseNodes(tn, prefixKeyHex, nil) if err != nil { return nil, err } - key = key[len(prefixKey):] // obtain the suffix key + key = key[len(prefixKeyHex):] // obtain the suffix key hasher := newHasher(false) defer returnHasherToPool(hasher) diff --git a/trie/proof_test.go b/trie/proof_test.go index a81f5fba5c..4435821cd3 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -901,11 +901,11 @@ func TestStorageProof(t *testing.T){ for _, prefixKey := range prefixKeys { proof := memorydb.New() key := kv.k - err := trie.ProveStorage(key, prefixKey, proof) + err := trie.ProveStorageWitness(key, prefixKey, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", kv.k) } - val, err := trie.VerifyStorageProof(key, prefixKey, proof) + val, err := trie.VerifyStorageWitness(key, prefixKey, proof) if err != nil { t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) } @@ -924,7 +924,7 @@ func TestOneElementStorageProof(t *testing.T){ proof := memorydb.New() key := []byte("k") - err := trie.ProveStorage(key, nil, proof) + err := trie.ProveStorageWitness(key, nil, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) } @@ -952,7 +952,7 @@ func TestEmptyStorageProof(t *testing.T){ proof := memorydb.New() key := []byte("k") - val, err := trie.VerifyStorageProof(key, nil, proof) + val, err := trie.VerifyStorageWitness(key, nil, proof) if val != nil && err != nil{ t.Fatalf("expected nil value and error for empty proof") } @@ -965,7 +965,7 @@ func TestEmptyKeyStorageProof(t *testing.T){ updateString(trie, "k", "v") proof := memorydb.New() - err := trie.ProveStorage([]byte(""), nil, proof) + err := trie.ProveStorageWitness([]byte(""), nil, proof) if err == nil { t.Fatalf("expected error for empty key") } @@ -979,11 +979,11 @@ func TestEmptyPrefixKeyStorageProof(t *testing.T){ proof := memorydb.New() key := kv.k - err := trie.ProveStorage(key, nil, proof) + err := trie.ProveStorageWitness(key, nil, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) } - val, err := trie.VerifyStorageProof(key, nil, proof) + val, err := trie.VerifyStorageWitness(key, nil, proof) if err != nil { t.Fatalf("failed to verify proof for key %x: %v\nraw proof: %x", key, err, proof) } @@ -1003,7 +1003,7 @@ func TestBadStorageProof(t *testing.T){ for _, prefixKey := range prefixKeys { proof := memorydb.New() key := []byte(kv.k) - err := trie.ProveStorage(key, prefixKey, proof) + err := trie.ProveStorageWitness(key, prefixKey, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) } @@ -1024,7 +1024,7 @@ func TestBadStorageProof(t *testing.T){ mutateByte(itVal) proof.Put(crypto.Keccak256(itVal), itVal) - if val, err := trie.VerifyStorageProof(key, prefixKey, proof); err == nil && val != nil{ + if val, err := trie.VerifyStorageWitness(key, prefixKey, proof); err == nil && val != nil{ t.Fatalf("expected proof to fail for key: %x, prefix: %x", key, prefixKey) } } @@ -1039,9 +1039,9 @@ func TestBadKeyStorageProof(t *testing.T){ proof := memorydb.New() key := []byte("x") - trie.ProveStorage(key, nil, proof) + trie.ProveStorageWitness(key, nil, proof) - val, err := trie.VerifyStorageProof(key, nil, proof) + val, err := trie.VerifyStorageWitness(key, nil, proof) if val != nil && err != nil{ t.Fatalf("expected nil value and error for bad key") } @@ -1058,9 +1058,9 @@ func TestBadPrefixKeyStorageProof(t *testing.T){ prefixKey := keybytesToHex([]byte("x")) - trie.ProveStorage(key, prefixKey, proof) + trie.ProveStorageWitness(key, prefixKey, proof) - val, err := trie.VerifyStorageProof(key, prefixKey, proof) + val, err := trie.VerifyStorageWitness(key, prefixKey, proof) if val != nil && err != nil{ t.Fatalf("expected nil value and error for bad prefix key") } @@ -1075,12 +1075,12 @@ func TestKeyPrefixKeySame(t *testing.T){ proof := memorydb.New() key := []byte("k") - trie.ProveStorage(key, key, proof) + trie.ProveStorageWitness(key, key, proof) if proof.Len() != 0 { t.Fatalf("expected proof size to be 0 for same key and prefix key") } - val, err := trie.VerifyStorageProof(key, key, proof) + val, err := trie.VerifyStorageWitness(key, key, proof) if val != nil && err != nil{ t.Fatalf("expected nil value and error for same key and prefix key") } @@ -1123,8 +1123,8 @@ func TestUnexpiredStorageProof(t *testing.T) { proof := memorydb.New() key := []byte("degi") - trie.ProveStorage(key, nil, proof) - val, err := trie.VerifyStorageProof(key, nil, proof) + trie.ProveStorageWitness(key, nil, proof) + val, err := trie.VerifyStorageWitness(key, nil, proof) if err != nil { t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proof) } From 9fee384c3f1fa769cf79f504bdcf6a55c3d560da Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 31 Mar 2023 14:10:51 +0800 Subject: [PATCH 18/51] core/transaction: add revive state gas computation; --- cmd/evm/internal/t8ntool/transaction.go | 2 +- core/bench_test.go | 2 +- core/state_transition.go | 12 +++- core/state_transition_test.go | 85 +++++++++++++++++++++++++ core/tx_pool.go | 2 +- core/types/revive_state_tx.go | 13 ++++ core/types/revive_witness.go | 41 ++++++++++-- core/types/transaction_test.go | 10 +-- light/trie.go | 4 ++ light/txpool.go | 2 +- params/protocol_params.go | 11 ++-- tests/transaction_test_util.go | 2 +- 12 files changed, 165 insertions(+), 21 deletions(-) create mode 100644 core/state_transition_test.go diff --git a/cmd/evm/internal/t8ntool/transaction.go b/cmd/evm/internal/t8ntool/transaction.go index 6f1c964ada..57057089cb 100644 --- a/cmd/evm/internal/t8ntool/transaction.go +++ b/cmd/evm/internal/t8ntool/transaction.go @@ -139,7 +139,7 @@ func Transaction(ctx *cli.Context) error { r.Address = sender } // Check intrinsic gas - if gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, + if gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, chainConfig.IsHomestead(new(big.Int)), chainConfig.IsIstanbul(new(big.Int))); err != nil { r.Error = err results = append(results, r) diff --git a/core/bench_test.go b/core/bench_test.go index 06333033c4..e78a926836 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -85,7 +85,7 @@ func genValueTx(nbytes int) func(int, *BlockGen) { return func(i int, gen *BlockGen) { toaddr := common.Address{} data := make([]byte, nbytes) - gas, _ := IntrinsicGas(data, nil, false, false, false) + gas, _ := IntrinsicGas(data, nil, nil, false, false, false) signer := types.MakeSigner(gen.config, big.NewInt(int64(i))) gasPrice := big.NewInt(0) if gen.header.BaseFee != nil { diff --git a/core/state_transition.go b/core/state_transition.go index 1d8dfcf7de..a85edd5cf0 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -119,7 +119,7 @@ func (result *ExecutionResult) Revert() []byte { } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. -func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation bool, isHomestead, isEIP2028 bool) (uint64, error) { +func IntrinsicGas(data []byte, accessList types.AccessList, witnessList types.WitnessList, isContractCreation bool, isHomestead, isEIP2028 bool) (uint64, error) { // Set the starting gas for the raw transaction var gas uint64 if isContractCreation && isHomestead { @@ -156,6 +156,14 @@ func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation b gas += uint64(len(accessList)) * params.TxAccessListAddressGas gas += uint64(accessList.StorageKeys()) * params.TxAccessListStorageKeyGas } + + if witnessList != nil { + witGas := types.WitnessIntrinsicGas(witnessList) + if (math.MaxUint64 - gas) < witGas { + return 0, ErrGasUintOverflow + } + gas += witGas + } return gas, nil } @@ -316,7 +324,7 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { } } // Check clauses 4-5, subtract intrinsic gas if everything is correct - gas, err := IntrinsicGas(st.data, st.msg.AccessList(), contractCreation, rules.IsHomestead, rules.IsIstanbul) + gas, err := IntrinsicGas(st.data, st.msg.AccessList(), st.msg.WitnessList(), contractCreation, rules.IsHomestead, rules.IsIstanbul) if err != nil { return nil, err } diff --git a/core/state_transition_test.go b/core/state_transition_test.go new file mode 100644 index 0000000000..57f5bd7da1 --- /dev/null +++ b/core/state_transition_test.go @@ -0,0 +1,85 @@ +package core + +import ( + "bytes" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" + "testing" +) + +func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, proofLen int) types.ReviveWitness { + proofList := make([]types.MPTProof, witSize) + for i := range proofList { + proof := make([][]byte, proofCount) + for j := range proof { + proof[j] = bytes.Repeat([]byte{'p'}, proofLen) + } + proofList[i] = types.MPTProof{ + Key: bytes.Repeat([]byte{'k'}, keyLen), + Proof: proof, + } + } + return types.ReviveWitness{ + WitnessType: types.MPTWitnessType, + Address: addr, + ProofList: proofList, + } +} + +func TestIntrinsicGas_WitnessList(t *testing.T) { + address := common.HexToAddress("d4584b5f6229b7be90727b0fc8c6b91bb427821f") + + test_data := []struct { + // input + data []byte + accessList types.AccessList + witnessList types.WitnessList + isContractCreation bool + isHomestead bool + isEIP2028 bool + // expect + gas uint64 + }{ + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 0, 100, 512), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 53416, + }, + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 1, 0, 512), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 55016, + }, + { + data: nil, + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 30, 2, 2, 32), + makeMerkleProofWitness(&address, 20, 1, 1, 36), + }, + isContractCreation: false, + isHomestead: true, + isEIP2028: true, + gas: 25948, + }, + } + + for _, item := range test_data { + gas, err := IntrinsicGas(item.data, item.accessList, item.witnessList, item.isContractCreation, item.isHomestead, item.isEIP2028) + assert.NoError(t, err) + assert.Equal(t, item.gas, gas) + } +} diff --git a/core/tx_pool.go b/core/tx_pool.go index b5f1d3a2fd..8eb16dbb54 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -706,7 +706,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } // Ensure the transaction has more gas than the basic tx fee. - intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true, pool.istanbul) + intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, true, pool.istanbul) if err != nil { return err } diff --git a/core/types/revive_state_tx.go b/core/types/revive_state_tx.go index 43646a465a..8cef4d2c1b 100644 --- a/core/types/revive_state_tx.go +++ b/core/types/revive_state_tx.go @@ -2,6 +2,7 @@ package types import ( "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" "math/big" ) @@ -109,3 +110,15 @@ func (tx *ReviveStateTx) rawSignatureValues() (v, r, s *big.Int) { func (tx *ReviveStateTx) setSignatureValues(chainID, v, r, s *big.Int) { tx.V, tx.R, tx.S = v, r, s } + +func WitnessIntrinsicGas(wits WitnessList) uint64 { + totalGas := uint64(0) + for i := 0; i < len(wits); i++ { + // witness size cost + totalGas += wits[i].Size() * params.TxWitnessListStorageGasPerByte + // witness verify cost + count, words := wits[i].ProofWords() + totalGas += count*params.TxWitnessListVerifyBaseGas + words*params.TxWitnessListVerifyGasPerWord + } + return totalGas +} diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go index 1900c71fbf..7cd6298af5 100644 --- a/core/types/revive_witness.go +++ b/core/types/revive_witness.go @@ -2,13 +2,44 @@ package types import "github.com/ethereum/go-ethereum/common" +const ( + MPTWitnessType = iota +) + type MPTProof struct { - key []byte // prefix key - proof [][]byte // list of RLP-encoded nodes + Key []byte // prefix key + Proof [][]byte // list of RLP-encoded nodes } type ReviveWitness struct { - witnessType byte // only support Merkle Proof for now - address *common.Address // target account address - proofList []MPTProof // revive multiple slots (same address) + WitnessType byte // only support Merkle Proof for now + Address *common.Address // target account address + ProofList []MPTProof // revive multiple slots (same address) +} + +// Size estimate witness byte size +func (r *ReviveWitness) Size() uint64 { + size := uint64(21) + for i := range r.ProofList { + size += uint64(len(r.ProofList[i].Key)) + for j := range r.ProofList[i].Proof { + size += uint64(len(r.ProofList[i].Proof[j])) + } + } + + return size +} + +// ProofWords get proof count and words +func (r *ReviveWitness) ProofWords() (uint64, uint64) { + count := uint64(0) + words := uint64(0) + for i := range r.ProofList { + for j := range r.ProofList[i].Proof { + count++ + words += uint64((len(r.ProofList[i].Proof[j]) + 31) / 32) + } + } + + return count, words } diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 3ad58a024b..90d0181701 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -576,11 +576,11 @@ func TestReviveStateTxAndSigner(t *testing.T) { addr = common.HexToAddress("0x0000000000000000000000000000000000000001") recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") witness = WitnessList{{ - witnessType: 0, - address: &addr, - proofList: []MPTProof{{ - key: common.Hex2Bytes("095e7baea6a6c7c4c2"), - proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, + WitnessType: 0, + Address: &addr, + ProofList: []MPTProof{{ + Key: common.Hex2Bytes("095e7baea6a6c7c4c2"), + Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, }}, }} ) diff --git a/light/trie.go b/light/trie.go index d41536e069..adf7f1d810 100644 --- a/light/trie.go +++ b/light/trie.go @@ -112,6 +112,10 @@ type odrTrie struct { trie *trie.Trie } +func (t *odrTrie) ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { + panic("implement me") +} + func (t *odrTrie) TryGet(key []byte) ([]byte, error) { key = crypto.Keccak256(key) var res []byte diff --git a/light/txpool.go b/light/txpool.go index d12694d8f9..8382514b13 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -385,7 +385,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error } // Should supply enough intrinsic gas - gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true, pool.istanbul) + gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, true, pool.istanbul) if err != nil { return err } diff --git a/params/protocol_params.go b/params/protocol_params.go index e244c24231..67bc3aa4a2 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -86,10 +86,13 @@ const ( SelfdestructRefundGas uint64 = 24000 // Refunded following a selfdestruct operation. MemoryGas uint64 = 3 // Times the address of the (highest referenced byte in memory + 1). NOTE: referencing happens on read, write and in instructions such as RETURN and CALL. - TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. - TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) - TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list - TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. + TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) + TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list + TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxWitnessListStorageGasPerByte uint64 = 16 // Per byte gas in BEP-215 witness list + TxWitnessListVerifyBaseGas uint64 = 60 // Base gas in BEP-215 witness list verify + TxWitnessListVerifyGasPerWord uint64 = 12 // Per-word price in BEP-215 witness list verify // These have been changed during the course of the chain CallGasFrontier uint64 = 40 // Once per CALL operation & message call transaction. diff --git a/tests/transaction_test_util.go b/tests/transaction_test_util.go index 82ee01de15..f63ed97dca 100644 --- a/tests/transaction_test_util.go +++ b/tests/transaction_test_util.go @@ -55,7 +55,7 @@ func (tt *TransactionTest) Run(config *params.ChainConfig) error { return nil, nil, err } // Intrinsic gas - requiredGas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, isHomestead, isIstanbul) + requiredGas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, isHomestead, isIstanbul) if err != nil { return nil, nil, err } From ce285c0b333e4df585a24f83feaaf03eb95b2e7a Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:35:05 +0800 Subject: [PATCH 19/51] core/ReviveWitness: refactor witness definition, add verify witness; txpool: add hard fork & witness check; trie: add MPTProofCache, include verify proof, MPTProofNubs; trie: add trie node type; core/StateTransition: add witness verify; --- core/state_transition.go | 57 ++++++++++- core/state_transition_test.go | 36 +++++-- core/tx_pool.go | 17 +++- core/types/revive_state_tx.go | 17 ++-- core/types/revive_witness.go | 142 ++++++++++++++++++++++----- core/types/revive_witness_test.go | 126 ++++++++++++++++++++++++ core/types/transaction_test.go | 24 +++-- light/txpool.go | 18 +++- params/config.go | 4 + params/protocol_params.go | 14 +-- trie/database.go | 12 +++ trie/node.go | 33 ++++++- trie/proof.go | 155 ++++++++++++++++++++++++++++-- trie/proof_test.go | 87 ++++++++++++----- 14 files changed, 655 insertions(+), 87 deletions(-) create mode 100644 core/types/revive_witness_test.go diff --git a/core/state_transition.go b/core/state_transition.go index a85edd5cf0..2f1554bc08 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -17,7 +17,9 @@ package core import ( + "errors" "fmt" + "github.com/ethereum/go-ethereum/trie" "math" "math/big" @@ -158,7 +160,10 @@ func IntrinsicGas(data []byte, accessList types.AccessList, witnessList types.Wi } if witnessList != nil { - witGas := types.WitnessIntrinsicGas(witnessList) + witGas, err := types.WitnessIntrinsicGas(witnessList) + if err != nil { + return 0, err + } if (math.MaxUint64 - gas) < witGas { return 0, ErrGasUintOverflow } @@ -268,6 +273,20 @@ func (st *StateTransition) preCheck() error { } } } + + // check witness and hard fork + if st.msg.WitnessList() != nil { + if !st.evm.ChainConfig().IsElwood(st.evm.Context.BlockNumber) { + return errors.New("cannot allow witness before Elwood fork") + } + witnessList := st.msg.WitnessList() + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } + return st.buyGas() } @@ -342,6 +361,42 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { if rules.IsBerlin { st.state.PrepareAccessList(msg.From(), msg.To(), vm.ActivePrecompiles(rules), msg.AccessList()) } + + // revive state before execution + if rules.IsElwood { + witnessList := msg.WitnessList() + for i := range witnessList { + wit := witnessList[i] + // got specify witness, verify proof and check if revive success + switch wit.WitnessType { + case types.StorageTrieWitnessType: + data, err := wit.WitnessData() + if err != nil { + return nil, err + } + stWit, ok := data.(*types.StorageTrieWitness) + if !ok { + return nil, errors.New("got StorageTrieWitnessType data error") + } + proofCaches := make([]trie.MPTProofCache, len(stWit.ProofList)) + for j := range stWit.ProofList { + proofCaches[j] = trie.MPTProofCache{ + MPTProof: stWit.ProofList[j], + } + if err := proofCaches[j].VerifyProof(); err != nil { + return nil, err + } + + // TODO revive trie nodes by witness in the same contract storage trie + // 1. check expired hash; + // 2. append trie nodes & rebuild shadow nodes; + } + default: + return nil, errors.New("unsupported WitnessType") + } + } + } + var ( ret []byte vmerr error // vm errors do not effect consensus and are therefore not assigned to err diff --git a/core/state_transition_test.go b/core/state_transition_test.go index 57f5bd7da1..a96b5ddcba 100644 --- a/core/state_transition_test.go +++ b/core/state_transition_test.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" "testing" ) @@ -16,14 +17,22 @@ func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, p proof[j] = bytes.Repeat([]byte{'p'}, proofLen) } proofList[i] = types.MPTProof{ - Key: bytes.Repeat([]byte{'k'}, keyLen), - Proof: proof, + RootKey: bytes.Repeat([]byte{'k'}, keyLen), + Proof: proof, } } + wit := types.StorageTrieWitness{ + Address: *addr, + ProofList: proofList, + } + + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } return types.ReviveWitness{ - WitnessType: types.MPTWitnessType, - Address: addr, - ProofList: proofList, + WitnessType: types.StorageTrieWitnessType, + Data: enc, } } @@ -50,7 +59,7 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: true, isHomestead: true, isEIP2028: true, - gas: 53416, + gas: 53464, }, { data: common.Hex2Bytes("1234567890"), @@ -61,7 +70,18 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: true, isHomestead: true, isEIP2028: true, - gas: 55016, + gas: 55176, + }, + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 1, 1, 0), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 55252, }, { data: nil, @@ -73,7 +93,7 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: false, isHomestead: true, isEIP2028: true, - gas: 25948, + gas: 26412, }, } diff --git a/core/tx_pool.go b/core/tx_pool.go index 8eb16dbb54..83e206c597 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -263,6 +263,7 @@ type TxPool struct { istanbul bool // Fork indicator whether we are in the istanbul stage. eip2718 bool // Fork indicator whether we are using EIP-2718 type transactions. eip1559 bool // Fork indicator whether we are using EIP-1559 type transactions. + isElwood bool // Fork indicator whether we are using BEP-216 type transactions. currentState *state.StateDB // Current state in the blockchain head pendingNonces *txNoncer // Pending state tracking virtual nonces @@ -706,13 +707,26 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } // Ensure the transaction has more gas than the basic tx fee. - intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, true, pool.istanbul) + witnessList := tx.WitnessList() + intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), witnessList, tx.To() == nil, true, pool.istanbul) if err != nil { return err } if tx.Gas() < intrGas { return ErrIntrinsicGas } + + // check witness and hard fork + if witnessList != nil { + if !pool.isElwood { + return errors.New("cannot allow witness before Elwood fork") + } + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } return nil } @@ -1433,6 +1447,7 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { pool.istanbul = pool.chainconfig.IsIstanbul(next) pool.eip2718 = pool.chainconfig.IsBerlin(next) pool.eip1559 = pool.chainconfig.IsLondon(next) + pool.isElwood = pool.chainconfig.IsElwood(next) } // promoteExecutables moves transactions that have become processable from the diff --git a/core/types/revive_state_tx.go b/core/types/revive_state_tx.go index 8cef4d2c1b..c9b9f6af8d 100644 --- a/core/types/revive_state_tx.go +++ b/core/types/revive_state_tx.go @@ -40,7 +40,9 @@ func (tx *ReviveStateTx) copy() TxData { S: new(big.Int), } - copy(cpy.WitnessList, tx.WitnessList) + for i := range tx.WitnessList { + cpy.WitnessList[i] = tx.WitnessList[i].Copy() + } if tx.Value != nil { cpy.Value.Set(tx.Value) } @@ -111,14 +113,15 @@ func (tx *ReviveStateTx) setSignatureValues(chainID, v, r, s *big.Int) { tx.V, tx.R, tx.S = v, r, s } -func WitnessIntrinsicGas(wits WitnessList) uint64 { +func WitnessIntrinsicGas(wits WitnessList) (uint64, error) { totalGas := uint64(0) for i := 0; i < len(wits); i++ { - // witness size cost totalGas += wits[i].Size() * params.TxWitnessListStorageGasPerByte - // witness verify cost - count, words := wits[i].ProofWords() - totalGas += count*params.TxWitnessListVerifyBaseGas + words*params.TxWitnessListVerifyGasPerWord + addGas, err := wits[i].AdditionalIntrinsicGas() + if err != nil { + return 0, err + } + totalGas += addGas } - return totalGas + return totalGas, nil } diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go index 7cd6298af5..cbb08ca6f9 100644 --- a/core/types/revive_witness.go +++ b/core/types/revive_witness.go @@ -1,45 +1,143 @@ package types -import "github.com/ethereum/go-ethereum/common" +import ( + "errors" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +var ( + ErrUnknownWitnessType = errors.New("unknown revive witness type") + ErrWitnessEmptyData = errors.New("witness data is empty") + ErrStorageTrieWitnessEmptyProofList = errors.New("StorageTrieWitness: empty proof list") + ErrStorageTrieWitnessEmptyInnerProof = errors.New("StorageTrieWitness: empty inner proof") + ErrStorageTrieWitnessWrongProofSize = errors.New("StorageTrieWitness: wrong proof size") +) const ( - MPTWitnessType = iota + StorageTrieWitnessType = iota ) +// MPTProof in order to degrade revive partial trie node, +// only allow on path proof, not support tree path, +// will verify the whole path later +// Attention: The proof could revive multi-vals, although it's a single trie path witness type MPTProof struct { - Key []byte // prefix key - Proof [][]byte // list of RLP-encoded nodes + RootKey []byte // root key, target the revival path root, max 32bytes + Proof [][]byte // list of RLP-encoded nodes +} + +type StorageTrieWitness struct { + Address common.Address // target account address + ProofList []MPTProof // revive multiple slots (same address) +} + +func (s *StorageTrieWitness) AdditionalIntrinsicGas() (uint64, error) { + count := 0 + words := 0 + for i := range s.ProofList { + for j := range s.ProofList[i].Proof { + count++ + words += (len(s.ProofList[i].Proof[j]) + 31) / 32 + } + } + + return uint64(count)*params.TxWitnessListVerifyMPTBaseGas + uint64(words)*params.TxWitnessListVerifyMPTGasPerWord, nil +} + +// VerifyWitness only check format, merkle proof check later +func (s *StorageTrieWitness) VerifyWitness() error { + if len(s.ProofList) == 0 { + return ErrStorageTrieWitnessEmptyProofList + } + for i := range s.ProofList { + if len(s.ProofList[i].Proof) == 0 { + return ErrStorageTrieWitnessEmptyInnerProof + } + for j := range s.ProofList[i].Proof { + // The smallest size is a valueNode, The largest size is the full fullNode + if len(s.ProofList[i].Proof[j]) < 32 || len(s.ProofList[i].Proof[j]) > 544 { + return ErrStorageTrieWitnessWrongProofSize + } + } + } + + return nil +} + +// ReviveWitnessData the common method of witness +type ReviveWitnessData interface { + // AdditionalIntrinsicGas got additional gas consumption + AdditionalIntrinsicGas() (uint64, error) + // VerifyWitness check if valid witness format, used before state revival + VerifyWitness() error } +// ReviveWitness for revive witness +// Attention: it's not thread safe type ReviveWitness struct { - WitnessType byte // only support Merkle Proof for now - Address *common.Address // target account address - ProofList []MPTProof // revive multiple slots (same address) + WitnessType byte // only support Merkle Proof for now + Data []byte // witness data, it's rlp format + cache ReviveWitnessData `rlp:"-" json:"-"` // cache if not encode to rlp or json, it used for performance } // Size estimate witness byte size func (r *ReviveWitness) Size() uint64 { - size := uint64(21) - for i := range r.ProofList { - size += uint64(len(r.ProofList[i].Key)) - for j := range r.ProofList[i].Proof { - size += uint64(len(r.ProofList[i].Proof[j])) + return uint64(len(r.Data) + 1) +} + +// Copy deep copy +func (r *ReviveWitness) Copy() ReviveWitness { + witness := ReviveWitness{ + WitnessType: r.WitnessType, + Data: make([]byte, len(r.Data)), + } + copy(witness.Data, r.Data) + return witness +} + +func (r *ReviveWitness) WitnessData() (ReviveWitnessData, error) { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return nil, err } } + return r.cache, nil +} - return size +func (r *ReviveWitness) AdditionalIntrinsicGas() (uint64, error) { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return 0, err + } + } + return r.cache.AdditionalIntrinsicGas() } -// ProofWords get proof count and words -func (r *ReviveWitness) ProofWords() (uint64, uint64) { - count := uint64(0) - words := uint64(0) - for i := range r.ProofList { - for j := range r.ProofList[i].Proof { - count++ - words += uint64((len(r.ProofList[i].Proof[j]) + 31) / 32) +func (r *ReviveWitness) VerifyWitness() error { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return err + } + } + return r.cache.VerifyWitness() +} + +func (r *ReviveWitness) parseWitness() error { + if len(r.Data) == 0 { + return ErrWitnessEmptyData + } + switch r.WitnessType { + case StorageTrieWitnessType: + var cache StorageTrieWitness + if err := rlp.DecodeBytes(r.Data, &cache); err != nil { + return err } + r.cache = &cache + default: + return ErrUnknownWitnessType } - return count, words + return nil } diff --git a/core/types/revive_witness_test.go b/core/types/revive_witness_test.go new file mode 100644 index 0000000000..ad7a508675 --- /dev/null +++ b/core/types/revive_witness_test.go @@ -0,0 +1,126 @@ +package types + +import ( + "bytes" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" + "testing" +) + +func makeSimpleReviveWitness(witType byte, data []byte) ReviveWitness { + return ReviveWitness{ + WitnessType: witType, + Data: data, + } +} + +func makeReviveWitnessFromStorageTrieWitness(wit StorageTrieWitness) ReviveWitness { + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } + return ReviveWitness{ + WitnessType: StorageTrieWitnessType, + Data: enc, + } +} + +func makeStorageTrieWitness(addr common.Address, proofCount int, proofLen ...int) StorageTrieWitness { + proofList := make([]MPTProof, proofCount) + for i := 0; i < proofCount; i++ { + proof := make([][]byte, len(proofLen)) + for j := range proofLen { + proof[j] = bytes.Repeat([]byte{'f'}, proofLen[j]) + } + proofList[i] = MPTProof{ + RootKey: nil, + Proof: proof, + } + } + wit := StorageTrieWitness{ + Address: addr, + ProofList: proofList, + } + + return wit +} + +func TestVerifyWitness(t *testing.T) { + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + testData := []struct { + wit StorageTrieWitness + expect error + }{ + { + wit: makeStorageTrieWitness(addr, 0, 0), + expect: ErrStorageTrieWitnessEmptyProofList, + }, + { + wit: makeStorageTrieWitness(addr, 1, 31), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeStorageTrieWitness(addr, 1, 545), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeStorageTrieWitness(addr, 1, 32), + expect: nil, + }, + { + wit: makeStorageTrieWitness(addr, 1, 544), + expect: nil, + }, + { + wit: makeStorageTrieWitness(addr, 10, 32, 544, 33, 543, 128, 99), + expect: nil, + }, + } + + for i := range testData { + assert.Equal(t, testData[i].expect, testData[i].wit.VerifyWitness(), i) + } +} + +func TestReviveWitness_VerifyWitness(t *testing.T) { + + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + testData := []struct { + wit ReviveWitness + expect error + }{ + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 0, 0)), + expect: ErrStorageTrieWitnessEmptyProofList, + }, + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 1, 31)), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeSimpleReviveWitness(0, nil), + expect: ErrWitnessEmptyData, + }, + { + wit: makeSimpleReviveWitness(1, bytes.Repeat([]byte{'e'}, 10)), + expect: ErrUnknownWitnessType, + }, + { + wit: makeSimpleReviveWitness(0, bytes.Repeat([]byte{'e'}, 10)), + expect: nil, + }, + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 10, 33, 544, 99)), + expect: nil, + }, + } + + for i := range testData { + if i == 4 { + assert.Error(t, testData[i].wit.VerifyWitness()) + continue + } + assert.Equal(t, testData[i].expect, testData[i].wit.VerifyWitness(), i) + } +} diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 90d0181701..b48f30f6d6 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -575,15 +575,23 @@ func TestReviveStateTxAndSigner(t *testing.T) { from = crypto.PubkeyToAddress(key.PublicKey) addr = common.HexToAddress("0x0000000000000000000000000000000000000001") recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") - witness = WitnessList{{ - WitnessType: 0, - Address: &addr, - ProofList: []MPTProof{{ - Key: common.Hex2Bytes("095e7baea6a6c7c4c2"), - Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, - }}, - }} ) + wit := StorageTrieWitness{ + Address: addr, + ProofList: []MPTProof{{ + RootKey: common.Hex2Bytes("095e7baea6a6c7c4c2"), + Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, + }}, + } + + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } + witness := WitnessList{{ + WitnessType: 0, + Data: enc, + }} for i := uint64(0); i < 500; i++ { var txdata TxData switch i % 5 { diff --git a/light/txpool.go b/light/txpool.go index 8382514b13..1b819794c1 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -18,6 +18,7 @@ package light import ( "context" + "errors" "fmt" "math/big" "sync" @@ -69,6 +70,7 @@ type TxPool struct { istanbul bool // Fork indicator whether we are in the istanbul stage. eip2718 bool // Fork indicator whether we are in the eip2718 stage. + isElwood bool // Fork indicator whether we are in the BEP-216 stage. } // TxRelayBackend provides an interface to the mechanism that forwards transacions @@ -318,6 +320,7 @@ func (pool *TxPool) setNewHead(head *types.Header) { next := new(big.Int).Add(head.Number, big.NewInt(1)) pool.istanbul = pool.config.IsIstanbul(next) pool.eip2718 = pool.config.IsBerlin(next) + pool.isElwood = pool.config.IsElwood(next) } // Stop stops the light transaction pool @@ -385,13 +388,26 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error } // Should supply enough intrinsic gas - gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, true, pool.istanbul) + witnessList := tx.WitnessList() + gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), witnessList, tx.To() == nil, true, pool.istanbul) if err != nil { return err } if tx.Gas() < gas { return core.ErrIntrinsicGas } + + // check witness and hard fork + if witnessList != nil { + if !pool.isElwood { + return errors.New("cannot allow witness before Elwood fork") + } + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } return currentState.Error() } diff --git a/params/config.go b/params/config.go index b123bd187f..22812c7ec3 100644 --- a/params/config.go +++ b/params/config.go @@ -782,6 +782,8 @@ type Rules struct { IsNano bool IsMoran bool IsPlanck bool + IsClaude bool + IsElwood bool } // Rules ensures c's ChainID is not nil. @@ -806,5 +808,7 @@ func (c *ChainConfig) Rules(num *big.Int, isMerge bool) Rules { IsNano: c.IsNano(num), IsMoran: c.IsMoran(num), IsPlanck: c.IsPlanck(num), + IsClaude: c.IsClaude(num), + IsElwood: c.IsElwood(num), } } diff --git a/params/protocol_params.go b/params/protocol_params.go index 67bc3aa4a2..e3f3aa0ae5 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -86,13 +86,13 @@ const ( SelfdestructRefundGas uint64 = 24000 // Refunded following a selfdestruct operation. MemoryGas uint64 = 3 // Times the address of the (highest referenced byte in memory + 1). NOTE: referencing happens on read, write and in instructions such as RETURN and CALL. - TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. - TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) - TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list - TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list - TxWitnessListStorageGasPerByte uint64 = 16 // Per byte gas in BEP-215 witness list - TxWitnessListVerifyBaseGas uint64 = 60 // Base gas in BEP-215 witness list verify - TxWitnessListVerifyGasPerWord uint64 = 12 // Per-word price in BEP-215 witness list verify + TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. + TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) + TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list + TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxWitnessListStorageGasPerByte uint64 = 16 // Per byte gas in BEP-215 witness list + TxWitnessListVerifyMPTBaseGas uint64 = 60 // Base gas in BEP-215 witness list verify, refer to EVM Sha256 gas + TxWitnessListVerifyMPTGasPerWord uint64 = 12 // Per-word price in BEP-215 witness list verify, refer to EVM Sha256 gas // These have been changed during the course of the chain CallGasFrontier uint64 = 40 // Once per CALL operation & message call transaction. diff --git a/trie/database.go b/trie/database.go index db465d4e9e..52469352cf 100644 --- a/trie/database.go +++ b/trie/database.go @@ -109,6 +109,10 @@ func (n rawNode) EncodeRLP(w io.Writer) error { return err } +func (n rawNode) nodeType() int { + return rawNodeType +} + // rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. @@ -117,6 +121,10 @@ type rawFullNode [17]node func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawFullNode) nodeType() int { + return rawFullNodeType +} + func (n rawFullNode) EncodeRLP(w io.Writer) error { eb := rlp.NewEncoderBuffer(w) n.encode(eb) @@ -134,6 +142,10 @@ type rawShortNode struct { func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawShortNode) nodeType() int { + return rawShortNodeType +} + // cachedNode is all the information we know about a single cached trie node // in the memory database write layer. type cachedNode struct { diff --git a/trie/node.go b/trie/node.go index 6ce6551ded..b50324294f 100644 --- a/trie/node.go +++ b/trie/node.go @@ -25,17 +25,32 @@ import ( "github.com/ethereum/go-ethereum/rlp" ) +const ( + BranchNodeLength = 17 +) + +const ( + shortNodeType = iota + fullNodeType + hashNodeType + valueNodeType + rawNodeType + rawShortNodeType + rawFullNodeType +) + var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} type node interface { cache() (hashNode, bool) encode(w rlp.EncoderBuffer) fstring(string) string + nodeType() int } type ( fullNode struct { - Children [17]node // Actual trie node data to encode/decode (needs custom encoder) + Children [BranchNodeLength]node // Actual trie node data to encode/decode (needs custom encoder) flags nodeFlag } shortNode struct { @@ -99,6 +114,22 @@ func (n valueNode) fstring(ind string) string { return fmt.Sprintf("%x ", []byte(n)) } +func (n *shortNode) nodeType() int { + return shortNodeType +} + +func (n *fullNode) nodeType() int { + return fullNodeType +} + +func (n hashNode) nodeType() int { + return hashNodeType +} + +func (n valueNode) nodeType() int { + return valueNodeType +} + // mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered. func mustDecodeNode(hash, buf []byte) node { n, err := decodeNode(hash, buf) diff --git a/trie/proof.go b/trie/proof.go index 63c3d22760..f532b8d34b 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -20,6 +20,7 @@ import ( "bytes" "errors" "fmt" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" @@ -78,7 +79,7 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri return t.trie.Prove(key, fromLevel, proofDb) } -// ProveStorageWitness constructs a merkle proof for a storage key. If the prefix key is specified, +// ProveStorageWitness constructs a merkle proof for a storage key. If the prefix key is specified, // the proof will start from the node that contains the prefix key to get the partial proof. // The result contains all encoded nodes from the starting node to the node that contains // the value. The value itself is also included in the last node and can be retrieved by @@ -103,7 +104,7 @@ func (t *Trie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethd // traverse through the suffix key _, err = t.traverseNodes(startNode, key, &nodes) - if err != nil{ + if err != nil { return err } @@ -159,7 +160,147 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } -// VerifyStorageWitness checks a merkle proof for a storage key. If the prefix key is specified, +// MPTProofNub include fullNode shortNode, revive n1 first if exist, +// revive n2 later if exist, include node hash +type MPTProofNub struct { + RootHexKey []byte // root hex key, max 64bytes + n1 node + n2 node +} + +type MPTProofCache struct { + types.MPTProof + + cacheHexPath [][]byte // cache path for performance + cacheHashes [][]byte // cache hash for performance + cacheNodes []node // cache node for performance + cacheNubs []*MPTProofNub // cache proof nubs to check revive duplicate +} + +// VerifyProof verify proof in MPT witness +// 1. calculate hash +// 2. decode trie node +// 3. verify partial merkle proof of the witness +// 4. split to partial witness +// TODO later revive state could revive KV from fullNode[0-15] or fullNode[16] shortNode.Val, return KVs for cache & snap +// another easy method is that revive direct to Trie/Trie cache, query later and set to KV cache +func (m *MPTProofCache) VerifyProof() error { + m.cacheHashes = make([][]byte, len(m.Proof)) + m.cacheNodes = make([]node, len(m.Proof)) + m.cacheHexPath = make([][]byte, len(m.Proof)) + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + var child []byte + for i := len(m.Proof) - 1; i >= 0; i-- { + m.cacheHashes[i] = hasher.hashData(m.Proof[i]) + n, err := decodeNode(m.cacheHashes[i], m.Proof[i]) + if err != nil { + return err + } + m.cacheNodes[i] = n + + switch t := n.(type) { + case *shortNode: + m.cacheHexPath[i] = t.Key + if err := matchHashNodeInShortNode(child, t); err != nil { + return err + } + case *fullNode: + index, err := matchHashNodeInFullNode(child, t) + if err != nil { + return err + } + if index >= 0 { + m.cacheHexPath[i] = []byte{byte(index)} + } + case valueNode: + if child != nil { + return errors.New("proof wrong child in valueNode") + } + default: + return fmt.Errorf("proof got wrong trie node: %v", t.nodeType()) + } + + child = m.cacheHashes[i] + } + + // cache proof nubs + m.cacheNubs = make([]*MPTProofNub, 0, len(m.Proof)) + prefix := keybytesToHex(m.RootKey) + prefix = prefix[:len(prefix)-1] + for i := 0; i < len(m.cacheNodes); i++ { + prefix = append(prefix, m.cacheHexPath[i]...) + n1 := m.cacheNodes[i] + nub := MPTProofNub{ + RootHexKey: prefix, + n1: n1, + n2: nil, + } + if needMergeNextNode(m.cacheNodes, i) { + i++ + prefix = append(prefix, m.cacheHexPath[i]...) + nub.n2 = m.cacheNodes[i] + } + m.cacheNubs = append(m.cacheNubs, &nub) + } + + return nil +} + +func needMergeNextNode(nodes []node, i int) bool { + if i >= len(nodes) || i+1 >= len(nodes) { + return false + } + + n1 := nodes[i] + n2 := nodes[i+1] + + if n2.nodeType() == valueNodeType { + return true + } + + // check extended node + if n1.nodeType() == shortNodeType { + return true + } + + return false +} + +func matchHashNodeInFullNode(child []byte, n *fullNode) (int, error) { + if child == nil { + return -1, nil + } + + for i := 0; i < BranchNodeLength-1; i++ { + switch v := n.Children[i].(type) { + case hashNode: + if bytes.Equal(child, v) { + return i, nil + } + } + } + return -1, errors.New("proof cannot find target child in fullNode") +} + +func matchHashNodeInShortNode(child []byte, n *shortNode) error { + if child == nil { + return nil + } + + switch v := n.Val.(type) { + case *hashNode: + if !bytes.Equal(child, *v) { + return errors.New("proof wrong child in shortNode") + } + default: + return errors.New("proof must hashNode when meet shortNode") + } + return nil +} + +// VerifyStorageWitness checks a merkle proof for a storage key. If the prefix key is specified, // it will traverse down to the node that contains the prefix key. From there, proof will be verified. // VerifyStorageProof returns an error if the proof contains invalid trie nodes. func (t *Trie) VerifyStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { @@ -167,7 +308,7 @@ func (t *Trie) VerifyStorageWitness(key []byte, prefixKeyHex []byte, proofDb eth if len(key) == 0 { return nil, fmt.Errorf("empty key provided") } - + key = keybytesToHex(key) tn := t.root @@ -216,7 +357,7 @@ func (t *Trie) VerifyStorageWitness(key []byte, prefixKeyHex []byte, proofDb eth func (t *Trie) traverseNodes(tn node, key []byte, nodes *[]node) (node, error) { for len(key) > 0 && tn != nil { switch n := tn.(type) { - case *shortNode: + case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { // The trie doesn't contain the key. tn = nil @@ -224,13 +365,13 @@ func (t *Trie) traverseNodes(tn node, key []byte, nodes *[]node) (node, error) { tn = n.Val key = key[len(n.Key):] } - if nodes != nil{ + if nodes != nil { *nodes = append(*nodes, n) } case *fullNode: tn = n.Children[key[0]] key = key[1:] - if nodes != nil{ + if nodes != nil { *nodes = append(*nodes, n) } case hashNode: diff --git a/trie/proof_test.go b/trie/proof_test.go index 4435821cd3..34b81b837b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -20,8 +20,11 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" mrand "math/rand" "sort" + "strings" "testing" "time" @@ -894,7 +897,7 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { // TestStorageProof tests the storage proof generation and verification. // This test will also test for partial proof generation and verification. -func TestStorageProof(t *testing.T){ +func TestStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { prefixKeys := getPrefixKeys(trie, []byte(kv.k)) @@ -918,7 +921,7 @@ func TestStorageProof(t *testing.T){ // TestOneElementStorageProof tests the storage proof generation and verification // for a trie with only one element. -func TestOneElementStorageProof(t *testing.T){ +func TestOneElementStorageProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") @@ -945,7 +948,7 @@ func TestOneElementStorageProof(t *testing.T){ // TestEmptyStorageProof tests storage verification with empty proof. // The verifier should nil for both value and error. -func TestEmptyStorageProof(t *testing.T){ +func TestEmptyStorageProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") @@ -953,14 +956,14 @@ func TestEmptyStorageProof(t *testing.T){ key := []byte("k") val, err := trie.VerifyStorageWitness(key, nil, proof) - if val != nil && err != nil{ + if val != nil && err != nil { t.Fatalf("expected nil value and error for empty proof") } } // TestEmptyKeyStorageProof tests the storage proof with empty key. -// The prover is expected to return -func TestEmptyKeyStorageProof(t *testing.T){ +// The prover is expected to return +func TestEmptyKeyStorageProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") @@ -973,7 +976,7 @@ func TestEmptyKeyStorageProof(t *testing.T){ // TestEmptyPrefixKeyStorageProof tests the storage proof with empty prefix key, // which means that all proofs generated are from the root node. -func TestEmptyPrefixKeyStorageProof(t *testing.T){ +func TestEmptyPrefixKeyStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { proof := memorydb.New() @@ -995,7 +998,7 @@ func TestEmptyPrefixKeyStorageProof(t *testing.T){ // TestBadStorageProof tests a few cases which the proof is wrong. // The proof is expected to detect the error. -func TestBadStorageProof(t *testing.T){ +func TestBadStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { @@ -1020,11 +1023,11 @@ func TestBadStorageProof(t *testing.T){ itVal, _ := proof.Get(itKey) proof.Delete(itKey) it.Release() - + mutateByte(itVal) proof.Put(crypto.Keccak256(itVal), itVal) - - if val, err := trie.VerifyStorageWitness(key, prefixKey, proof); err == nil && val != nil{ + + if val, err := trie.VerifyStorageWitness(key, prefixKey, proof); err == nil && val != nil { t.Fatalf("expected proof to fail for key: %x, prefix: %x", key, prefixKey) } } @@ -1033,7 +1036,7 @@ func TestBadStorageProof(t *testing.T){ // TestBadKeyStorageProof tests the storage proof with a bad key. // The verifier is expected to return nil for both value and error. -func TestBadKeyStorageProof(t *testing.T){ +func TestBadKeyStorageProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") @@ -1042,14 +1045,14 @@ func TestBadKeyStorageProof(t *testing.T){ trie.ProveStorageWitness(key, nil, proof) val, err := trie.VerifyStorageWitness(key, nil, proof) - if val != nil && err != nil{ + if val != nil && err != nil { t.Fatalf("expected nil value and error for bad key") } } // TestBadPrefixKeyStorageProof tests the storage proof with a bad prefix key. // The verifier is expected to return nil for both value and error. -func TestBadPrefixKeyStorageProof(t *testing.T){ +func TestBadPrefixKeyStorageProof(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") @@ -1061,19 +1064,19 @@ func TestBadPrefixKeyStorageProof(t *testing.T){ trie.ProveStorageWitness(key, prefixKey, proof) val, err := trie.VerifyStorageWitness(key, prefixKey, proof) - if val != nil && err != nil{ + if val != nil && err != nil { t.Fatalf("expected nil value and error for bad prefix key") } } // TestKeyPrefixKeySame tests the storage proof with the same key and prefix key. // The proof size should be 0 and the verifier should return nil for both value and error. -func TestKeyPrefixKeySame(t *testing.T){ +func TestKeyPrefixKeySame(t *testing.T) { trie := new(Trie) updateString(trie, "k", "v") proof := memorydb.New() - key := []byte("k") + key := []byte("k") trie.ProveStorageWitness(key, key, proof) if proof.Len() != 0 { @@ -1081,7 +1084,7 @@ func TestKeyPrefixKeySame(t *testing.T){ } val, err := trie.VerifyStorageWitness(key, key, proof) - if val != nil && err != nil{ + if val != nil && err != nil { t.Fatalf("expected nil value and error for same key and prefix key") } } @@ -1099,7 +1102,6 @@ func TestUnexpiredStorageProof(t *testing.T) { "abdf": "D", } - unexpiredData := map[string]string{ "defg": "E", "defh": "F", @@ -1342,13 +1344,13 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { } } -func getPrefixKeys(t *Trie, key []byte) ([][]byte){ +func getPrefixKeys(t *Trie, key []byte) [][]byte { var prefixKeys [][]byte key = keybytesToHex(key) tn := t.root for len(key) > 0 && tn != nil { switch n := tn.(type) { - case *shortNode: + case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { // The trie doesn't contain the key. tn = nil @@ -1357,7 +1359,7 @@ func getPrefixKeys(t *Trie, key []byte) ([][]byte){ // Check if there is a previous key in prefixKeys if len(prefixKeys) == 0 { prefixKeys = append(prefixKeys, n.Key) - } else{ + } else { prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], n.Key...)) } key = key[len(n.Key):] @@ -1366,7 +1368,7 @@ func getPrefixKeys(t *Trie, key []byte) ([][]byte){ tn = n.Children[key[0]] if len(prefixKeys) == 0 { prefixKeys = append(prefixKeys, key[:1]) - } else{ + } else { prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], key[:1]...)) } key = key[1:] @@ -1382,4 +1384,41 @@ func getPrefixKeys(t *Trie, key []byte) ([][]byte){ } return prefixKeys -} \ No newline at end of file +} + +func TestMPTProofCache_VerifyProof_normalCase(t *testing.T) { + cache := makeMPTProofCache(nil, []string{ + "0xf90211a03697534056039e03300557bd69fe16e18ce4a6ccd5522db4dfa97dfe1fad3d3aa0b1bf1f230b98b9034738d599177ae817c08143b9395a47f300636b0dd2fb3c5ea0aa04a4966751d4c50063fe13a96a6c7924f665819733f556849b5eb9fa1d6839a0e162e080d1c12c59dc984fb2246d8ad61209264bee40d3fdd07c4ea4ff411b6aa0e5c3f2dde71bf303423f34674748567dcdf8379129653b8213f698468738d492a068a3e3059b6e7115a055a7874f81c5a1e84ddc1967527973f8c78cd86a1c9f8fa0d734bd63b7be8e8471091b792f5bbcbc7b0ce582f6d985b7a15a3c0155242c56a00143c06f57a65c8485dbae750aa51df5dff1bf7bdf28060129a20de9e51364eda07b416f79b3f4e39d0159efff351009d44002d9e83530fb5a5778eb55f5f4432ca036706b52196fa0b73feb2e7ff8f1379c7176d427dd44ad63c7b65e66693904a1a0fd6c8b815e2769ce379a20eaccdba1f145fb11f77c280553f15ee4f1ee135375a02f5233009f082177e5ed2bfa6e180bf1a7310e6bc3c079cb85a4ac6fee4ae379a03f07f1bb33fa26ebd772fa874914dc7a08581095e5159fdcf9221be6cbeb6648a097557eec1ac08c3bfe45ce8e34cd329164a33928ac83fef1009656536ef6907fa028196bfb31aa7f14a0a8000b00b0aa5d09450c32d537e45eebee70b14313ff1ca0126ce265ca7bbb0e0b01f068d1edef1544cbeb2f048c99829713c18d7abc049a80", + "0xf90211a01f7c019858f447dbff8ed8e4329e88600bab8f17fec9594c664e25acc95da3dba00916866833439bc250b5e08edc2b7c041634ca3b33a013f79eb299a3e33a056fa0a8fae921061f3bc81154b5d2c149d3860a3cdef00d02173ee3b5837de6b26f56a097583621a54e74994619a97cf82f823005a35ad1bf4795047726619eccd11ad4a0be39789c4abdb2a185cce40f2c77575eee2a1eab38d3168395560da15307614aa0c46a7fcea5501656d70508178f0731460edcce1c01e32ed08c1468f2593db277a0460f030038b09b8461d834ded79d77fd3ed25c4e248775752bf1c830a530ef2ba03d887abec623c6b4d93be75d1608dcacb291bf30406fbbc944a94aa4203fb1eba0475eeb07d471044af74313093cfbb0201e17a405d7af13a7ae6b245e18915515a0ea794035230f90e14f4f601d0e9f04217ed02710df363417bc3dd7dda5a84608a08aa9f0c44e5a9e359b65d4f0e40937662f07af10642a626e19ca7bc56c329706a05677c9b342c1cbcfd491cd491800d44423d1bdc08ab00eb59376a7118f7c4e23a0b761f22d67328e6c90caf8be65affb701b045f4f1f581472fc5f724e0c61328da0b6727240643009e59bba78249918a83ca8faeaf1d5b47fe0b41c356e8ae300dda0d7cbeb12faf439126bdb86f94b6262c3a70e88e4d8a47b49735f0b9d632a5df8a0428fe930556ce5ea94bcc4d092fe0f05a2a9b360175857711729ab3a7092a81780", + "0xf90211a0a81ff74945250ba9925753e379aa8815a3fe77926aad2d02db4d78a3da7ecb48a0e795c64d2b738b34ebb77a3307df800c9f1fe324b442cd71ffa5fd268ec12ffca020a5f968a1c8292d08135cb451daed999a44eb0cdd04526a89c38e753398c50ca0cc7165f6b984f2f2569d101d70f72af94eb79b18c126895da2d3cf557b5aca51a00a10f4dc5851e71f195cca5bd7a2a62a6c3ca03d95e91c7dc55e55f4cb726903a09ecaaf877b18a55ca4001fe46cc10389c94104abc2994cdfe5843f28814b119ca099f59b2b52ee9d9b44ab7123a86775dcfe6c50301dd7cf9c6fc6ea968c1d2a01a0fccda5c1489dd3268fcf2471f6a372fc3afe5eebecc0ba9fc3c023d85da26aeda0d730ea7aabdad2e5451e826726e53c86e6e805e46220bc7edd80bf3f7e467f96a074f3d767a84557aa9559b08b2d1f93d8205827c042d1ac616b8cf37e22de2beca03ca1fe479ea1e64c3bfeae9c603ef4f55c270de0bb4c79d045f67d3481ca2852a0bdcbc3d154db40b5faa1f6875f46d485b96bdfffb4513da631f9b302768faae7a096923f4f559c7b7f912587292cc865aba5934e9e75c9aa88f20de73e928c6c51a0f52ca9710977327dd407ba9d9f00b0075d7e312ee19686f1b53579585ff50132a076adb4cf98f9af261d1b2a147b9b2cbace1647eab537ca1a55e8d40d35775b74a0723fa2f72d6a983939bc9bf9ee22e05fc142437726450e5707c60155bc34ec0580", + "0xf90211a039a64fd9b31f3ea3bf9a991ca828eadb67cf9ae0ebbcdd195d297454699580e8a0c95d85a63beb9a02b56d032116d86fdf9a64065dd5d44e33acef674ae3ceb6f9a0d83ef07a99302abccdce1289d353a20494713f45d8edbd3c2e87f08788a878d9a063a19aee41fec40a98edcf4d60c759c509b512bfc5d9feae7de50c8c2d00eee7a029ce5c8c1ac9939cbf481274b8e6f5f24a430136dc5aab7deb489a9ce7db5a95a03b45c53d2e4f54e49a53eb298aee828f4803581ee60ca52940533b77c3e3fadfa082b8084dfc0337a49fd5d53ce107fa0bdcdf25ac7bbac017d7ac250b77c8764da0d4dab90004fd3bb36b2fa5f6e914a7c90820d119e5ba3ba8439270c6cfbd0a18a0bd9286c9248ae8a953d00aa9906f06b6f0364d0bb9fb615de04c9f8d5b5fd346a00955eccaa41a17fafb0ed66272b5183b3a30a973af7a43e0f25f0b640fd5df0ca0a7ba773602ace05991211770fc5555dd9c55e2518bcb005d1db022584a89132da0450b066588b44701992ab4a926d5dd185dd465abf1e51a3f60ab4a9f924cf85aa05eda0687636512339d0db99eade61c0d44358bb8027185296ad59b4cedd2935aa0cd18604dee296e5443b598e470a04efde04fbb0213c0fda8cc3af20dae2a1c34a01c0474c1bf15e4732d1c22a1303573c7ec8643f711354e853ab26038b2eebe25a081a0b2d329a9519375f71384a7130faafc3a1379db59bfb7db7135c9d27d9cc080", + "0xf901f1a0e14d9caa85464966f0371f427250e7bf2d86f6d41535b08ea79044391d6f0fe7a08bea233c92e4c1d05bd03d62cd93974cc29f6df54e6fa2574bf8dfb3af936a85a0dd4f4af9ab72d1bbbf59c286a731614f48bf3cdcef572cd819426fc2a30ae5d9a058465ea8e97b2f3a873646f9504aff111ab967abaea8ebb2fe91bf058ea162f3a06466c41eb770bae5c07c26cd692540e7f9af70fee2bc164e12f29083ebd2cca1a0fe0925da033ffb967ca1e1d6505c2ca740553916c3b9a205131d719b7921fb00a0973ecee958d1305f1b6f8159a9732e47f5fbd4121f60254ea44a8f028308150780a0f81717c7f8702ed39d89a34fc86827aab31db26b22d22ee399667e4a44081c95a0e276ec24ee74ee61bdbe92e8afa57813d59f26e0ad34f24d71f0627719f8a11ea00a8ec2f6922480890f9e67c62c1654dcccf1710a01a378f614e02a3785d42d7ea08c43cc0c6c690dd87cf42691e9ead799c5ddbcfc9458ebd054a46acecedbe9f7a0d3727ed077c60ed6ff1d9d5d20fdf2912a513942d0efee9ec2d97d33ab9b7f56a03f293b0c0f25b9aa2735c4e42b132008d8af2ecfaaefdc940dc94cf312f5a1dda04860bc5f19829bb277554927c5b6dbc0af93025aca49aed3be020a3080996a26a0e47475bf4f62364d8a0dd87f2c21d0b64ef4d9aaf5fc46d9b1e3af5b83f56d4580", + "0xf89180a0dd3524693059f48ce47c5dd82d0462953a5141c7abc5a63981637b71e0100bcf80a033ca98826ea65c1c4be16e1df12e78142d992bf1fc0189401632ceff3fcbbe7880a0e5daf803b0890d164e287fa43b51332286cddeb9b62280426037fc02dc2b4d5f8080a0bfa907c2e30720a07347d23cfa573dec8c9b3afa51a4a0e0783a481da0fcb1b38080808080808080", + "0xe79e208246cec5810061f4ff7efe1dcd6cb407d59abc3478830df04484584c868786d647b234389e", + }) + + err := cache.VerifyProof() + assert.NoError(t, err) + assert.Equal(t, common.Hex2Bytes("3310913fe74cfb66dbde8fe8557b48e8e65617a17c2375a581c32d49f812cde4"), cache.cacheHashes[0]) + key := common.Hex2Bytes("95eea00c49d14a895954837cd876ffa8cfad96cbaacc40fc31d6df2c902528a8") + hash := make([]byte, common.HashLength) + h := newHasher(false) + h.sha.Reset() + h.sha.Write(key) + h.sha.Read(hash) + assert.Equal(t, hash, hexToKeybytes(cache.cacheNubs[6].RootHexKey)) +} + +func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { + + proof := make([][]byte, len(proofs)) + for i := range proofs { + proof[i] = common.Hex2Bytes(strings.TrimPrefix(proofs[i], "0x")) + } + return MPTProofCache{ + MPTProof: types.MPTProof{ + RootKey: key, + Proof: proof, + }, + } +} From 63190702b504be6426f406753f190bfaee78445b Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 12 Apr 2023 16:25:59 +0800 Subject: [PATCH 20/51] witness: MPT revive refactor(trie): change prefixKey to prefixKeyHex feat(trie): add ReviveTrie() refactor(MPTProof): change RootKey to RootKeyHex refactor: fix VerifyProof and ExpireByPrefix bugs feat(revive): add ReviveTrie and UTs minor: edit some comments refactor(trie): add resolveHash and comments --- core/state/database.go | 1 + core/state/statedb.go | 37 +++ core/state_transition.go | 33 +-- core/state_transition_test.go | 13 +- core/types/revive_witness.go | 2 +- core/types/revive_witness_test.go | 2 +- core/types/transaction_test.go | 2 +- core/vm/interface.go | 1 + light/trie.go | 4 + trie/dummy_trie.go | 4 + trie/proof.go | 14 +- trie/proof_test.go | 8 +- trie/secure_trie.go | 4 + trie/trie.go | 122 ++++++++-- trie/trie_test.go | 363 ++++++++++++++++++++++++++++++ 15 files changed, 550 insertions(+), 60 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 37430062e8..4c071e2589 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -129,6 +129,7 @@ type Trie interface { ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error + ReviveTrie(trie.MPTProofCache) error } // NewDatabase creates a backing store for state. The returned database is safe for diff --git a/core/state/statedb.go b/core/state/statedb.go index c232537f5e..c31d6961be 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -1753,3 +1753,40 @@ func (s *StateDB) GetDirtyAccounts() []common.Address { func (s *StateDB) GetStorage(address common.Address) *sync.Map { return s.storagePool.getStorage(address) } + +// ReviveTrie revive a trie with a given witness list +func (s *StateDB) ReviveTrie(witnessList types.WitnessList) error{ + for i := range witnessList { + wit := witnessList[i] + // got specify witness, verify proof and check if revive success + switch wit.WitnessType { + case types.StorageTrieWitnessType: + data, err := wit.WitnessData() + if err != nil { + return err + } + stWit, ok := data.(*types.StorageTrieWitness) + if !ok { + return errors.New("got StorageTrieWitnessType data error") + } + proofCaches := make([]trie.MPTProofCache, len(stWit.ProofList)) + for j := range stWit.ProofList { + proofCaches[j] = trie.MPTProofCache{ + MPTProof: stWit.ProofList[j], + } + if err := proofCaches[j].VerifyProof(); err != nil { + return err + } + + trie := s.StorageTrie(stWit.Address) + if err := trie.ReviveTrie(proofCaches[j]); err != nil { + return err + } + } + default: + return errors.New("unsupported WitnessType") + } + } + + return nil +} \ No newline at end of file diff --git a/core/state_transition.go b/core/state_transition.go index 2f1554bc08..3699be1fb1 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -19,7 +19,6 @@ package core import ( "errors" "fmt" - "github.com/ethereum/go-ethereum/trie" "math" "math/big" @@ -364,37 +363,7 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { // revive state before execution if rules.IsElwood { - witnessList := msg.WitnessList() - for i := range witnessList { - wit := witnessList[i] - // got specify witness, verify proof and check if revive success - switch wit.WitnessType { - case types.StorageTrieWitnessType: - data, err := wit.WitnessData() - if err != nil { - return nil, err - } - stWit, ok := data.(*types.StorageTrieWitness) - if !ok { - return nil, errors.New("got StorageTrieWitnessType data error") - } - proofCaches := make([]trie.MPTProofCache, len(stWit.ProofList)) - for j := range stWit.ProofList { - proofCaches[j] = trie.MPTProofCache{ - MPTProof: stWit.ProofList[j], - } - if err := proofCaches[j].VerifyProof(); err != nil { - return nil, err - } - - // TODO revive trie nodes by witness in the same contract storage trie - // 1. check expired hash; - // 2. append trie nodes & rebuild shadow nodes; - } - default: - return nil, errors.New("unsupported WitnessType") - } - } + st.state.ReviveTrie(msg.WitnessList()) } var ( diff --git a/core/state_transition_test.go b/core/state_transition_test.go index a96b5ddcba..6d399b84ef 100644 --- a/core/state_transition_test.go +++ b/core/state_transition_test.go @@ -9,6 +9,17 @@ import ( "testing" ) +func keybytesToHex(str []byte) []byte { + l := len(str)*2 + 1 + var nibbles = make([]byte, l) + for i, b := range str { + nibbles[i*2] = b / 16 + nibbles[i*2+1] = b % 16 + } + nibbles[l-1] = 16 + return nibbles +} + func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, proofLen int) types.ReviveWitness { proofList := make([]types.MPTProof, witSize) for i := range proofList { @@ -17,7 +28,7 @@ func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, p proof[j] = bytes.Repeat([]byte{'p'}, proofLen) } proofList[i] = types.MPTProof{ - RootKey: bytes.Repeat([]byte{'k'}, keyLen), + RootKeyHex: keybytesToHex(bytes.Repeat([]byte{'k'}, keyLen)), Proof: proof, } } diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go index cbb08ca6f9..17672970d5 100644 --- a/core/types/revive_witness.go +++ b/core/types/revive_witness.go @@ -24,7 +24,7 @@ const ( // will verify the whole path later // Attention: The proof could revive multi-vals, although it's a single trie path witness type MPTProof struct { - RootKey []byte // root key, target the revival path root, max 32bytes + RootKeyHex []byte // prefix key in nibbles format, max 65 bytes. TODO: optimize witness size Proof [][]byte // list of RLP-encoded nodes } diff --git a/core/types/revive_witness_test.go b/core/types/revive_witness_test.go index ad7a508675..255a73e5f7 100644 --- a/core/types/revive_witness_test.go +++ b/core/types/revive_witness_test.go @@ -34,7 +34,7 @@ func makeStorageTrieWitness(addr common.Address, proofCount int, proofLen ...int proof[j] = bytes.Repeat([]byte{'f'}, proofLen[j]) } proofList[i] = MPTProof{ - RootKey: nil, + RootKeyHex: nil, Proof: proof, } } diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index b48f30f6d6..cebdc56951 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -579,7 +579,7 @@ func TestReviveStateTxAndSigner(t *testing.T) { wit := StorageTrieWitness{ Address: addr, ProofList: []MPTProof{{ - RootKey: common.Hex2Bytes("095e7baea6a6c7c4c2"), + RootKeyHex: []byte{0x09, 0x5e, 0x7b, 0xae, 0xa6, 0xa6, 0xc7, 0xc4, 0xc2}, Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, }}, } diff --git a/core/vm/interface.go b/core/vm/interface.go index ad9b05d666..877283384e 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -74,6 +74,7 @@ type StateDB interface { AddPreimage(common.Hash, []byte) ForEachStorage(common.Address, func(common.Hash, common.Hash) bool) error + ReviveTrie(witnessList types.WitnessList) error } // CallContext provides a basic interface for the EVM calling conventions. The EVM diff --git a/light/trie.go b/light/trie.go index adf7f1d810..9c64ef40a0 100644 --- a/light/trie.go +++ b/light/trie.go @@ -202,6 +202,10 @@ func (db *odrTrie) NoTries() bool { return false } +func (t *odrTrie) ReviveTrie(proof trie.MPTProofCache) error { + return t.trie.ReviveTrie(proof) +} + type nodeIterator struct { trie.NodeIterator t *odrTrie diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index be0c3a01b6..6d10d0defe 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -103,3 +103,7 @@ func (t *EmptyTrie) NodeIterator(start []byte) NodeIterator { func (t *EmptyTrie) TryUpdateAccount(key []byte, account *types.StateAccount) error { return nil } + +func (t *EmptyTrie) ReviveTrie(proof MPTProofCache) error { + return nil +} \ No newline at end of file diff --git a/trie/proof.go b/trie/proof.go index f532b8d34b..bac2a44f0e 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -227,10 +227,12 @@ func (m *MPTProofCache) VerifyProof() error { // cache proof nubs m.cacheNubs = make([]*MPTProofNub, 0, len(m.Proof)) - prefix := keybytesToHex(m.RootKey) - prefix = prefix[:len(prefix)-1] + prefix := m.RootKeyHex for i := 0; i < len(m.cacheNodes); i++ { - prefix = append(prefix, m.cacheHexPath[i]...) + if i - 1 >= 0 { + prefix = append(prefix, m.cacheHexPath[i-1]...) + } + // prefix = append(prefix, m.cacheHexPath[i]...) n1 := m.cacheNodes[i] nub := MPTProofNub{ RootHexKey: prefix, @@ -239,7 +241,7 @@ func (m *MPTProofCache) VerifyProof() error { } if needMergeNextNode(m.cacheNodes, i) { i++ - prefix = append(prefix, m.cacheHexPath[i]...) + prefix = append(prefix, m.cacheHexPath[i-1]...) nub.n2 = m.cacheNodes[i] } m.cacheNubs = append(m.cacheNubs, &nub) @@ -290,8 +292,8 @@ func matchHashNodeInShortNode(child []byte, n *shortNode) error { } switch v := n.Val.(type) { - case *hashNode: - if !bytes.Equal(child, *v) { + case hashNode: + if !bytes.Equal(child, v) { return errors.New("proof wrong child in shortNode") } default: diff --git a/trie/proof_test.go b/trie/proof_test.go index 34b81b837b..ac81b46ff0 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -900,7 +900,7 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { - prefixKeys := getPrefixKeys(trie, []byte(kv.k)) + prefixKeys := getPrefixKeysHex(trie, []byte(kv.k)) for _, prefixKey := range prefixKeys { proof := memorydb.New() key := kv.k @@ -1002,7 +1002,7 @@ func TestBadStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { - prefixKeys := getPrefixKeys(trie, []byte(kv.k)) + prefixKeys := getPrefixKeysHex(trie, []byte(kv.k)) for _, prefixKey := range prefixKeys { proof := memorydb.New() key := []byte(kv.k) @@ -1344,7 +1344,7 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { } } -func getPrefixKeys(t *Trie, key []byte) [][]byte { +func getPrefixKeysHex(t *Trie, key []byte) [][]byte { var prefixKeys [][]byte key = keybytesToHex(key) tn := t.root @@ -1417,7 +1417,7 @@ func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { } return MPTProofCache{ MPTProof: types.MPTProof{ - RootKey: key, + RootKeyHex: key, Proof: proof, }, } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index dd7598d893..9b55108933 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -228,3 +228,7 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte { } return t.secKeyCache } + +func (t *SecureTrie) ReviveTrie(proof MPTProofCache) error { + return t.trie.ReviveTrie(proof) +} \ No newline at end of file diff --git a/trie/trie.go b/trie/trie.go index 007408a87f..f82974a616 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -492,36 +492,40 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } } -func (t *Trie) ExpireByPrefix(prefixKey []byte) { - _, err := t.expireByPrefix(t.root, prefixKey) +func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { + hn, err := t.expireByPrefix(t.root, prefixKeyHex) + if prefixKeyHex == nil && hn != nil { + t.root = hn + } if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return err } + return nil } -func (t *Trie) expireByPrefix(n node, prefixKey []byte) (node, error) { +func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { // Loop through prefix key // When prefix key is empty, generate the hash node of the current node // Replace current node with the hash node // If length of prefix key is empty - if len(prefixKey) == 0 { + if len(prefixKeyHex) == 0 { hasher := newHasher(false) defer returnHasherToPool(hasher) var hn node _, hn = hasher.proofHash(n) - return hn, nil + if _, ok := hn.(hashNode); ok { + return hn, nil + } + + return nil, nil } switch n := n.(type) { case *shortNode: - matchLen := prefixLen(prefixKey, n.Key) - if matchLen == len(prefixKey) { - return nil, fmt.Errorf("")// Found the node to expire - } - - hn, err := t.expireByPrefix(n.Val, prefixKey[matchLen:]) + matchLen := prefixLen(prefixKeyHex, n.Key) + hn, err := t.expireByPrefix(n.Val, prefixKeyHex[matchLen:]) if err != nil { return nil, err } @@ -533,14 +537,14 @@ func (t *Trie) expireByPrefix(n node, prefixKey []byte) (node, error) { return nil, err case *fullNode: - hn, err := t.expireByPrefix(n.Children[prefixKey[0]], prefixKey[1:]) + hn, err := t.expireByPrefix(n.Children[prefixKeyHex[0]], prefixKeyHex[1:]) if err != nil { return nil, err } // Replace child node with hash node if hn != nil { - n.Children[prefixKey[0]] = hn + n.Children[prefixKeyHex[0]] = hn } return nil, err @@ -566,6 +570,9 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { hash := common.BytesToHash(n) + if t.db == nil { + return nil, fmt.Errorf("empty trie database") + } if node := t.db.node(hash); node != nil { return node, nil } @@ -649,3 +656,90 @@ func (t *Trie) Reset() { func (t *Trie) Size() int { return estimateSize(t.root) } + +// ReviveTrie revives the trie from the proof cache +func (t *Trie) ReviveTrie(proof MPTProofCache) error { + + var parent node + var childIndex int // If parent is a fullNode, childIndex is the index of the child node + + cacheHashIndex := 0 // Keep track of the index of the cachedHash + nubs := proof.cacheNubs + +loopNubs: + for _, nub := range nubs { + key := nub.RootHexKey + startNode := t.root + parent = nil // TODO (asyukii): When RootNode is introduced, parent node will be the RootNode instead of nil + childIndex = -1 + // Traverse through the trie using RootHexKey + + // Loop through the key to find hash node + for len(key) > 0 { + switch n := startNode.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { + return fmt.Errorf("key %v not found", key) + } else { + parent = n + startNode = n.Val + key = key[len(n.Key):] + } + case *fullNode: + startNode = n.Children[key[0]] + parent = n + childIndex = int(key[0]) + key = key[1:] + case hashNode: + tn, err := t.resolveHash(n, nil) + if err == nil { + startNode = tn + } else { + continue loopNubs + } + default: + continue loopNubs + } + } + + // TODO (asyukii): check if the node is expired + // Attach node to parent + if _, ok := startNode.(hashNode); ok { + cachedHash := proof.cacheHashes[cacheHashIndex] + if bytes.Equal(cachedHash, startNode.(hashNode)) { + // Attach n1 to the trie + switch n := parent.(type) { + case *shortNode: + n.Val = nub.n1 + parent = n.Val + case *fullNode: + n.Children[childIndex] = nub.n1 + parent = n.Children[childIndex] + // TODO (asyukii): build shadow node + } + + // Attach n2 to the trie if exists + if nub.n2 != nil { + switch n := parent.(type) { + case *shortNode: + n.Val = nub.n2 + default: + return fmt.Errorf("n2 should only be attached to a shortNode") + } + // TODO (asyukii): build shadow node if n2 is a fullNode + } + } + } + + // Increment cacheHashIndex + if nub.n1 != nil { + cacheHashIndex++ + } + if nub.n2 != nil { + cacheHashIndex++ + } + } + + return nil +} + diff --git a/trie/trie_test.go b/trie/trie_test.go index 63aed333db..cd792ffb34 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -19,6 +19,7 @@ package trie import ( "bytes" "encoding/binary" + // "encoding/hex" "errors" "fmt" "hash" @@ -39,6 +40,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" + "github.com/stretchr/testify/assert" ) func init() { @@ -473,6 +475,367 @@ func TestRandom(t *testing.T) { } } + +// TestExpireByPrefix tests that the trie is not corrupted after +// expiring a key by prefix. +func TestExpireByPrefix(t *testing.T){ + data := map[string]string{ + "abcd": "A", + "abce": "B", + "abde": "C", + "abdf": "D", + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + trie := createCustomTrie(data) + rootHash := trie.Hash() + + for k := range data { + prefixKeys := getPrefixKeysHex(trie, []byte(k)) + for _, prefixKey := range prefixKeys { + trie.ExpireByPrefix(prefixKey) + currHash := trie.Hash() + // Validate root hash + assert.Equal(t, rootHash, currHash, "Root hash mismatch, got %x, expected %x", currHash, rootHash) + + // Reset trie + trie = createCustomTrie(data) + } + } +} + +func createCustomTrie(data map[string]string) *Trie{ + trie := new(Trie) + for k, v := range data { + trie.Update([]byte(k), []byte(v)) + } + + return trie +} + +type proofList [][]byte + +func (n *proofList) Put(key []byte, value []byte) error { + *n = append(*n, value) + return nil +} + +func (n *proofList) Delete(key []byte) error { + panic("not supported") +} + +func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { + return MPTProofCache{ + MPTProof: types.MPTProof{ + RootKeyHex: rootKeyHex, + Proof: proof, + }, + } +} + +// TestReviveTrie tests that a trie can be revived from a proof +func TestReviveTrie(t *testing.T){ + + trie, vals := nonRandomTrie(500) + + oriRootHash := trie.Hash() + + for _, kv := range vals { + key := []byte(kv.k) + val := []byte(kv.v) + prefixKeys := getPrefixKeysHex(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + // Verify value exists after revive + v := trie.Get(key) + assert.Equal(t, val, v) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) + + // Reset trie + trie, _ = nonRandomTrie(500) + } + } +} + +// TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when +// the whole trie is expired. This test will fail because the parent node in +// ReviveTrie is nil, set to RootNode when available +// func TestReviveAtRoot(t *testing.T) { +// trie, vals := nonRandomTrie(500) + +// oriRootHash := trie.Hash() + +// for _, kv := range vals { +// key := []byte(kv.k) +// val := []byte(kv.v) + +// fmt.Printf("key: %x, val: %x", key, val) +// var proof proofList + +// err := trie.ProveStorageWitness(key, nil, &proof) +// assert.NoError(t, err) + +// // Expire trie +// trie.ExpireByPrefix(nil) + +// // Construct MPTProofCache +// proofCache := makeRawMPTProofCache(nil, proof) + +// // VerifyProof +// err = proofCache.VerifyProof() +// assert.NoError(t, err) + +// // Revive trie +// err = trie.ReviveTrie(proofCache) +// assert.NoError(t, err) + + +// // Verify value exists after revive +// v := trie.Get(key) +// assert.Equal(t, val, v) + +// // Verify root hash +// currRootHash := trie.Hash() +// assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) + +// // Reset trie +// trie, _ = nonRandomTrie(500) +// } + +// } + +// TestReviveBadProof tests that a trie cannot be revived from a bad proof +func TestReviveBadProof(t *testing.T) { + + dataA := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + dataB := map[string]string{ + "qwer": "A", "qwet": "B", "qwrt": "C", "qwry": "D", + "abcd": "E", "abce": "F", "abde": "G", "abdf": "H", + } + + trieA := createCustomTrie(dataA) + trieB := createCustomTrie(dataB) + + var proofB proofList + + err := trieB.ProveStorageWitness([]byte("abcd"), nil, &proofB) + assert.NoError(t, err) + + // Expire trie A + trieA.ExpireByPrefix(nil) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proofB) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + err = trieA.ReviveTrie(proofCache) + assert.NoError(t, err) + + // Verify value does exists after revive + _, err = trieA.TryGet([]byte("abcd")) + assert.Error(t, err) + +} + +// TestReviveOneElement tests that a trie with a single element +// can be revived from a proof +func TestReviveOneElement(t *testing.T) { + trie := new(Trie) + key := []byte("k") + val := []byte("v") + trie.Update(key, val) + + // Generate proof + var proof proofList + + err := trie.ProveStorageWitness(key, nil, &proof) + assert.NoError(t, err) + + err = trie.ExpireByPrefix(nil) + assert.NoError(t, err) + + proofCache := makeRawMPTProofCache(nil, proof) + + err = proofCache.VerifyProof() + assert.NoError(t, err) + + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + v := trie.Get(key) + assert.Equal(t, val, v) +} + +// TestReviveBadProofAfterUpdate tests that after reviving a path and +// then update the value, old proof should be invalid +func TestReviveBadProofAfterUpdate(t *testing.T) { + trie, vals := nonRandomTrie(500) + for _, kv := range vals { + key := []byte(kv.k) + prefixKeys := getPrefixKeysHex(trie, key) + for _, prefixKey := range prefixKeys { + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix(prefixKey) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + err = proofCache.VerifyProof() + assert.NoError(t, err) + + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + trie.Update(key, []byte("new value")) + + // Revive again with old proof + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.TryGet(key) + assert.NoError(t, err) + assert.Equal(t, []byte("new value"), resVal) + } + } +} + +// TestPartialReviveFullProof tests that a path can be revived +// with full proof even when the trie is partially expired +func TestPartialReviveFullProof(t *testing.T) { + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data) + + // Get proof + var proof proofList + err := trie.ProveStorageWitness([]byte("abcd"), nil, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix([]byte{6,1}) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proof) + + // Verify proof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.TryGet([]byte("abcd")) + assert.NoError(t, err) + assert.Equal(t, []byte("A"), resVal) +} + +// TestReviveValueAtFullNode tests that a value that is at +// full node can be revived properly +func TestReviveValueAtFullNode(t *testing.T) { + hexKeys := [][]byte{ + {6,1,6,2,6,3,6,4,16}, + {6,1,6,2,6,3,6,5,16}, + {6,1,6,2,6,4,6,5,16}, + {6,1,6,2,6,4,6,6,16}, + {6,4,6,5,6,6,6,7,16}, + {6,4,6,5,6,6,6,8,16}, + {6,4,6,5,6,7,6,8,16}, + {6,4,6,5,6,7,6,9,16}, + {6,1,6,2,6,3,6,4,16}, + {6,1,6,2,16}, // This is the key that has a value at a full node + } + + byteKeys := make([][]byte, len(hexKeys)) + + vals := []string{ + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", + } + + trie := new(Trie) + for i, hexKey := range hexKeys { + hexKey = hexToKeybytes(hexKey) + byteKeys[i] = hexKey + } + + // Insert keys into trie + for i, hexKey := range byteKeys { + trie.Update(hexKey, []byte(vals[i])) + } + + key := byteKeys[9] + val := vals[9] + + prefixKeys := getPrefixKeysHex(trie, key) + + for _, prefixKey := range prefixKeys { + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix(prefixKey) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + err = proofCache.VerifyProof() + assert.NoError(t, err) + + err = trie.ReviveTrie(proofCache) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.TryGet(key) + assert.NoError(t, err) + assert.Equal(t, []byte(val), resVal) + } +} + func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } From b9a713cff955dc66ab956d46a089cd211f7147bd Mon Sep 17 00:00:00 2001 From: cryyl <1226241521@qq.com> Date: Wed, 12 Apr 2023 17:46:00 +0800 Subject: [PATCH 21/51] implement MPT and shadow node Signed-off-by: cryyl <1226241521@qq.com> --- trie/database.go | 10 ++++++++ trie/errors.go | 10 ++++++++ trie/node.go | 61 +++++++++++++++++++++++++++++++++++++++------ trie/shadownodes.go | 18 +++++++++++++ trie/trie.go | 45 +++++++++++++++++++-------------- 5 files changed, 117 insertions(+), 27 deletions(-) create mode 100644 trie/shadownodes.go diff --git a/trie/database.go b/trie/database.go index 52469352cf..002af87498 100644 --- a/trie/database.go +++ b/trie/database.go @@ -103,6 +103,8 @@ type rawNode []byte func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawNode) setEpoch(epcoh uint16) { panic("this should never end up in a live trie") } +func (n rawNode) getEpoch() uint16 { panic("this should never end up in a live trie") } func (n rawNode) EncodeRLP(w io.Writer) error { _, err := w.Write(n) @@ -120,6 +122,8 @@ type rawFullNode [17]node func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawFullNode) setEpoch(epcoh uint16) { panic("this should never end up in a live trie") } +func (n rawFullNode) getEpoch() uint16 { panic("this should never end up in a live trie") } func (n rawFullNode) nodeType() int { return rawFullNodeType @@ -141,6 +145,8 @@ type rawShortNode struct { func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawShortNode) setEpoch(epoch uint16) { panic("this should never end up in a live trie") } +func (n rawShortNode) getEpoch() uint16 { panic("this should never end up in a live trie") } func (n rawShortNode) nodeType() int { return rawShortNodeType @@ -421,6 +427,10 @@ func (db *Database) node(hash common.Hash) node { return mustDecodeNodeUnsafe(hash[:], enc) } +func (db *Database) RootNode(hash common.Hash) *RootNode { + return nil +} + // Node retrieves an encoded cached trie node from memory. If it cannot be found // cached, the method queries the persistent database for the content. func (db *Database) Node(hash common.Hash) ([]byte, error) { diff --git a/trie/errors.go b/trie/errors.go index 567b80078c..425f7154bc 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -33,3 +33,13 @@ type MissingNodeError struct { func (err *MissingNodeError) Error() string { return fmt.Sprintf("missing trie node %x (path %x)", err.NodeHash, err.Path) } + +type ExpiredNodeError struct { + ExpiredNode node // node of the expired node + Path []byte // hex-encoded path to the expired node + Epoch uint16 +} + +func (err *ExpiredNodeError) Error() string { + return fmt.Sprintf("expired trie ndoe ") +} diff --git a/trie/node.go b/trie/node.go index b50324294f..d4d84360f1 100644 --- a/trie/node.go +++ b/trie/node.go @@ -46,22 +46,34 @@ type node interface { encode(w rlp.EncoderBuffer) fstring(string) string nodeType() int + setEpoch(epoch uint16) + getEpoch() uint16 } type ( fullNode struct { - Children [BranchNodeLength]node // Actual trie node data to encode/decode (needs custom encoder) - flags nodeFlag + Children [BranchNodeLength]node // Actual trie node data to encode/decode (needs custom encoder) + flags nodeFlag + epoch uint16 `rlp:"-" json:"-"` + shadowNode *shadowBranchNode `rlp:"-" json:"-"` } shortNode struct { - Key []byte - Val node - flags nodeFlag + Key []byte + Val node + flags nodeFlag + epoch uint16 `rlp:"-" json:"-"` + shadowNode *shadowExtensionNode `rlp:"-" json:"-"` } hashNode []byte valueNode []byte ) +type RootNode struct { + Epoch uint16 + TrieHash common.Hash + ShadowHash common.Hash +} + // nilValueNode is used when collapsing internal trie nodes for hashing, since // unset children need to serialize correctly. var nilValueNode = valueNode(nil) @@ -73,6 +85,26 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { return eb.Flush() } +func (n *fullNode) GetShadowNode() *shadowBranchNode { + return &shadowBranchNode{} +} + +func (n *fullNode) IsChildExpired(pos int) (bool, error) { + return false, nil +} + +func (n *fullNode) GetChildEpoch(pos int) uint16 { + return n.GetShadowNode().EpochMap[pos] +} + +func (n *fullNode) UpdateChildEpoch(pos int, epoch uint16) { + n.GetShadowNode().EpochMap[pos] = epoch +} + +func (n *shortNode) GetShadowNode() *shadowExtensionNode { + return &shadowExtensionNode{} +} + func (n *fullNode) copy() *fullNode { copy := *n; return © } func (n *shortNode) copy() *shortNode { copy := *n; return © } @@ -93,6 +125,16 @@ func (n *shortNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } +func (n *fullNode) setEpoch(epoch uint16) { n.epoch = epoch } +func (n *shortNode) setEpoch(epoch uint16) { n.epoch = epoch } +func (n hashNode) setEpoch(epoch uint16) {} +func (n valueNode) setEpoch(epoch uint16) {} + +func (n *fullNode) getEpoch() uint16 { return n.epoch } +func (n *shortNode) getEpoch() uint16 { return n.epoch } +func (n hashNode) getEpoch() uint16 { return 0 } +func (n valueNode) getEpoch() uint16 { return 0 } + func (n *fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) for i, node := range &n.Children { @@ -186,21 +228,24 @@ func decodeShort(hash, elems []byte) (node, error) { if err != nil { return nil, err } - flag := nodeFlag{hash: hash} + n := &shortNode{flags: nodeFlag{hash: hash}} key := compactToHex(kbuf) + n.Key = key if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { return nil, fmt.Errorf("invalid value node: %v", err) } - return &shortNode{key, valueNode(val), flag}, nil + n.Val = valueNode(val) + return n, nil } r, _, err := decodeRef(rest) if err != nil { return nil, wrapError(err, "val") } - return &shortNode{key, r, flag}, nil + n.Val = r + return n, nil } func decodeFull(hash, elems []byte) (*fullNode, error) { diff --git a/trie/shadownodes.go b/trie/shadownodes.go new file mode 100644 index 0000000000..819b405f6a --- /dev/null +++ b/trie/shadownodes.go @@ -0,0 +1,18 @@ +package trie + +import ( + "github.com/ethereum/go-ethereum/common" +) + +//type shadowNode interface { +// encode(encoder rlp.EncoderBuffer) +//} + +type shadowExtensionNode struct { + ShadowHash *common.Hash +} + +type shadowBranchNode struct { + ShadowHash *common.Hash + EpochMap [16]uint16 +} diff --git a/trie/trie.go b/trie/trie.go index f82974a616..f27f228ae0 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -86,8 +86,13 @@ func New(root common.Hash, db *Database) (*Trie, error) { trie := &Trie{ db: db, } + epoch := uint16(0) + if rootNode := db.RootNode(root); rootNode != nil { + root = rootNode.TrieHash + epoch = rootNode.Epoch + } if root != (common.Hash{}) && root != emptyRoot { - rootnode, err := trie.resolveHash(root[:], nil) + rootnode, err := trie.resolveHash(root[:], nil, epoch) if err != nil { return nil, err } @@ -116,14 +121,14 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) + value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0, t.root.getEpoch()) if err == nil && didResolve { t.root = newroot } return value, err } -func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { +func (t *Trie) tryGet(origNode node, key []byte, pos int, epoch uint16) (value []byte, newnode node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -134,25 +139,28 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode // key not found in trie return nil, n, false, nil } - value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) + value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key), n.epoch) if err == nil && didResolve { n = n.copy() n.Val = newnode } return value, n, didResolve, err case *fullNode: - value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) + if expired, err := n.IsChildExpired(pos); expired { + return nil, n, false, err + } + value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(pos)) if err == nil && didResolve { n = n.copy() n.Children[key[pos]] = newnode } return value, n, didResolve, err case hashNode: - child, err := t.resolveHash(n, key[:pos]) + child, err := t.resolveHash(n, key[:pos], epoch) if err != nil { return nil, n, true, err } - value, newnode, _, err := t.tryGet(child, key, pos) + value, newnode, _, err := t.tryGet(child, key, pos, epoch) return value, newnode, true, err default: panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) @@ -300,7 +308,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if !dirty || err != nil { return false, n, err } - return true, &shortNode{n.Key, nn, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag()}, nil } // Otherwise branch out at the index where they differ. branch := &fullNode{flags: t.newFlag()} @@ -318,7 +326,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil + return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag()}, nil case *fullNode: dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) @@ -331,7 +339,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error return true, n, nil case nil: - return true, &shortNode{key, value, t.newFlag()}, nil + return true, &shortNode{Key: key, Val: value, flags: t.newFlag()}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -401,9 +409,9 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil + return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag()}, nil default: - return true, &shortNode{n.Key, child, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag()}, nil } case *fullNode: @@ -457,12 +465,12 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } if cnode, ok := cnode.(*shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return true, &shortNode{k, cnode.Val, t.newFlag()}, nil + return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag()}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil + return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag()}, nil } // n still contains at least two values and cannot be reduced. return true, n, nil @@ -507,19 +515,18 @@ func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { // Loop through prefix key // When prefix key is empty, generate the hash node of the current node // Replace current node with the hash node - + // If length of prefix key is empty if len(prefixKeyHex) == 0 { hasher := newHasher(false) defer returnHasherToPool(hasher) var hn node _, hn = hasher.proofHash(n) - if _, ok := hn.(hashNode); ok { return hn, nil } - return nil, nil + return nil, nil } switch n := n.(type) { @@ -553,7 +560,6 @@ func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { } } - func concat(s1 []byte, s2 ...byte) []byte { r := make([]byte, len(s1)+len(s2)) copy(r, s1) @@ -568,12 +574,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { return n, nil } -func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { +func (t *Trie) resolveHash(n hashNode, prefix []byte, epoch uint16) (node, error) { hash := common.BytesToHash(n) if t.db == nil { return nil, fmt.Errorf("empty trie database") } if node := t.db.node(hash); node != nil { + node.setEpoch(epoch) return node, nil } return nil, &MissingNodeError{NodeHash: hash, Path: prefix} From 8e9e1a8b71bb9381caa8a9dd44ffa15ec017f40e Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 14 Apr 2023 11:13:40 +0800 Subject: [PATCH 22/51] state/state_object: add dirty trie & pending trie; state/journal: add revive journal; state/stateDB: opt revive state; state/state_object: record AccessedState, opt pending revive trie; state/journal: add access state journal; --- core/state/database.go | 2 +- core/state/journal.go | 31 +++++++++++++ core/state/state_object.go | 91 ++++++++++++++++++++++++++++++++++++-- core/state/statedb.go | 17 ++++--- core/state_transition.go | 2 +- core/vm/interface.go | 2 +- trie/dummy_trie.go | 3 +- trie/proof.go | 7 +-- trie/proof_test.go | 13 +++++- trie/trie.go | 22 ++++----- trie/trie_test.go | 46 +++++++++---------- 11 files changed, 183 insertions(+), 53 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 4c071e2589..38e2f46ca8 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -126,7 +126,7 @@ type Trie interface { // nodes of the longest existing prefix of the key (at least the root), ending // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error - + ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error ReviveTrie(trie.MPTProofCache) error diff --git a/core/state/journal.go b/core/state/journal.go index 4f1fe2bf48..10b00a9228 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -139,6 +139,13 @@ type ( address *common.Address slot *common.Hash } + reviveStorageTrieNodeChange struct { + address *common.Address + } + accessedStorageStateChange struct { + address *common.Address + slot *common.Hash + } ) func (ch createObjectChange) revert(s *StateDB) { @@ -272,3 +279,27 @@ func (ch accessListAddSlotChange) revert(s *StateDB) { func (ch accessListAddSlotChange) dirtied() *common.Address { return nil } + +func (ch reviveStorageTrieNodeChange) revert(s *StateDB) { + s.getStateObject(*ch.address).dirtyReviveState = make(Storage) + s.getStateObject(*ch.address).dirtyReviveTrie = nil +} + +func (ch reviveStorageTrieNodeChange) dirtied() *common.Address { + return ch.address +} + +func (ch accessedStorageStateChange) revert(s *StateDB) { + if count, ok := s.getStateObject(*ch.address).dirtyAccessedState[*ch.slot]; ok { + if count > 1 { + s.getStateObject(*ch.address).dirtyAccessedState[*ch.slot] = count - 1 + } else { + delete(s.getStateObject(*ch.address).dirtyAccessedState, *ch.slot) + } + } + +} + +func (ch accessedStorageStateChange) dirtied() *common.Address { + return ch.address +} diff --git a/core/state/state_object.go b/core/state/state_object.go index 1ede96ec63..e4b44e538d 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -19,6 +19,7 @@ package state import ( "bytes" "fmt" + "github.com/ethereum/go-ethereum/trie" "io" "math/big" "sync" @@ -79,7 +80,7 @@ type StateObject struct { dbErr error // Write caches. - trie Trie // storage trie, which becomes non-nil on first access + trie Trie // storage trie, which becomes non-nil on first access, it's committed trie code Code // contract bytecode, which gets set when code is loaded sharedOriginStorage *sync.Map // Point to the entry of the stateObject in sharedPool @@ -89,6 +90,17 @@ type StateObject struct { dirtyStorage Storage // Storage entries that have been modified in the current transaction execution fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. + // revive state + pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later + dirtyReviveTrie Trie // dirtyReviveTrie for tx + + // TODO when R&W, access revive state first + pendingReviveState Storage // pendingReviveState for block, it cannot flush to trie, just cache + dirtyReviveState Storage // dirtyReviveState for tx, for cache dirtyReviveTrie + + pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed, it will update epoch index late + dirtyAccessedState map[common.Hash]int // dirtyAccessedState record which state is accessed, it will update epoch index later, attention: don't record revive state + // Cache flags. // When an object is marked suicided it will be delete from the trie // during the "update" phase of the state transition. @@ -131,6 +143,8 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *St originStorage: make(Storage), pendingStorage: make(Storage), dirtyStorage: make(Storage), + dirtyReviveState: make(Storage), + pendingReviveState: make(Storage), } } @@ -183,6 +197,20 @@ func (s *StateObject) getTrie(db Database) Trie { return s.trie } +func (s *StateObject) getPendingReviveTrie(db Database) Trie { + if s.pendingReviveTrie == nil { + s.pendingReviveTrie = s.db.db.CopyTrie(s.getTrie(db)) + } + return s.pendingReviveTrie +} + +func (s *StateObject) getDirtyReviveTrie(db Database) Trie { + if s.dirtyReviveTrie == nil { + s.dirtyReviveTrie = s.db.db.CopyTrie(s.getPendingReviveTrie(db)) + } + return s.dirtyReviveTrie +} + // GetState retrieves a value from the account storage trie. func (s *StateObject) GetState(db Database, key common.Hash) common.Hash { // If the fake storage is set, only lookup the state here(in the debugging mode) @@ -338,6 +366,14 @@ func (s *StateObject) finalise(prefetch bool) { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure } } + for key, value := range s.dirtyReviveState { + s.pendingReviveState[key] = value + } + for key, value := range s.dirtyAccessedState { + count := s.pendingAccessedState[key] + count += value + s.pendingAccessedState[key] = count + } prefetcher := s.db.prefetcher if prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != emptyRoot { @@ -346,6 +382,16 @@ func (s *StateObject) finalise(prefetch bool) { if len(s.dirtyStorage) > 0 { s.dirtyStorage = make(Storage) } + if len(s.dirtyReviveState) > 0 { + s.dirtyReviveState = make(Storage) + } + if len(s.dirtyAccessedState) > 0 { + s.dirtyAccessedState = make(map[common.Hash]int) + } + if s.dirtyReviveTrie != nil { + s.pendingReviveTrie = s.dirtyReviveTrie + s.dirtyReviveTrie = nil + } } // updateTrie writes cached storage modifications into the object's storage trie. @@ -364,8 +410,8 @@ func (s *StateObject) updateTrie(db Database) Trie { s.db.MetricsMux.Unlock() }(time.Now()) } - // Insert all the pending updates into the trie - tr := s.getTrie(db) + // Insert all the pending updates into the pending trie + tr := s.getPendingReviveTrie(db) usedStorage := make([][]byte, 0, len(s.pendingStorage)) dirtyStorage := make(map[common.Hash][]byte) @@ -395,6 +441,16 @@ func (s *StateObject) updateTrie(db Database) Trie { usedStorage = append(usedStorage, common.CopyBytes(key[:])) } }() + wg.Add(1) + go func() { + defer wg.Done() + for key := range s.pendingAccessedState { + if _, ok := dirtyStorage[key]; ok { + continue + } + // TODO update accessed state epoch index + } + }() if s.db.snap != nil { // If state snapshotting is active, cache the data til commit wg.Add(1) @@ -423,6 +479,14 @@ func (s *StateObject) updateTrie(db Database) Trie { if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) } + if len(s.pendingAccessedState) > 0 { + s.pendingAccessedState = make(map[common.Hash]int) + } + + // reset trie as pending trie, will commit later + if tr != nil { + s.trie = s.db.db.CopyTrie(tr) + } return tr } @@ -524,6 +588,15 @@ func (s *StateObject) deepCopy(db *StateDB) *StateObject { stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted + + if s.dirtyReviveTrie != nil { + stateObject.dirtyReviveTrie = db.db.CopyTrie(s.dirtyReviveTrie) + } + if s.pendingReviveTrie != nil { + stateObject.pendingReviveTrie = db.db.CopyTrie(s.pendingReviveTrie) + } + stateObject.dirtyReviveState = s.dirtyReviveState.Copy() + stateObject.pendingReviveState = s.pendingReviveState.Copy() return stateObject } @@ -615,3 +688,15 @@ func (s *StateObject) Nonce() uint64 { func (s *StateObject) Value() *big.Int { panic("Value on StateObject should never be called") } + +func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { + dr := s.getDirtyReviveTrie(s.db.db) + if err := dr.ReviveTrie(proofCache); err != nil { + s.dirtyReviveTrie = nil + return err + } + s.db.journal.append(reviveStorageTrieNodeChange{ + address: &s.address, + }) + return nil +} diff --git a/core/state/statedb.go b/core/state/statedb.go index c31d6961be..acdcfce606 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -473,6 +473,7 @@ func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { } // GetState retrieves a value from the given account's storage trie. +// TODO access in shadow node func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { stateObject := s.getStateObject(addr) if stateObject != nil { @@ -507,7 +508,7 @@ func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, key c return proof, err } -// TODO: GetStorageProof returns the combined Merkle proof and Shadow Tree proof for given storage slot. +// TODO: GetStorageProof returns the combined Merkle proof and Shadow Tree proof for given storage slot. func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) @@ -593,6 +594,7 @@ func (s *StateDB) SetCode(addr common.Address, code []byte) { } } +// TODO access state and check insert duplicated func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { @@ -1456,6 +1458,7 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er tasks <- func() { // Write any storage changes in the state object to its storage trie if !s.noTrie { + // TODO commit revive state cache to Trie if _, err := obj.CommitTrie(s.db); err != nil { taskResults <- err return @@ -1568,6 +1571,7 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er diffLayer.Destructs, diffLayer.Accounts, diffLayer.Storages = s.SnapToDiffLayer() // Only update if there's a state transition (skip empty Clique blocks) if parent := s.snap.Root(); parent != s.expectedRoot { + // TODO snap support epoch index err := s.snaps.Update(s.expectedRoot, parent, s.snapDestructs, s.snapAccounts, s.snapStorage, verified) if err != nil { @@ -1755,7 +1759,7 @@ func (s *StateDB) GetStorage(address common.Address) *sync.Map { } // ReviveTrie revive a trie with a given witness list -func (s *StateDB) ReviveTrie(witnessList types.WitnessList) error{ +func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { for i := range witnessList { wit := witnessList[i] // got specify witness, verify proof and check if revive success @@ -1778,8 +1782,11 @@ func (s *StateDB) ReviveTrie(witnessList types.WitnessList) error{ return err } - trie := s.StorageTrie(stWit.Address) - if err := trie.ReviveTrie(proofCaches[j]); err != nil { + stateObject := s.getStateObject(stWit.Address) + if stateObject == nil { + return errors.New("contract object not found") + } + if err := stateObject.ReviveStorageTrie(proofCaches[j]); err != nil { return err } } @@ -1789,4 +1796,4 @@ func (s *StateDB) ReviveTrie(witnessList types.WitnessList) error{ } return nil -} \ No newline at end of file +} diff --git a/core/state_transition.go b/core/state_transition.go index 3699be1fb1..90f8e1610b 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -363,7 +363,7 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { // revive state before execution if rules.IsElwood { - st.state.ReviveTrie(msg.WitnessList()) + st.state.ReviveStorageTrie(msg.WitnessList()) } var ( diff --git a/core/vm/interface.go b/core/vm/interface.go index 877283384e..86dde81c80 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -74,7 +74,7 @@ type StateDB interface { AddPreimage(common.Hash, []byte) ForEachStorage(common.Address, func(common.Hash, common.Hash) bool) error - ReviveTrie(witnessList types.WitnessList) error + ReviveStorageTrie(witnessList types.WitnessList) error } // CallContext provides a basic interface for the EVM calling conventions. The EVM diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 6d10d0defe..0a27cd24f8 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -18,7 +18,6 @@ package trie import ( "fmt" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" @@ -106,4 +105,4 @@ func (t *EmptyTrie) TryUpdateAccount(key []byte, account *types.StateAccount) er func (t *EmptyTrie) ReviveTrie(proof MPTProofCache) error { return nil -} \ No newline at end of file +} diff --git a/trie/proof.go b/trie/proof.go index bac2a44f0e..9c238a7278 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -180,8 +180,8 @@ type MPTProofCache struct { // VerifyProof verify proof in MPT witness // 1. calculate hash // 2. decode trie node -// 3. verify partial merkle proof of the witness -// 4. split to partial witness +// 3. verify partial merkle proof of the witness, TODO match algo will check inner mem node scene, until meet hash node or value node or nil? +// 4. split to partial witness, TODO check if satisfy partial witness rules? // TODO later revive state could revive KV from fullNode[0-15] or fullNode[16] shortNode.Val, return KVs for cache & snap // another easy method is that revive direct to Trie/Trie cache, query later and set to KV cache func (m *MPTProofCache) VerifyProof() error { @@ -229,7 +229,7 @@ func (m *MPTProofCache) VerifyProof() error { m.cacheNubs = make([]*MPTProofNub, 0, len(m.Proof)) prefix := m.RootKeyHex for i := 0; i < len(m.cacheNodes); i++ { - if i - 1 >= 0 { + if i-1 >= 0 { prefix = append(prefix, m.cacheHexPath[i-1]...) } // prefix = append(prefix, m.cacheHexPath[i]...) @@ -244,6 +244,7 @@ func (m *MPTProofCache) VerifyProof() error { prefix = append(prefix, m.cacheHexPath[i-1]...) nub.n2 = m.cacheNodes[i] } + // TODO check short node must with child in same nub m.cacheNubs = append(m.cacheNubs, &nub) } diff --git a/trie/proof_test.go b/trie/proof_test.go index ac81b46ff0..f18ecf8232 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -1386,6 +1386,13 @@ func getPrefixKeysHex(t *Trie, key []byte) [][]byte { return prefixKeys } +// TODO add more UTs +// 1. whole trie expired without root node +// 2. whole trie expired with root node +// 3. witness with one short node, val is full node +// 3. witness with one short node, val is value node +// 3. witness with one short node, val is recursive full node + short node + value node +// 4. witness with one full node with val node, children is short node and recursive full node + short node + value node func TestMPTProofCache_VerifyProof_normalCase(t *testing.T) { cache := makeMPTProofCache(nil, []string{ "0xf90211a03697534056039e03300557bd69fe16e18ce4a6ccd5522db4dfa97dfe1fad3d3aa0b1bf1f230b98b9034738d599177ae817c08143b9395a47f300636b0dd2fb3c5ea0aa04a4966751d4c50063fe13a96a6c7924f665819733f556849b5eb9fa1d6839a0e162e080d1c12c59dc984fb2246d8ad61209264bee40d3fdd07c4ea4ff411b6aa0e5c3f2dde71bf303423f34674748567dcdf8379129653b8213f698468738d492a068a3e3059b6e7115a055a7874f81c5a1e84ddc1967527973f8c78cd86a1c9f8fa0d734bd63b7be8e8471091b792f5bbcbc7b0ce582f6d985b7a15a3c0155242c56a00143c06f57a65c8485dbae750aa51df5dff1bf7bdf28060129a20de9e51364eda07b416f79b3f4e39d0159efff351009d44002d9e83530fb5a5778eb55f5f4432ca036706b52196fa0b73feb2e7ff8f1379c7176d427dd44ad63c7b65e66693904a1a0fd6c8b815e2769ce379a20eaccdba1f145fb11f77c280553f15ee4f1ee135375a02f5233009f082177e5ed2bfa6e180bf1a7310e6bc3c079cb85a4ac6fee4ae379a03f07f1bb33fa26ebd772fa874914dc7a08581095e5159fdcf9221be6cbeb6648a097557eec1ac08c3bfe45ce8e34cd329164a33928ac83fef1009656536ef6907fa028196bfb31aa7f14a0a8000b00b0aa5d09450c32d537e45eebee70b14313ff1ca0126ce265ca7bbb0e0b01f068d1edef1544cbeb2f048c99829713c18d7abc049a80", @@ -1406,7 +1413,9 @@ func TestMPTProofCache_VerifyProof_normalCase(t *testing.T) { h.sha.Reset() h.sha.Write(key) h.sha.Read(hash) - assert.Equal(t, hash, hexToKeybytes(cache.cacheNubs[6].RootHexKey)) + ln := cache.cacheNubs[6].n1.(*shortNode) + hexKey := append(cache.cacheNubs[6].RootHexKey, ln.Key...) + assert.Equal(t, hash, hexToKeybytes(hexKey)) } func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { @@ -1418,7 +1427,7 @@ func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { return MPTProofCache{ MPTProof: types.MPTProof{ RootKeyHex: key, - Proof: proof, + Proof: proof, }, } } diff --git a/trie/trie.go b/trie/trie.go index f27f228ae0..edee42541e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -506,7 +506,7 @@ func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { t.root = hn } if err != nil { - return err + return err } return nil } @@ -666,10 +666,10 @@ func (t *Trie) Size() int { // ReviveTrie revives the trie from the proof cache func (t *Trie) ReviveTrie(proof MPTProofCache) error { - + var parent node var childIndex int // If parent is a fullNode, childIndex is the index of the child node - + cacheHashIndex := 0 // Keep track of the index of the cachedHash nubs := proof.cacheNubs @@ -717,6 +717,7 @@ loopNubs: // Attach n1 to the trie switch n := parent.(type) { case *shortNode: + // TODO should copy node and parent point to new node n.Val = nub.n1 parent = n.Val case *fullNode: @@ -727,13 +728,13 @@ loopNubs: // Attach n2 to the trie if exists if nub.n2 != nil { - switch n := parent.(type) { - case *shortNode: - n.Val = nub.n2 - default: - return fmt.Errorf("n2 should only be attached to a shortNode") - } - // TODO (asyukii): build shadow node if n2 is a fullNode + switch n := parent.(type) { + case *shortNode: + n.Val = nub.n2 + default: + return fmt.Errorf("n2 should only be attached to a shortNode") + } + // TODO (asyukii): build shadow node if n2 is a fullNode } } } @@ -749,4 +750,3 @@ loopNubs: return nil } - diff --git a/trie/trie_test.go b/trie/trie_test.go index cd792ffb34..72cd1dab59 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -39,8 +39,8 @@ import ( "github.com/ethereum/go-ethereum/ethdb/leveldb" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/rlp" - "golang.org/x/crypto/sha3" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/sha3" ) func init() { @@ -475,10 +475,9 @@ func TestRandom(t *testing.T) { } } - -// TestExpireByPrefix tests that the trie is not corrupted after +// TestExpireByPrefix tests that the trie is not corrupted after // expiring a key by prefix. -func TestExpireByPrefix(t *testing.T){ +func TestExpireByPrefix(t *testing.T) { data := map[string]string{ "abcd": "A", "abce": "B", @@ -494,7 +493,7 @@ func TestExpireByPrefix(t *testing.T){ rootHash := trie.Hash() for k := range data { - prefixKeys := getPrefixKeysHex(trie, []byte(k)) + prefixKeys := getPrefixKeysHex(trie, []byte(k)) for _, prefixKey := range prefixKeys { trie.ExpireByPrefix(prefixKey) currHash := trie.Hash() @@ -507,7 +506,7 @@ func TestExpireByPrefix(t *testing.T){ } } -func createCustomTrie(data map[string]string) *Trie{ +func createCustomTrie(data map[string]string) *Trie { trie := new(Trie) for k, v := range data { trie.Update([]byte(k), []byte(v)) @@ -531,13 +530,13 @@ func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { return MPTProofCache{ MPTProof: types.MPTProof{ RootKeyHex: rootKeyHex, - Proof: proof, + Proof: proof, }, } } // TestReviveTrie tests that a trie can be revived from a proof -func TestReviveTrie(t *testing.T){ +func TestReviveTrie(t *testing.T) { trie, vals := nonRandomTrie(500) @@ -578,10 +577,10 @@ func TestReviveTrie(t *testing.T){ // Reset trie trie, _ = nonRandomTrie(500) } - } + } } -// TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when +// TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when // the whole trie is expired. This test will fail because the parent node in // ReviveTrie is nil, set to RootNode when available // func TestReviveAtRoot(t *testing.T) { @@ -613,7 +612,6 @@ func TestReviveTrie(t *testing.T){ // err = trie.ReviveTrie(proofCache) // assert.NoError(t, err) - // // Verify value exists after revive // v := trie.Get(key) // assert.Equal(t, val, v) @@ -630,7 +628,7 @@ func TestReviveTrie(t *testing.T){ // TestReviveBadProof tests that a trie cannot be revived from a bad proof func TestReviveBadProof(t *testing.T) { - + dataA := map[string]string{ "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", "defg": "E", "defh": "F", "degh": "G", "degi": "H", @@ -698,7 +696,7 @@ func TestReviveOneElement(t *testing.T) { assert.Equal(t, val, v) } -// TestReviveBadProofAfterUpdate tests that after reviving a path and +// TestReviveBadProofAfterUpdate tests that after reviving a path and // then update the value, old proof should be invalid func TestReviveBadProofAfterUpdate(t *testing.T) { trie, vals := nonRandomTrie(500) @@ -753,7 +751,7 @@ func TestPartialReviveFullProof(t *testing.T) { assert.NoError(t, err) // Expire trie - err = trie.ExpireByPrefix([]byte{6,1}) + err = trie.ExpireByPrefix([]byte{6, 1}) assert.NoError(t, err) // Construct MPTProofCache @@ -777,16 +775,16 @@ func TestPartialReviveFullProof(t *testing.T) { // full node can be revived properly func TestReviveValueAtFullNode(t *testing.T) { hexKeys := [][]byte{ - {6,1,6,2,6,3,6,4,16}, - {6,1,6,2,6,3,6,5,16}, - {6,1,6,2,6,4,6,5,16}, - {6,1,6,2,6,4,6,6,16}, - {6,4,6,5,6,6,6,7,16}, - {6,4,6,5,6,6,6,8,16}, - {6,4,6,5,6,7,6,8,16}, - {6,4,6,5,6,7,6,9,16}, - {6,1,6,2,6,3,6,4,16}, - {6,1,6,2,16}, // This is the key that has a value at a full node + {6, 1, 6, 2, 6, 3, 6, 4, 16}, + {6, 1, 6, 2, 6, 3, 6, 5, 16}, + {6, 1, 6, 2, 6, 4, 6, 5, 16}, + {6, 1, 6, 2, 6, 4, 6, 6, 16}, + {6, 4, 6, 5, 6, 6, 6, 7, 16}, + {6, 4, 6, 5, 6, 6, 6, 8, 16}, + {6, 4, 6, 5, 6, 7, 6, 8, 16}, + {6, 4, 6, 5, 6, 7, 6, 9, 16}, + {6, 1, 6, 2, 6, 3, 6, 4, 16}, + {6, 1, 6, 2, 16}, // This is the key that has a value at a full node } byteKeys := make([][]byte, len(hexKeys)) From 4d4a2a5a12ea76717cab7d9fcae73dc94154fd5c Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 18 Apr 2023 00:48:11 +0800 Subject: [PATCH 23/51] refactor(trie): update ReviveTrie to include cache mechanism --- core/state/database.go | 2 +- light/trie.go | 4 +- trie/dummy_trie.go | 4 +- trie/secure_trie.go | 6 +- trie/trie.go | 138 +++++++++++++++++++++-------------------- trie/trie_test.go | 16 ++--- 6 files changed, 86 insertions(+), 84 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 38e2f46ca8..037be8841a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -129,7 +129,7 @@ type Trie interface { ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error - ReviveTrie(trie.MPTProofCache) error + ReviveTrie(trie.MPTProofCache) } // NewDatabase creates a backing store for state. The returned database is safe for diff --git a/light/trie.go b/light/trie.go index 9c64ef40a0..23cc60919b 100644 --- a/light/trie.go +++ b/light/trie.go @@ -202,8 +202,8 @@ func (db *odrTrie) NoTries() bool { return false } -func (t *odrTrie) ReviveTrie(proof trie.MPTProofCache) error { - return t.trie.ReviveTrie(proof) +func (t *odrTrie) ReviveTrie(proof trie.MPTProofCache) { + t.trie.ReviveTrie(proof) } type nodeIterator struct { diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 0a27cd24f8..a5e17ff386 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -103,6 +103,4 @@ func (t *EmptyTrie) TryUpdateAccount(key []byte, account *types.StateAccount) er return nil } -func (t *EmptyTrie) ReviveTrie(proof MPTProofCache) error { - return nil -} +func (t *EmptyTrie) ReviveTrie(proof MPTProofCache) {} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 9b55108933..18e5548b5b 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -229,6 +229,6 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte { return t.secKeyCache } -func (t *SecureTrie) ReviveTrie(proof MPTProofCache) error { - return t.trie.ReviveTrie(proof) -} \ No newline at end of file +func (t *SecureTrie) ReviveTrie(proof MPTProofCache) { + t.trie.ReviveTrie(proof) +} diff --git a/trie/trie.go b/trie/trie.go index edee42541e..f3e8a6ac41 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -664,89 +664,93 @@ func (t *Trie) Size() int { return estimateSize(t.root) } -// ReviveTrie revives the trie from the proof cache -func (t *Trie) ReviveTrie(proof MPTProofCache) error { - - var parent node - var childIndex int // If parent is a fullNode, childIndex is the index of the child node - - cacheHashIndex := 0 // Keep track of the index of the cachedHash - nubs := proof.cacheNubs - -loopNubs: - for _, nub := range nubs { - key := nub.RootHexKey - startNode := t.root - parent = nil // TODO (asyukii): When RootNode is introduced, parent node will be the RootNode instead of nil - childIndex = -1 - // Traverse through the trie using RootHexKey - - // Loop through the key to find hash node - for len(key) > 0 { - switch n := startNode.(type) { - case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { - return fmt.Errorf("key %v not found", key) - } else { - parent = n - startNode = n.Val - key = key[len(n.Key):] - } - case *fullNode: - startNode = n.Children[key[0]] - parent = n - childIndex = int(key[0]) - key = key[1:] - case hashNode: - tn, err := t.resolveHash(n, nil) - if err == nil { - startNode = tn - } else { - continue loopNubs - } - default: - continue loopNubs - } +func (t *Trie) ReviveTrie(proof MPTProofCache) { + if err := t.TryRevive(proof); err != nil { + log.Error(fmt.Sprintf("Failed to revive trie: %v", err)) + } +} + +func (t *Trie) TryRevive(proof MPTProofCache) error { + + cacheHashIndex := 0 + + for _, nub := range proof.cacheNubs { + newNode, didResolve, newCacheHashIndex, err := t.tryRevive(t.root, nub.RootHexKey, *nub, proof.cacheHashes, cacheHashIndex) + if err != nil { + return err + } + if didResolve { + t.root = newNode } + cacheHashIndex = newCacheHashIndex + } - // TODO (asyukii): check if the node is expired - // Attach node to parent - if _, ok := startNode.(hashNode); ok { - cachedHash := proof.cacheHashes[cacheHashIndex] - if bytes.Equal(cachedHash, startNode.(hashNode)) { - // Attach n1 to the trie - switch n := parent.(type) { - case *shortNode: - // TODO should copy node and parent point to new node - n.Val = nub.n1 - parent = n.Val - case *fullNode: - n.Children[childIndex] = nub.n1 - parent = n.Children[childIndex] - // TODO (asyukii): build shadow node - } + return nil + +} + +func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, cacheHashes [][]byte, cacheHashIndex int) (node, bool, int, error) { + if len(key) == 0 { + + if hashNode, ok := n.(hashNode); ok { + cachedHash := cacheHashes[cacheHashIndex] + if bytes.Equal(cachedHash, hashNode) { - // Attach n2 to the trie if exists if nub.n2 != nil { - switch n := parent.(type) { + switch n1 := nub.n1.(type) { case *shortNode: - n.Val = nub.n2 + n1.Val = nub.n2 default: - return fmt.Errorf("n2 should only be attached to a shortNode") + return nil, false, cacheHashIndex, fmt.Errorf("invalid node type") } - // TODO (asyukii): build shadow node if n2 is a fullNode + cacheHashIndex++ } + + cacheHashIndex++ + + return nub.n1, true, cacheHashIndex, nil } } - // Increment cacheHashIndex if nub.n1 != nil { cacheHashIndex++ } if nub.n2 != nil { cacheHashIndex++ } - } - return nil + return nil, false, cacheHashIndex, nil + } + switch n := n.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { + return nil, false, cacheHashIndex, fmt.Errorf("key %v not found", key) + } + newNode, didResolve, newCacheHashIndex, err := t.tryRevive(n.Val, key[len(n.Key):], nub, cacheHashes, cacheHashIndex) + if didResolve && err == nil { + n = n.copy() + n.Val = newNode + } + return n, didResolve, newCacheHashIndex, err + case *fullNode: + childIndex := int(key[0]) + newNode, didResolve, newCacheHashIndex, err := t.tryRevive(n.Children[childIndex], key[1:], nub, cacheHashes, cacheHashIndex) + if didResolve && err == nil { + n = n.copy() + n.Children[childIndex] = newNode + } + return n, didResolve, newCacheHashIndex, err + case hashNode: + tn, err := t.resolveHash(n, nil, 0) // TODO (asyukii): Revisit epoch index + if err != nil { + return nil, false, cacheHashIndex, err + } + return t.tryRevive(tn, key, nub, cacheHashes, cacheHashIndex) + case valueNode: + return nil, false, cacheHashIndex, nil + case nil: + return nil, false, cacheHashIndex, nil + default: + panic(fmt.Sprintf("invalid node: %T", n)) + } } diff --git a/trie/trie_test.go b/trie/trie_test.go index 72cd1dab59..28fff11368 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -563,7 +563,7 @@ func TestReviveTrie(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) // Verify value exists after revive @@ -658,8 +658,8 @@ func TestReviveBadProof(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trieA.ReviveTrie(proofCache) - assert.NoError(t, err) + err = trieA.TryRevive(proofCache) + assert.Error(t, err) // Verify value does exists after revive _, err = trieA.TryGet([]byte("abcd")) @@ -689,7 +689,7 @@ func TestReviveOneElement(t *testing.T) { err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) v := trie.Get(key) @@ -718,13 +718,13 @@ func TestReviveBadProofAfterUpdate(t *testing.T) { err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) trie.Update(key, []byte("new value")) // Revive again with old proof - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) // Validate trie @@ -762,7 +762,7 @@ func TestPartialReviveFullProof(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) // Validate trie @@ -824,7 +824,7 @@ func TestReviveValueAtFullNode(t *testing.T) { err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.ReviveTrie(proofCache) + err = trie.TryRevive(proofCache) assert.NoError(t, err) // Validate trie From b7a3629218af546e8275a3d76fc170ad77f687f5 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Tue, 18 Apr 2023 22:31:22 +0800 Subject: [PATCH 24/51] trie/shadow_node: add shadow node storage; fix: fix some compile error and ut failures; --- core/state/state_object.go | 8 ++- core/state_transition_test.go | 11 ++-- core/types/revive_state_tx.go | 3 +- core/types/revive_witness.go | 3 +- core/types/revive_witness_test.go | 5 +- core/types/state_epoch.go | 3 +- core/types/state_epoch_test.go | 5 +- core/types/transaction_test.go | 5 +- trie/dummy_trie.go | 1 + trie/errors.go | 2 +- trie/proof.go | 1 + trie/proof_test.go | 13 ++--- trie/shadow_node.go | 83 +++++++++++++++++++++++++++++++ trie/trie.go | 23 +++++---- trie/trie_test.go | 7 +-- 15 files changed, 135 insertions(+), 38 deletions(-) create mode 100644 trie/shadow_node.go diff --git a/core/state/state_object.go b/core/state/state_object.go index e4b44e538d..d3b37a2bed 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -19,12 +19,13 @@ package state import ( "bytes" "fmt" - "github.com/ethereum/go-ethereum/trie" "io" "math/big" "sync" "time" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -691,10 +692,7 @@ func (s *StateObject) Value() *big.Int { func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { dr := s.getDirtyReviveTrie(s.db.db) - if err := dr.ReviveTrie(proofCache); err != nil { - s.dirtyReviveTrie = nil - return err - } + dr.ReviveTrie(proofCache) s.db.journal.append(reviveStorageTrieNodeChange{ address: &s.address, }) diff --git a/core/state_transition_test.go b/core/state_transition_test.go index 6d399b84ef..b11d64bfad 100644 --- a/core/state_transition_test.go +++ b/core/state_transition_test.go @@ -2,11 +2,12 @@ package core import ( "bytes" + "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" - "testing" ) func keybytesToHex(str []byte) []byte { @@ -29,7 +30,7 @@ func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, p } proofList[i] = types.MPTProof{ RootKeyHex: keybytesToHex(bytes.Repeat([]byte{'k'}, keyLen)), - Proof: proof, + Proof: proof, } } wit := types.StorageTrieWitness{ @@ -81,7 +82,7 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: true, isHomestead: true, isEIP2028: true, - gas: 55176, + gas: 56792, }, { data: common.Hex2Bytes("1234567890"), @@ -92,7 +93,7 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: true, isHomestead: true, isEIP2028: true, - gas: 55252, + gas: 56868, }, { data: nil, @@ -104,7 +105,7 @@ func TestIntrinsicGas_WitnessList(t *testing.T) { isContractCreation: false, isHomestead: true, isEIP2028: true, - gas: 26412, + gas: 27804, }, } diff --git a/core/types/revive_state_tx.go b/core/types/revive_state_tx.go index c9b9f6af8d..4fac02a033 100644 --- a/core/types/revive_state_tx.go +++ b/core/types/revive_state_tx.go @@ -1,9 +1,10 @@ package types import ( + "math/big" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/params" - "math/big" ) type WitnessList []ReviveWitness diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go index 17672970d5..815df6f5fb 100644 --- a/core/types/revive_witness.go +++ b/core/types/revive_witness.go @@ -2,6 +2,7 @@ package types import ( "errors" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rlp" @@ -25,7 +26,7 @@ const ( // Attention: The proof could revive multi-vals, although it's a single trie path witness type MPTProof struct { RootKeyHex []byte // prefix key in nibbles format, max 65 bytes. TODO: optimize witness size - Proof [][]byte // list of RLP-encoded nodes + Proof [][]byte // list of RLP-encoded nodes } type StorageTrieWitness struct { diff --git a/core/types/revive_witness_test.go b/core/types/revive_witness_test.go index 255a73e5f7..025258b21f 100644 --- a/core/types/revive_witness_test.go +++ b/core/types/revive_witness_test.go @@ -2,10 +2,11 @@ package types import ( "bytes" + "testing" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" - "testing" ) func makeSimpleReviveWitness(witType byte, data []byte) ReviveWitness { @@ -35,7 +36,7 @@ func makeStorageTrieWitness(addr common.Address, proofCount int, proofLen ...int } proofList[i] = MPTProof{ RootKeyHex: nil, - Proof: proof, + Proof: proof, } } wit := StorageTrieWitness{ diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go index df19b4f300..310f26004c 100644 --- a/core/types/state_epoch.go +++ b/core/types/state_epoch.go @@ -1,9 +1,10 @@ package types import ( + "math/big" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/params" - "math/big" ) var ( diff --git a/core/types/state_epoch_test.go b/core/types/state_epoch_test.go index c5d988b0b6..0193240930 100644 --- a/core/types/state_epoch_test.go +++ b/core/types/state_epoch_test.go @@ -1,10 +1,11 @@ package types import ( - "github.com/ethereum/go-ethereum/params" - "github.com/stretchr/testify/assert" "math/big" "testing" + + "github.com/ethereum/go-ethereum/params" + "github.com/stretchr/testify/assert" ) func TestStateForkConfig(t *testing.T) { diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index cebdc56951..7949b4b87f 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -21,13 +21,14 @@ import ( "crypto/ecdsa" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" "math/big" "math/rand" "reflect" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" @@ -580,7 +581,7 @@ func TestReviveStateTxAndSigner(t *testing.T) { Address: addr, ProofList: []MPTProof{{ RootKeyHex: []byte{0x09, 0x5e, 0x7b, 0xae, 0xa6, 0xa6, 0xc7, 0xc4, 0xc2}, - Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, + Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, }}, } diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index a5e17ff386..46ae7666ad 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -18,6 +18,7 @@ package trie import ( "fmt" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" diff --git a/trie/errors.go b/trie/errors.go index 425f7154bc..57d281067d 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -41,5 +41,5 @@ type ExpiredNodeError struct { } func (err *ExpiredNodeError) Error() string { - return fmt.Sprintf("expired trie ndoe ") + return "expired trie node" } diff --git a/trie/proof.go b/trie/proof.go index 9c238a7278..55f66c883e 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -20,6 +20,7 @@ import ( "bytes" "errors" "fmt" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/common" diff --git a/trie/proof_test.go b/trie/proof_test.go index f18ecf8232..fcdcb2a89d 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -20,14 +20,15 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" - "github.com/ethereum/go-ethereum/core/types" - "github.com/stretchr/testify/assert" mrand "math/rand" "sort" "strings" "testing" "time" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" @@ -900,7 +901,7 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { - prefixKeys := getPrefixKeysHex(trie, []byte(kv.k)) + prefixKeys := getPrefixKeysHex(trie, kv.k) for _, prefixKey := range prefixKeys { proof := memorydb.New() key := kv.k @@ -912,7 +913,7 @@ func TestStorageProof(t *testing.T) { if err != nil { t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) } - if val != nil && !bytes.Equal(val, []byte(kv.v)) { + if val != nil && !bytes.Equal(val, kv.v) { t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) } } @@ -1002,10 +1003,10 @@ func TestBadStorageProof(t *testing.T) { trie, vals := randomTrie(500) for _, kv := range vals { - prefixKeys := getPrefixKeysHex(trie, []byte(kv.k)) + prefixKeys := getPrefixKeysHex(trie, kv.k) for _, prefixKey := range prefixKeys { proof := memorydb.New() - key := []byte(kv.k) + key := kv.k err := trie.ProveStorageWitness(key, prefixKey, proof) if err != nil { t.Fatalf("missing key %x while constructing proof", key) diff --git a/trie/shadow_node.go b/trie/shadow_node.go new file mode 100644 index 0000000000..66e37224ea --- /dev/null +++ b/trie/shadow_node.go @@ -0,0 +1,83 @@ +package trie + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" +) + +type ShadowNodeStorage interface { + // Get key is the shadow node prefix path + Get(key []byte) ([]byte, error) + Put(key []byte, val []byte) error + Commit(root common.Hash) error +} + +type ShadowNodeManager struct { + diskdb ethdb.KeyValueStore + // TODO diff layers + // TODO history states +} + +// NewShadowNodeManager TODO need reload diff layers and rebuild history metadata +func NewShadowNodeManager(diskdb ethdb.KeyValueStore) *ShadowNodeManager { + return &ShadowNodeManager{ + diskdb: diskdb, + } +} + +//// OpenStorage parentRoot is block root? or contract root ? later save block history? +//func (s *ShadowNodeManager) OpenStorage(parentRoot, addrHash common.Hash) ShadowNodeStorage { +// // TODO allow RW append on diff layer, only read plainState +// return &shadowNodeStorageReaderWriterMock{ +// s: s, +// parentRoot: parentRoot, +// addrHash: addrHash, +// } +//} + +//func (s *ShadowNodeManager) OpenHistoryStorage(blockAt uint64, addrHash common.Hash) ShadowNodeStorage { +// // TODO only allow read when access history +//} + +type shadowNodeStorageReaderWriterMock struct { + mockEpoch uint16 + nodeMap map[string][]byte +} + +func newShadowNodeStorageMock(epoch uint16) ShadowNodeStorage { + return &shadowNodeStorageReaderWriterMock{ + mockEpoch: epoch, + nodeMap: make(map[string][]byte), + } +} + +func (s *shadowNodeStorageReaderWriterMock) Get(key []byte) ([]byte, error) { + var err error + tmp := string(key) + val, ok := s.nodeMap[tmp] + if !ok { + n := shadowBranchNode{ + ShadowHash: nil, + EpochMap: [16]uint16{}, + } + for i := range n.EpochMap { + n.EpochMap[i] = s.mockEpoch + } + val, err = rlp.EncodeToBytes(n) + if err != nil { + return nil, err + } + s.nodeMap[tmp] = val + } + return val, nil +} + +func (s *shadowNodeStorageReaderWriterMock) Put(key []byte, val []byte) error { + s.nodeMap[string(key)] = val + return nil +} + +func (s *shadowNodeStorageReaderWriterMock) Commit(root common.Hash) error { + return nil +} diff --git a/trie/trie.go b/trie/trie.go index f3e8a6ac41..55c64833ab 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -61,6 +61,7 @@ type LeafCallback func(paths [][]byte, hexpath []byte, leaf []byte, parent commo // Trie is not safe for concurrent use. type Trie struct { db *Database + sndb ShadowNodeStorage // only storage trie using it, account trie needn't the shadow node root node // Keep track of the number leafs which have been inserted since the last // hashing operation. This number will not directly map to the number of @@ -84,15 +85,16 @@ func New(root common.Hash, db *Database) (*Trie, error) { panic("trie.New called without a database") } trie := &Trie{ - db: db, + db: db, + sndb: newShadowNodeStorageMock(0), } - epoch := uint16(0) + //epoch := uint16(0) if rootNode := db.RootNode(root); rootNode != nil { root = rootNode.TrieHash - epoch = rootNode.Epoch + //epoch = rootNode.Epoch } if root != (common.Hash{}) && root != emptyRoot { - rootnode, err := trie.resolveHash(root[:], nil, epoch) + rootnode, err := trie.resolveHash(root[:], nil) if err != nil { return nil, err } @@ -121,7 +123,11 @@ func (t *Trie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryGet(key []byte) ([]byte, error) { - value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0, t.root.getEpoch()) + var nextEpoch uint16 + if t.root != nil { + nextEpoch = t.root.getEpoch() + } + value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0, nextEpoch) if err == nil && didResolve { t.root = newroot } @@ -156,7 +162,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int, epoch uint16) (value [ } return value, n, didResolve, err case hashNode: - child, err := t.resolveHash(n, key[:pos], epoch) + child, err := t.resolveHash(n, key[:pos]) if err != nil { return nil, n, true, err } @@ -574,13 +580,12 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { return n, nil } -func (t *Trie) resolveHash(n hashNode, prefix []byte, epoch uint16) (node, error) { +func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { hash := common.BytesToHash(n) if t.db == nil { return nil, fmt.Errorf("empty trie database") } if node := t.db.node(hash); node != nil { - node.setEpoch(epoch) return node, nil } return nil, &MissingNodeError{NodeHash: hash, Path: prefix} @@ -741,7 +746,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, cacheHashes [][]by } return n, didResolve, newCacheHashIndex, err case hashNode: - tn, err := t.resolveHash(n, nil, 0) // TODO (asyukii): Revisit epoch index + tn, err := t.resolveHash(n, nil) // TODO (asyukii): Revisit epoch index if err != nil { return nil, false, cacheHashIndex, err } diff --git a/trie/trie_test.go b/trie/trie_test.go index 28fff11368..da658f1f81 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -19,6 +19,7 @@ package trie import ( "bytes" "encoding/binary" + // "encoding/hex" "errors" "fmt" @@ -543,8 +544,8 @@ func TestReviveTrie(t *testing.T) { oriRootHash := trie.Hash() for _, kv := range vals { - key := []byte(kv.k) - val := []byte(kv.v) + key := kv.k + val := kv.v prefixKeys := getPrefixKeysHex(trie, key) for _, prefixKey := range prefixKeys { // Generate proof @@ -701,7 +702,7 @@ func TestReviveOneElement(t *testing.T) { func TestReviveBadProofAfterUpdate(t *testing.T) { trie, vals := nonRandomTrie(500) for _, kv := range vals { - key := []byte(kv.k) + key := kv.k prefixKeys := getPrefixKeysHex(trie, key) for _, prefixKey := range prefixKeys { var proof proofList From 2b40f3f838f9bff3e656d81f4e5ec96dd9390648 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:18:11 +0800 Subject: [PATCH 25/51] ci: add other develop ci trigger; --- .github/workflows/build-test.yml | 2 ++ .github/workflows/commit-lint.yml | 2 ++ .github/workflows/integration-test.yml | 2 ++ .github/workflows/lint.yml | 2 ++ .github/workflows/unit-test.yml | 2 ++ 5 files changed, 10 insertions(+) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 43d2de9356..6fa80373cf 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: unit-test: diff --git a/.github/workflows/commit-lint.yml b/.github/workflows/commit-lint.yml index 7df13ebcac..0680a52f79 100644 --- a/.github/workflows/commit-lint.yml +++ b/.github/workflows/commit-lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: commitlint: diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 234cbfda5b..fb66659f44 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: truffle-test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ebe646dfe6..476665a660 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: golang-lint: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index d692c2f0d8..809bc12f82 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: unit-test: From 1a64833392917c09836d373ee62953b55413a1d9 Mon Sep 17 00:00:00 2001 From: asyukii Date: Fri, 21 Apr 2023 11:32:46 +0800 Subject: [PATCH 26/51] refactor(trie): track successful nubs for ReviveTrie add TryRevive again refactor: add getter for MPTProofNub --- core/state/database.go | 2 +- core/state/state_object.go | 2 +- light/trie.go | 4 +-- trie/dummy_trie.go | 4 ++- trie/proof.go | 4 +++ trie/secure_trie.go | 4 +-- trie/trie.go | 64 ++++++++++++++++---------------------- trie/trie_test.go | 35 ++++++++++----------- 8 files changed, 56 insertions(+), 63 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 037be8841a..46da90c954 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -129,7 +129,7 @@ type Trie interface { ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error - ReviveTrie(trie.MPTProofCache) + ReviveTrie(proof []*trie.MPTProofNub) []*trie.MPTProofNub } // NewDatabase creates a backing store for state. The returned database is safe for diff --git a/core/state/state_object.go b/core/state/state_object.go index d3b37a2bed..d353d2e507 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -692,7 +692,7 @@ func (s *StateObject) Value() *big.Int { func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { dr := s.getDirtyReviveTrie(s.db.db) - dr.ReviveTrie(proofCache) + dr.ReviveTrie(proofCache.CacheNubs()) s.db.journal.append(reviveStorageTrieNodeChange{ address: &s.address, }) diff --git a/light/trie.go b/light/trie.go index 23cc60919b..f7962e6ad6 100644 --- a/light/trie.go +++ b/light/trie.go @@ -202,8 +202,8 @@ func (db *odrTrie) NoTries() bool { return false } -func (t *odrTrie) ReviveTrie(proof trie.MPTProofCache) { - t.trie.ReviveTrie(proof) +func (t *odrTrie) ReviveTrie(proof []*trie.MPTProofNub) []*trie.MPTProofNub { + return t.trie.ReviveTrie(proof) } type nodeIterator struct { diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 46ae7666ad..ac021ed643 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -104,4 +104,6 @@ func (t *EmptyTrie) TryUpdateAccount(key []byte, account *types.StateAccount) er return nil } -func (t *EmptyTrie) ReviveTrie(proof MPTProofCache) {} +func (t *EmptyTrie) ReviveTrie(proof []*MPTProofNub) []*MPTProofNub { + return nil +} diff --git a/trie/proof.go b/trie/proof.go index 55f66c883e..357f96a0a7 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -252,6 +252,10 @@ func (m *MPTProofCache) VerifyProof() error { return nil } +func (m *MPTProofCache) CacheNubs() []*MPTProofNub { + return m.cacheNubs +} + func needMergeNextNode(nodes []node, i int) bool { if i >= len(nodes) || i+1 >= len(nodes) { return false diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 18e5548b5b..fa74b8b15c 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -229,6 +229,6 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte { return t.secKeyCache } -func (t *SecureTrie) ReviveTrie(proof MPTProofCache) { - t.trie.ReviveTrie(proof) +func (t *SecureTrie) ReviveTrie(proof []*MPTProofNub) []*MPTProofNub { + return t.trie.ReviveTrie(proof) } diff --git a/trie/trie.go b/trie/trie.go index 55c64833ab..dfe2f0d660 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -669,36 +669,36 @@ func (t *Trie) Size() int { return estimateSize(t.root) } -func (t *Trie) ReviveTrie(proof MPTProofCache) { - if err := t.TryRevive(proof); err != nil { +func (t *Trie) ReviveTrie(proof []*MPTProofNub) (successNubs []*MPTProofNub) { + successNubs, err := t.TryRevive(proof) + if err != nil { log.Error(fmt.Sprintf("Failed to revive trie: %v", err)) } + return successNubs } -func (t *Trie) TryRevive(proof MPTProofCache) error { - - cacheHashIndex := 0 +func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err error) { - for _, nub := range proof.cacheNubs { - newNode, didResolve, newCacheHashIndex, err := t.tryRevive(t.root, nub.RootHexKey, *nub, proof.cacheHashes, cacheHashIndex) - if err != nil { - return err - } - if didResolve { + for _, nub := range proof { + newNode, didResolve, err := t.tryRevive(t.root, nub.RootHexKey, *nub) + if didResolve && err == nil { + successNubs = append(successNubs, nub) t.root = newNode } - cacheHashIndex = newCacheHashIndex } - return nil + if len(successNubs) == 0 && len(proof) != 0 { + return successNubs, fmt.Errorf("all nubs failed to revive trie") + } + return successNubs, nil } -func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, cacheHashes [][]byte, cacheHashIndex int) (node, bool, int, error) { +func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub) (node, bool, error) { if len(key) == 0 { if hashNode, ok := n.(hashNode); ok { - cachedHash := cacheHashes[cacheHashIndex] + cachedHash, _ := nub.n1.cache() if bytes.Equal(cachedHash, hashNode) { if nub.n2 != nil { @@ -706,55 +706,45 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, cacheHashes [][]by case *shortNode: n1.Val = nub.n2 default: - return nil, false, cacheHashIndex, fmt.Errorf("invalid node type") + return nil, false, fmt.Errorf("invalid node type") } - cacheHashIndex++ } - cacheHashIndex++ - - return nub.n1, true, cacheHashIndex, nil + return nub.n1, true, nil } } - if nub.n1 != nil { - cacheHashIndex++ - } - if nub.n2 != nil { - cacheHashIndex++ - } - - return nil, false, cacheHashIndex, nil + return nil, false, nil } switch n := n.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { - return nil, false, cacheHashIndex, fmt.Errorf("key %v not found", key) + return nil, false, fmt.Errorf("key %v not found", key) } - newNode, didResolve, newCacheHashIndex, err := t.tryRevive(n.Val, key[len(n.Key):], nub, cacheHashes, cacheHashIndex) + newNode, didResolve, err := t.tryRevive(n.Val, key[len(n.Key):], nub) if didResolve && err == nil { n = n.copy() n.Val = newNode } - return n, didResolve, newCacheHashIndex, err + return n, didResolve, err case *fullNode: childIndex := int(key[0]) - newNode, didResolve, newCacheHashIndex, err := t.tryRevive(n.Children[childIndex], key[1:], nub, cacheHashes, cacheHashIndex) + newNode, didResolve, err := t.tryRevive(n.Children[childIndex], key[1:], nub) if didResolve && err == nil { n = n.copy() n.Children[childIndex] = newNode } - return n, didResolve, newCacheHashIndex, err + return n, didResolve, err case hashNode: tn, err := t.resolveHash(n, nil) // TODO (asyukii): Revisit epoch index if err != nil { - return nil, false, cacheHashIndex, err + return nil, false, err } - return t.tryRevive(tn, key, nub, cacheHashes, cacheHashIndex) + return t.tryRevive(tn, key, nub) case valueNode: - return nil, false, cacheHashIndex, nil + return nil, false, nil case nil: - return nil, false, cacheHashIndex, nil + return nil, false, nil default: panic(fmt.Sprintf("invalid node: %T", n)) } diff --git a/trie/trie_test.go b/trie/trie_test.go index da658f1f81..271f1b376c 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -536,8 +536,8 @@ func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { } } -// TestReviveTrie tests that a trie can be revived from a proof -func TestReviveTrie(t *testing.T) { +// TestTryRevive tests that a trie can be revived from a proof +func TestTryRevive(t *testing.T) { trie, vals := nonRandomTrie(500) @@ -564,7 +564,7 @@ func TestReviveTrie(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trie.TryRevive(proofCache) + _, err = trie.TryRevive(proofCache.cacheNubs) assert.NoError(t, err) // Verify value exists after revive @@ -583,7 +583,7 @@ func TestReviveTrie(t *testing.T) { // TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when // the whole trie is expired. This test will fail because the parent node in -// ReviveTrie is nil, set to RootNode when available +// TryRevive is nil, set to RootNode when available // func TestReviveAtRoot(t *testing.T) { // trie, vals := nonRandomTrie(500) @@ -610,7 +610,7 @@ func TestReviveTrie(t *testing.T) { // assert.NoError(t, err) // // Revive trie -// err = trie.ReviveTrie(proofCache) +// err = trie.TryRevive(proofCache) // assert.NoError(t, err) // // Verify value exists after revive @@ -659,7 +659,7 @@ func TestReviveBadProof(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trieA.TryRevive(proofCache) + _, err = trieA.TryRevive(proofCache.cacheNubs) assert.Error(t, err) // Verify value does exists after revive @@ -668,9 +668,9 @@ func TestReviveBadProof(t *testing.T) { } -// TestReviveOneElement tests that a trie with a single element -// can be revived from a proof -func TestReviveOneElement(t *testing.T) { +// TestReviveAlreadyExists tests that a path cannot be revived +// again if it already exists +func TestReviveAlreadyExists(t *testing.T) { trie := new(Trie) key := []byte("k") val := []byte("v") @@ -682,16 +682,13 @@ func TestReviveOneElement(t *testing.T) { err := trie.ProveStorageWitness(key, nil, &proof) assert.NoError(t, err) - err = trie.ExpireByPrefix(nil) - assert.NoError(t, err) - proofCache := makeRawMPTProofCache(nil, proof) err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.TryRevive(proofCache) - assert.NoError(t, err) + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.Error(t, err) v := trie.Get(key) assert.Equal(t, val, v) @@ -719,14 +716,14 @@ func TestReviveBadProofAfterUpdate(t *testing.T) { err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.TryRevive(proofCache) + // Revive first + _, err = trie.TryRevive(proofCache.cacheNubs) assert.NoError(t, err) trie.Update(key, []byte("new value")) // Revive again with old proof - err = trie.TryRevive(proofCache) - assert.NoError(t, err) + trie.TryRevive(proofCache.cacheNubs) // Validate trie resVal, err := trie.TryGet(key) @@ -763,7 +760,7 @@ func TestPartialReviveFullProof(t *testing.T) { assert.NoError(t, err) // Revive trie - err = trie.TryRevive(proofCache) + _, err = trie.TryRevive(proofCache.cacheNubs) assert.NoError(t, err) // Validate trie @@ -825,7 +822,7 @@ func TestReviveValueAtFullNode(t *testing.T) { err = proofCache.VerifyProof() assert.NoError(t, err) - err = trie.TryRevive(proofCache) + _, err = trie.TryRevive(proofCache.cacheNubs) assert.NoError(t, err) // Validate trie From c67dbe7f96ed929581cae9186602538dc2071939 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Thu, 20 Apr 2023 23:11:23 +0800 Subject: [PATCH 27/51] StateDB: opt init with state epoch, add getState, setState error handle; StateEpoch: add new epoch type, add expired utility function; StateObject: support complete state R&W, opt revive trie cache & state; Snapshot: support parse state epoch; EVM: opt sLoad & sStore error handle, add EVM error collection; --- accounts/abi/bind/backends/simulated.go | 14 +- cmd/evm/internal/t8ntool/execution.go | 10 +- cmd/evm/runner.go | 6 +- cmd/geth/chaincmd.go | 20 ++- cmd/geth/snapshot.go | 2 +- core/blockchain.go | 7 +- core/blockchain_reader.go | 7 +- core/blockchain_test.go | 23 ++- core/chain_makers.go | 3 +- core/genesis.go | 2 +- core/state/errors.go | 53 ++++++ core/state/snapshot/snapshot.go | 40 +++++ core/state/state_object.go | 209 ++++++++++++++++-------- core/state/state_test.go | 12 +- core/state/statedb.go | 35 ++-- core/state/statedb_test.go | 42 ++--- core/state_processor.go | 2 +- core/tx_pool.go | 4 +- core/tx_pool_test.go | 2 +- core/types/state_epoch.go | 18 +- core/types/state_epoch_test.go | 36 ++-- core/vm/errors.go | 22 +++ core/vm/evm.go | 12 ++ core/vm/gas_table.go | 24 ++- core/vm/instructions.go | 5 +- core/vm/interface.go | 6 +- core/vm/interpreter.go | 9 + core/vm/operations_acl.go | 16 +- eth/api.go | 6 +- eth/api_backend.go | 4 +- eth/catalyst/api_test.go | 4 +- eth/protocols/eth/handler_test.go | 6 +- eth/state_accessor.go | 10 +- eth/tracers/api_test.go | 4 +- eth/tracers/js/goja.go | 8 +- eth/tracers/logger/logger.go | 6 +- eth/tracers/logger/logger_test.go | 8 +- eth/tracers/native/prestate.go | 7 +- graphql/graphql.go | 2 +- internal/ethapi/api.go | 17 +- les/api_backend.go | 4 +- les/odr_test.go | 4 +- les/state_accessor.go | 2 +- light/odr_test.go | 4 +- light/trie.go | 6 +- light/txpool.go | 2 +- miner/miner.go | 2 +- miner/miner_test.go | 3 +- miner/worker.go | 3 +- trie/errors.go | 4 +- 50 files changed, 522 insertions(+), 235 deletions(-) create mode 100644 core/state/errors.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 434fac8d6a..a9a4981559 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -130,7 +130,8 @@ func (b *SimulatedBackend) rollback(parent *types.Block) { blocks, _ := core.GenerateChain(b.config, parent, ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {}) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.blockchain.StateCache(), nil) + blockNum := new(big.Int).Add(parent.Number(), common.Big1) + b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), b.blockchain.StateCache(), nil, types.GetStateEpoch(b.config, blockNum)) } // Fork creates a side-chain that can be used to simulate reorgs. @@ -169,7 +170,7 @@ func (b *SimulatedBackend) stateByBlockNumber(ctx context.Context, blockNumber * if err != nil { return nil, err } - return b.blockchain.StateAt(block.Root()) + return b.blockchain.StateAt(block.Root(), block.Number()) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -221,7 +222,10 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres return nil, err } - val := stateDB.GetState(contract, key) + val, err := stateDB.GetState(contract, key) + if err != nil { + return nil, err + } return val[:], nil } @@ -672,7 +676,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil) + b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), stateDB.Database(), nil, types.GetStateEpoch(b.config, b.pendingBlock.Number())) b.pendingReceipts = receipts[0] return nil } @@ -788,7 +792,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil) + b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), stateDB.Database(), nil, types.GetStateEpoch(b.config, b.pendingBlock.Number())) return nil } diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 56d6a9b5ff..0e2fc0b255 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -116,7 +116,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, return h } var ( - statedb = MakePreState(rawdb.NewMemoryDatabase(), pre.Pre) + statedb = MakePreState(rawdb.NewMemoryDatabase(), pre, chainConfig) signer = types.MakeSigner(chainConfig, new(big.Int).SetUint64(pre.Env.Number)) gaspool = new(core.GasPool) blockHash = common.Hash{0x13, 0x37} @@ -270,10 +270,10 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, return statedb, execRs, nil } -func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB { +func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) *state.StateDB { sdb := state.NewDatabase(db) - statedb, _ := state.New(common.Hash{}, sdb, nil) - for addr, a := range accounts { + statedb, _ := state.NewWithEpoch(common.Hash{}, sdb, nil, types.GetStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1))) + for addr, a := range pre.Pre { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) statedb.SetBalance(addr, a.Balance) @@ -285,7 +285,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB statedb.Finalise(false) statedb.AccountsIntermediateRoot() root, _, _ := statedb.Commit(nil) - statedb, _ = state.New(root, sdb, nil) + statedb, _ = state.NewWithEpoch(root, sdb, nil, types.GetStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number))) return statedb } diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index d57602f8d5..4964e41a14 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -28,6 +28,8 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/cmd/evm/internal/compiler" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" @@ -138,10 +140,10 @@ func runCmd(ctx *cli.Context) error { genesisConfig = gen db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) - statedb, _ = state.New(genesis.Root(), state.NewDatabase(db), nil) chainConfig = gen.Config + statedb, _ = state.NewWithEpoch(genesis.Root(), state.NewDatabase(db), nil, types.GetStateEpoch(chainConfig, genesis.Number())) } else { - statedb, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) + statedb, _ = state.NewWithEpoch(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil, types.StateEpoch0) genesisConfig = new(core.Genesis) } if ctx.GlobalString(SenderFlag.Name) != "" { diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index d2b44a4565..be3a9fe750 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "math/big" "net" "os" "path" @@ -495,11 +496,11 @@ func exportPreimages(ctx *cli.Context) error { return nil } -func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, ethdb.Database, common.Hash, error) { +func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, ethdb.Database, common.Hash, *big.Int, error) { db := utils.MakeChainDatabase(ctx, stack, true, false) var header *types.Header if ctx.NArg() > 1 { - return nil, nil, common.Hash{}, fmt.Errorf("expected 1 argument (number or hash), got %d", ctx.NArg()) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("expected 1 argument (number or hash), got %d", ctx.NArg()) } if ctx.NArg() == 1 { arg := ctx.Args().First() @@ -508,17 +509,17 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth if number := rawdb.ReadHeaderNumber(db, hash); number != nil { header = rawdb.ReadHeader(db, hash, *number) } else { - return nil, nil, common.Hash{}, fmt.Errorf("block %x not found", hash) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("block %x not found", hash) } } else { number, err := strconv.Atoi(arg) if err != nil { - return nil, nil, common.Hash{}, err + return nil, nil, common.Hash{}, common.Big0, err } if hash := rawdb.ReadCanonicalHash(db, uint64(number)); hash != (common.Hash{}) { header = rawdb.ReadHeader(db, hash, uint64(number)) } else { - return nil, nil, common.Hash{}, fmt.Errorf("header for block %d not found", number) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("header for block %d not found", number) } } } else { @@ -526,7 +527,7 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth header = rawdb.ReadHeadHeader(db) } if header == nil { - return nil, nil, common.Hash{}, errors.New("no head block found") + return nil, nil, common.Hash{}, common.Big0, errors.New("no head block found") } startArg := common.FromHex(ctx.String(utils.StartKeyFlag.Name)) var start common.Hash @@ -538,7 +539,7 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth start = crypto.Keccak256Hash(startArg) log.Info("Converting start-address to hash", "address", common.BytesToAddress(startArg), "hash", start.Hex()) default: - return nil, nil, common.Hash{}, fmt.Errorf("invalid start argument: %x. 20 or 32 hex-encoded bytes required", startArg) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("invalid start argument: %x. 20 or 32 hex-encoded bytes required", startArg) } var conf = &state.DumpConfig{ SkipCode: ctx.Bool(utils.ExcludeCodeFlag.Name), @@ -550,17 +551,18 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth log.Info("State dump configured", "block", header.Number, "hash", header.Hash().Hex(), "skipcode", conf.SkipCode, "skipstorage", conf.SkipStorage, "start", hexutil.Encode(conf.Start), "limit", conf.Max) - return conf, db, header.Root, nil + return conf, db, header.Root, header.Number, nil } func dump(ctx *cli.Context) error { stack, _ := makeConfigNode(ctx) defer stack.Close() - conf, db, root, err := parseDumpConfig(ctx, stack) + conf, db, root, _, err := parseDumpConfig(ctx, stack) if err != nil { return err } + // TODO unknown chainconfig, default using epoch0 state, err := state.New(root, state.NewDatabase(db), nil) if err != nil { return err diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 1c92da95f5..c8ce3bff50 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -728,7 +728,7 @@ func dumpState(ctx *cli.Context) error { stack, _ := makeConfigNode(ctx) defer stack.Close() - conf, db, root, err := parseDumpConfig(ctx, stack) + conf, db, root, _, err := parseDumpConfig(ctx, stack) if err != nil { return err } diff --git a/core/blockchain.go b/core/blockchain.go index b119270b99..9908c02c5b 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -364,7 +364,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par // Make sure the state associated with the block is available head := bc.CurrentBlock() - if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { + if _, err := state.NewWithEpoch(head.Root(), bc.stateCache, bc.snaps, types.GetStateEpoch(chainConfig, head.Number())); err != nil { // Head state is missing, before the state recovery, find out the // disk layer point of snapshot(if it's enabled). Make sure the // rewound point is lower than disk layer. @@ -715,7 +715,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo enoughBeyondCount = beyondCount > maxBeyondBlocks - if _, err := state.New(newHeadBlock.Root(), bc.stateCache, bc.snaps); err != nil { + if _, err := state.NewWithEpoch(newHeadBlock.Root(), bc.stateCache, bc.snaps, types.GetStateEpoch(bc.chainConfig, header.Number)); err != nil { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) @@ -1882,7 +1882,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) if parent == nil { parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1) } - statedb, err := state.NewWithSharedPool(parent.Root, bc.stateCache, bc.snaps) + statedb, err := state.NewWithEpoch(parent.Root, bc.stateCache, bc.snaps, + types.GetStateEpoch(bc.chainConfig, block.Number())) if err != nil { return it.index, err } diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index b28661da3e..3426bebd77 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -308,12 +308,13 @@ func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) { // State returns a new mutable state based on the current HEAD block. func (bc *BlockChain) State() (*state.StateDB, error) { - return bc.StateAt(bc.CurrentBlock().Root()) + block := bc.CurrentBlock() + return bc.StateAt(block.Root(), block.Number()) } // StateAt returns a new mutable state based on a particular point in time. -func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return state.New(root, bc.stateCache, bc.snaps) +func (bc *BlockChain) StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) { + return state.NewWithEpoch(root, bc.stateCache, bc.snaps, types.GetStateEpoch(bc.chainConfig, number)) } // Config retrieves the chain's fork configuration. diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 20615eef8c..55e39133a4 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -3178,21 +3178,26 @@ func TestDeleteRecreateSlots(t *testing.T) { statedb, _ := chain.State() // If all is correct, then slot 1 and 2 are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } // Also, 3 and 4 should be set - if got, exp := statedb.GetState(aa, common.HexToHash("03")), common.HexToHash("03"); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("03")), common.HexToHash("03"); got != exp { t.Fatalf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("04")), common.HexToHash("04"); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("04")), common.HexToHash("04"); got != exp { t.Fatalf("got %x exp %x", got, exp) } } +func getStateIgnoreErr(statedb *state.StateDB, addr common.Address, hash common.Hash) common.Hash { + val, _ := statedb.GetState(addr, hash) + return val +} + // TestDeleteRecreateAccount tests a state-transition that contains deletion of a // contract with storage, and a recreate of the same contract via a // regular value-transfer @@ -3258,10 +3263,10 @@ func TestDeleteRecreateAccount(t *testing.T) { statedb, _ := chain.State() // If all is correct, then both slots are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } } @@ -3435,10 +3440,10 @@ func TestDeleteRecreateSlotsAcrossManyBlocks(t *testing.T) { } statedb, _ := chain.State() // If all is correct, then slot 1 and 2 are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("block %d, got %x exp %x", blockNum, got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("block %d, got %x exp %x", blockNum, got, exp) } exp := expectations[i] @@ -3447,7 +3452,7 @@ func TestDeleteRecreateSlotsAcrossManyBlocks(t *testing.T) { t.Fatalf("block %d, expected %v to exist, it did not", blockNum, aa) } for slot, val := range exp.values { - if gotValue, expValue := statedb.GetState(aa, asHash(slot)), asHash(val); gotValue != expValue { + if gotValue, expValue := getStateIgnoreErr(statedb, aa, asHash(slot)), asHash(val); gotValue != expValue { t.Fatalf("block %d, slot %d, got %x exp %x", blockNum, slot, gotValue, expValue) } } diff --git a/core/chain_makers.go b/core/chain_makers.go index c53269ed6e..e1e9aeec7c 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -275,7 +275,8 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse return nil, nil } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil) + number := new(big.Int).Add(parent.Number(), common.Big1) + statedb, err := state.NewWithEpoch(parent.Root(), state.NewDatabase(db), nil, types.GetStateEpoch(config, number)) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index 75bd357cd4..bc29e68757 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -182,7 +182,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override // We have the genesis block in database(perhaps in ancient database) // but the corresponding state is missing. header := rawdb.ReadHeader(db, stored, 0) - if _, err := state.New(header.Root, state.NewDatabaseWithConfigAndCache(db, nil), nil); err != nil { + if _, err := state.NewWithEpoch(header.Root, state.NewDatabaseWithConfigAndCache(db, nil), nil, types.StateEpoch0); err != nil { if genesis == nil { genesis = DefaultGenesisBlock() } diff --git a/core/state/errors.go b/core/state/errors.go new file mode 100644 index 0000000000..66f0f1dbd9 --- /dev/null +++ b/core/state/errors.go @@ -0,0 +1,53 @@ +package state + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/trie" +) + +// ExpiredStateError Access State error, must revert the execution +type ExpiredStateError struct { + Addr common.Address + Key common.Hash + Path []byte + Epoch types.StateEpoch + isInsert bool // when true it through expired path, must recovery the expired path +} + +func NewPlainExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: key, + Path: []byte{}, + Epoch: epoch, + isInsert: false, + } +} + +func NewExpiredStateError(addr common.Address, err *trie.ExpiredNodeError) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: common.Hash{}, + Path: err.Path, + Epoch: err.Epoch, + isInsert: false, + } +} + +func NewInsertExpiredStateError(addr common.Address, err *trie.ExpiredNodeError) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: common.Hash{}, + Path: err.Path, + Epoch: err.Epoch, + isInsert: true, + } +} + +func (e *ExpiredStateError) Error() string { + if e.isInsert { + return "Insert state through expired path" + } + return "Access expired state" +} diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 2f13631607..592564ce27 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -24,6 +24,8 @@ import ( "sync" "sync/atomic" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" @@ -182,6 +184,44 @@ type Tree struct { onFlatten func() // Hook invoked when the bottom most diff layers are flattened } +// SnapValue store snap with Epoch after BEP-216 +type SnapValue struct { + Epoch types.StateEpoch + Val common.Hash +} + +func ParseSnapValFromBytes(enc []byte) (*SnapValue, error) { + k, content, _, err := rlp.Split(enc) + if err != nil { + return nil, err + } + if k != rlp.List { + val := common.Hash{} + val.SetBytes(content) + return &SnapValue{ + Epoch: 0, + Val: val, + }, nil + } + var val SnapValue + if err = rlp.DecodeBytes(enc, &val); err != nil { + return nil, err + } + return &val, nil +} + +func NewSnapValBytes(epoch types.StateEpoch, val common.Hash) ([]byte, error) { + snapVal := SnapValue{ + Epoch: epoch, + Val: val, + } + enc, err := rlp.EncodeToBytes(&snapVal) + if err != nil { + return nil, err + } + return enc, nil +} + // New attempts to load an already existing snapshot from a persistent key-value // store (with a number of memory layers from a journal), ensuring that the head // of the snapshot matches the expected one. diff --git a/core/state/state_object.go b/core/state/state_object.go index d353d2e507..7305e705bc 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/core/state/snapshot" + "github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/common" @@ -84,6 +86,7 @@ type StateObject struct { trie Trie // storage trie, which becomes non-nil on first access, it's committed trie code Code // contract bytecode, which gets set when code is loaded + // TODO(0xbundler) Attention: it's a shared storage between stateDBs, will expired, just disable sharedOriginStorage now sharedOriginStorage *sync.Map // Point to the entry of the stateObject in sharedPool originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction @@ -95,12 +98,13 @@ type StateObject struct { pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later dirtyReviveTrie Trie // dirtyReviveTrie for tx - // TODO when R&W, access revive state first + // when R&W, access revive state first pendingReviveState Storage // pendingReviveState for block, it cannot flush to trie, just cache dirtyReviveState Storage // dirtyReviveState for tx, for cache dirtyReviveTrie + // accessed state, don't record revive state, don't record nonexist state or any err access pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed, it will update epoch index late - dirtyAccessedState map[common.Hash]int // dirtyAccessedState record which state is accessed, it will update epoch index later, attention: don't record revive state + dirtyAccessedState map[common.Hash]int // dirtyAccessedState record which state is accessed, it will update epoch index later // Cache flags. // When an object is marked suicided it will be delete from the trie @@ -111,6 +115,7 @@ type StateObject struct { //encode encodeData []byte + Epoch types.StateEpoch } // empty returns whether the account is considered empty. @@ -136,16 +141,19 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *St } return &StateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - sharedOriginStorage: storageMap, - originStorage: make(Storage), - pendingStorage: make(Storage), - dirtyStorage: make(Storage), - dirtyReviveState: make(Storage), - pendingReviveState: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + sharedOriginStorage: storageMap, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), + dirtyReviveState: make(Storage), + pendingReviveState: make(Storage), + dirtyAccessedState: make(map[common.Hash]int), + pendingAccessedState: make(map[common.Hash]int), + Epoch: db.Epoch, } } @@ -213,18 +221,27 @@ func (s *StateObject) getDirtyReviveTrie(db Database) Trie { } // GetState retrieves a value from the account storage trie. -func (s *StateObject) GetState(db Database, key common.Hash) common.Hash { +func (s *StateObject) GetState(db Database, key common.Hash) (common.Hash, error) { // If the fake storage is set, only lookup the state here(in the debugging mode) if s.fakeStorage != nil { - return s.fakeStorage[key] + return s.fakeStorage[key], nil } // If we have a dirty value for this state entry, return it - value, dirty := s.dirtyStorage[key] - if dirty { - return value + if value, dirty := s.dirtyStorage[key]; dirty { + s.accessState(key) + return value, nil + } + if revived, revive := s.dirtyReviveState[key]; revive { + s.accessState(key) + return revived, nil } + // Otherwise return the entry's original value - return s.GetCommittedState(db, key) + committed, err := s.GetCommittedState(db, key) + if err == nil && committed != (common.Hash{}) { + s.accessState(key) + } + return committed, err } func (s *StateObject) getOriginStorage(key common.Hash) (common.Hash, bool) { @@ -252,24 +269,23 @@ func (s *StateObject) setOriginStorage(key common.Hash, value common.Hash) { } // GetCommittedState retrieves a value from the committed account storage trie. -func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Hash { +func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Hash, error) { // If the fake storage is set, only lookup the state here(in the debugging mode) if s.fakeStorage != nil { - return s.fakeStorage[key] + return s.fakeStorage[key], nil } // If we have a pending write or clean cached, return that if value, pending := s.pendingStorage[key]; pending { - return value + return value, nil + } + if revived, revive := s.pendingReviveState[key]; revive { + return revived, nil } if value, cached := s.getOriginStorage(key); cached { - return value + return value, nil } // If no live objects are available, attempt to use snapshots - var ( - enc []byte - err error - ) if s.db.snap != nil { // If the object was destructed in *this* block (and potentially resurrected), // the storage has been cleared out, and we should *not* consult the previous @@ -278,29 +294,38 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Has // have been handles via pendingStorage above. // 2) we don't have new values, and can deliver empty response back if _, destructed := s.db.snapDestructs[s.address]; destructed { - return common.Hash{} + return common.Hash{}, nil } start := time.Now() - enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + enc, err := s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) if metrics.EnabledExpensive { s.db.SnapshotStorageReads += time.Since(start) } + if err == nil { + if snapVal, err := snapshot.ParseSnapValFromBytes(enc); err == nil { + if types.EpochExpired(snapVal.Epoch, s.Epoch) { + return common.Hash{}, NewPlainExpiredStateError(s.address, key, snapVal.Epoch) + } + return snapVal.Val, nil + } + } } // If snapshot unavailable or reading from it failed, load from the database - if s.db.snap == nil || err != nil { - start := time.Now() - // if metrics.EnabledExpensive { - // meter = &s.db.StorageReads - // } - enc, err = s.getTrie(db).TryGet(key.Bytes()) - if metrics.EnabledExpensive { - s.db.StorageReads += time.Since(start) - } - if err != nil { - s.setError(err) - return common.Hash{} + start := time.Now() + //if metrics.EnabledExpensive { + // meter = &s.db.StorageReads + //} + enc, err := s.getTrie(db).TryGet(key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return common.Hash{}, NewExpiredStateError(s.address, enErr) } + s.setError(err) + return common.Hash{}, nil } var value common.Hash if len(enc) > 0 { @@ -311,28 +336,49 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Has value.SetBytes(content) } s.setOriginStorage(key, value) - return value + return value, nil } // SetState updates a value in account storage. -func (s *StateObject) SetState(db Database, key, value common.Hash) { +func (s *StateObject) SetState(db Database, key, value common.Hash) error { // If the fake storage is set, put the temporary state update here. if s.fakeStorage != nil { s.fakeStorage[key] = value - return + return nil } // If the new value is the same as old, don't set - prev := s.GetState(db, key) + prev, err := s.GetState(db, key) + if exErr, ok := err.(*ExpiredStateError); ok { + exErr.isInsert = true + return exErr + } + if err != nil { + return err + } if prev == value { - return + s.accessState(key) + return nil + } + // when state insert, check if valid to insert new state + if prev != (common.Hash{}) { + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if err != nil { + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return NewInsertExpiredStateError(s.address, enErr) + } + s.setError(err) + return nil + } } // New value is different, update and journal the change + s.accessState(key) s.db.journal.append(storageChange{ account: &s.address, key: key, prevalue: prev, }) s.setState(key, value) + return nil } // SetStorage replaces the entire state storage with the given one. @@ -372,8 +418,7 @@ func (s *StateObject) finalise(prefetch bool) { } for key, value := range s.dirtyAccessedState { count := s.pendingAccessedState[key] - count += value - s.pendingAccessedState[key] = count + s.pendingAccessedState[key] = count + value } prefetcher := s.db.prefetcher @@ -400,7 +445,7 @@ func (s *StateObject) finalise(prefetch bool) { func (s *StateObject) updateTrie(db Database) Trie { // Make sure all dirty slots are finalized into the pending storage area s.finalise(false) // Don't prefetch anymore, pull directly if need be - if len(s.pendingStorage) == 0 { + if len(s.pendingStorage) == 0 && len(s.pendingReviveState) == 0 { return s.trie } // Track the amount of time wasted on updating the storage trie @@ -415,42 +460,42 @@ func (s *StateObject) updateTrie(db Database) Trie { tr := s.getPendingReviveTrie(db) usedStorage := make([][]byte, 0, len(s.pendingStorage)) - dirtyStorage := make(map[common.Hash][]byte) + dirtyStorage := make(map[common.Hash]common.Hash) + accessStorage := make(map[common.Hash]struct{}) + for k := range s.pendingAccessedState { + accessStorage[k] = struct{}{} + } for key, value := range s.pendingStorage { // Skip noop changes, persist actual changes if value == s.originStorage[key] { continue } s.originStorage[key] = value - var v []byte - if value != (common.Hash{}) { - // Encoding []byte cannot fail, ok to ignore the error. - v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) - } - dirtyStorage[key] = v + dirtyStorage[key] = value + delete(accessStorage, key) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for key, value := range dirtyStorage { - if len(value) == 0 { + var v []byte + if value != (common.Hash{}) { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + } + if len(v) == 0 { s.setError(tr.TryDelete(key[:])) } else { - s.setError(tr.TryUpdate(key[:], value)) + s.setError(tr.TryUpdate(key[:], v)) } usedStorage = append(usedStorage, common.CopyBytes(key[:])) } - }() - wg.Add(1) - go func() { - defer wg.Done() - for key := range s.pendingAccessedState { - if _, ok := dirtyStorage[key]; ok { - continue - } - // TODO update accessed state epoch index - } + // TODO(0xbundler): call TryUpdateEpoch later + //for key := range accessStorage { + // s.setError(tr.TryUpdateEpoch(key[:])) + // usedStorage = append(usedStorage, common.CopyBytes(key[:])) + //} }() if s.db.snap != nil { // If state snapshotting is active, cache the data til commit @@ -466,7 +511,11 @@ func (s *StateObject) updateTrie(db Database) Trie { } s.db.snapStorageMux.Unlock() for key, value := range dirtyStorage { - storage[string(key[:])] = value + enc, err := snapshot.NewSnapValBytes(s.Epoch, value) + if err != nil { + s.setError(err) + } + storage[string(key[:])] = enc } }() } @@ -480,9 +529,15 @@ func (s *StateObject) updateTrie(db Database) Trie { if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) } + if len(s.pendingReviveState) > 0 { + s.pendingReviveState = make(Storage) + } if len(s.pendingAccessedState) > 0 { s.pendingAccessedState = make(map[common.Hash]int) } + if s.pendingReviveTrie != nil { + s.pendingReviveTrie = nil + } // reset trie as pending trie, will commit later if tr != nil { @@ -598,6 +653,14 @@ func (s *StateObject) deepCopy(db *StateDB) *StateObject { } stateObject.dirtyReviveState = s.dirtyReviveState.Copy() stateObject.pendingReviveState = s.pendingReviveState.Copy() + stateObject.dirtyAccessedState = make(map[common.Hash]int, len(s.dirtyAccessedState)) + for k, v := range s.dirtyAccessedState { + stateObject.dirtyAccessedState[k] = v + } + stateObject.pendingAccessedState = make(map[common.Hash]int, len(s.pendingAccessedState)) + for k, v := range s.pendingAccessedState { + stateObject.pendingAccessedState[k] = v + } return stateObject } @@ -692,9 +755,19 @@ func (s *StateObject) Value() *big.Int { func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { dr := s.getDirtyReviveTrie(s.db.db) + // TODO(0xbundler): revive nub and cache revive state dr.ReviveTrie(proofCache.CacheNubs()) s.db.journal.append(reviveStorageTrieNodeChange{ address: &s.address, }) return nil } + +func (s *StateObject) accessState(key common.Hash) { + s.db.journal.append(accessedStorageStateChange{ + address: &s.address, + slot: &key, + }) + count := s.dirtyAccessedState[key] + s.dirtyAccessedState[key] = count + 1 +} diff --git a/core/state/state_test.go b/core/state/state_test.go index 4cc5c33a85..af9e48648b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -104,10 +104,10 @@ func TestNull(t *testing.T) { s.state.AccountsIntermediateRoot() s.state.Commit(nil) - if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { + if value, _ := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { t.Errorf("expected empty current value, got %x", value) } - if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { + if value, _ := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { t.Errorf("expected empty committed value, got %x", value) } } @@ -130,19 +130,19 @@ func TestSnapshot(t *testing.T) { s.state.SetState(stateobjaddr, storageaddr, data2) s.state.RevertToSnapshot(snapshot) - if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { + if v, _ := s.state.GetState(stateobjaddr, storageaddr); v != data1 { t.Errorf("wrong storage value %v, want %v", v, data1) } - if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) } // revert up to the genesis state and ensure correct content s.state.RevertToSnapshot(genesis) - if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong storage value %v, want %v", v, common.Hash{}) } - if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) } } diff --git a/core/state/statedb.go b/core/state/statedb.go index acdcfce606..2db95dc5cb 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -128,6 +128,9 @@ type StateDB struct { validRevisions []revision nextRevisionId int + // state epoch + Epoch types.StateEpoch + // Measurements gathered during execution for debugging purposes MetricsMux sync.Mutex AccountReads time.Duration @@ -148,14 +151,21 @@ type StateDB struct { StorageDeleted int } -// New creates a new state from a given trie. +// NewWithEpoch creates a new state from a given trie. +func NewWithEpoch(root common.Hash, db Database, snaps *snapshot.Tree, epoch types.StateEpoch) (*StateDB, error) { + return newStateDB(root, db, snaps, epoch) +} + +// New creates a new state from a given trie, it inits at Epoch0 func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { - return newStateDB(root, db, snaps) + return newStateDB(root, db, snaps, types.StateEpoch0) } // NewWithSharedPool creates a new state with sharedStorge on layer 1.5 +// Deprecated: disable in state expiry, it inits at Epoch0 +// TODO(0xbundler) cannot use share pool in state revive now, need optimise later func NewWithSharedPool(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { - statedb, err := newStateDB(root, db, snaps) + statedb, err := newStateDB(root, db, snaps, types.StateEpoch0) if err != nil { return nil, err } @@ -163,7 +173,7 @@ func NewWithSharedPool(root common.Hash, db Database, snaps *snapshot.Tree) (*St return statedb, nil } -func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { +func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree, epoch types.StateEpoch) (*StateDB, error) { sdb := &StateDB{ db: db, originalRoot: root, @@ -175,6 +185,7 @@ func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, preimages: make(map[common.Hash][]byte), journal: newJournal(), hasher: crypto.NewKeccakState(), + Epoch: epoch, } if sdb.snaps != nil { @@ -473,13 +484,12 @@ func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { } // GetState retrieves a value from the given account's storage trie. -// TODO access in shadow node -func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { +func (s *StateDB) GetState(addr common.Address, hash common.Hash) (common.Hash, error) { stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.GetState(s.db, hash) } - return common.Hash{} + return common.Hash{}, nil } // GetProof returns the Merkle proof for a given account. @@ -520,12 +530,12 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, } // GetCommittedState retrieves a value from the given account's committed storage trie. -func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { +func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) (common.Hash, error) { stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.GetCommittedState(s.db, hash) } - return common.Hash{} + return common.Hash{}, nil } // Database retrieves the low level database supporting the lower level trie ops. @@ -594,12 +604,12 @@ func (s *StateDB) SetCode(addr common.Address, code []byte) { } } -// TODO access state and check insert duplicated -func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { +func (s *StateDB) SetState(addr common.Address, key, value common.Hash) error { stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(s.db, key, value) + return stateObject.SetState(s.db, key, value) } + return nil } // SetStorage replaces the entire storage for the specified account with given @@ -1458,7 +1468,6 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er tasks <- func() { // Write any storage changes in the state object to its storage trie if !s.noTrie { - // TODO commit revive state cache to Trie if _, err := obj.CommitTrie(s.db); err != nil { taskResults <- err return diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 4b3a91cde6..e364c509fe 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -452,10 +452,12 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { // Check storage. if obj := state.getStateObject(addr); obj != nil { state.ForEachStorage(addr, func(key, value common.Hash) bool { - return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + val, _ := checkstate.GetState(addr, key) + return checkeq("GetState("+key.Hex()+")", val, value) }) checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { - return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + val, _ := checkstate.GetState(addr, key) + return checkeq("GetState("+key.Hex()+")", val, value) }) } if err != nil { @@ -529,10 +531,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) } - if val := state.GetState(addr, skey); val != sval { + if val, _ := state.GetState(addr, skey); val != sval { t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := state.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the non-committed state database and check pre/post commit balance @@ -543,10 +545,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("first copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } @@ -559,10 +561,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != sval { + if val, _ := copyOne.GetCommittedState(addr, skey); val != sval { t.Fatalf("first copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } // Copy the copy and check the balance once more @@ -573,10 +575,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != sval { t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } } @@ -603,10 +605,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) } - if val := state.GetState(addr, skey); val != sval { + if val, _ := state.GetState(addr, skey); val != sval { t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := state.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the non-committed state database and check pre/post commit balance @@ -617,10 +619,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("first copy committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the copy and check the balance once more @@ -631,10 +633,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("second copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } @@ -647,10 +649,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != sval { t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } // Copy the copy-copy and check the balance once more @@ -661,10 +663,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyThree.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("third copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyThree.GetState(addr, skey); val != sval { + if val, _ := copyThree.GetState(addr, skey); val != sval { t.Fatalf("third copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyThree.GetCommittedState(addr, skey); val != sval { + if val, _ := copyThree.GetCommittedState(addr, skey); val != sval { t.Fatalf("third copy committed storage slot mismatch: have %x, want %x", val, sval) } } diff --git a/core/state_processor.go b/core/state_processor.go index b42938adf9..7b926b742b 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -122,7 +122,7 @@ func (p *LightStateProcessor) Process(block *types.Block, statedb *state.StateDB // prepare new statedb statedb.StopPrefetcher() parent := p.bc.GetHeader(block.ParentHash(), block.NumberU64()-1) - statedb, err = state.New(parent.Root, p.bc.stateCache, p.bc.snaps) + statedb, err = state.NewWithEpoch(parent.Root, p.bc.stateCache, p.bc.snaps, types.GetStateEpoch(p.config, block.Number())) if err != nil { return statedb, nil, nil, 0, err } diff --git a/core/tx_pool.go b/core/tx_pool.go index 83e206c597..c6d7383ba6 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -157,7 +157,7 @@ const ( type blockChain interface { CurrentBlock() *types.Block GetBlock(hash common.Hash, number uint64) *types.Block - StateAt(root common.Hash) (*state.StateDB, error) + StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription } @@ -1428,7 +1428,7 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { if newHead == nil { newHead = pool.chain.CurrentBlock().Header() // Special case during testing } - statedb, err := pool.chain.StateAt(newHead.Root) + statedb, err := pool.chain.StateAt(newHead.Root, newHead.Number) if err != nil { log.Error("Failed to reset txpool state", "err", err) return diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index d3b40f1f4a..9b25781b2c 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -73,7 +73,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return bc.CurrentBlock() } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go index 310f26004c..157f41d7b7 100644 --- a/core/types/state_epoch.go +++ b/core/types/state_epoch.go @@ -10,23 +10,31 @@ import ( var ( // EpochPeriod indicates the state rotate epoch block length EpochPeriod = big.NewInt(7_008_000) + StateEpoch0 = StateEpoch(0) ) -// GetCurrentEpoch computes the current state epoch by hard fork and block number +type StateEpoch uint16 + +// GetStateEpoch computes the current state epoch by hard fork and block number // state epoch will indicate if the state is accessible or expiry. // Before ClaudeBlock indicates state epoch0. // ClaudeBlock indicates start state epoch1. // ElwoodBlock indicates start state epoch2 and start epoch rotate by EpochPeriod. // When N>=2 and epochN started, epoch(N-2)'s state will expire. -func GetCurrentEpoch(config *params.ChainConfig, blockNumber *big.Int) *big.Int { +func GetStateEpoch(config *params.ChainConfig, blockNumber *big.Int) StateEpoch { if config.IsElwood(blockNumber) { ret := new(big.Int).Sub(blockNumber, config.ElwoodBlock) ret.Div(ret, EpochPeriod) ret.Add(ret, common.Big2) - return ret + return StateEpoch(ret.Uint64()) } else if config.IsClaude(blockNumber) { - return common.Big1 + return 1 } else { - return common.Big0 + return 0 } } + +// EpochExpired check pre epoch if expired compared to current epoch +func EpochExpired(pre StateEpoch, cur StateEpoch) bool { + return cur >= 2 && pre < cur-1 +} diff --git a/core/types/state_epoch_test.go b/core/types/state_epoch_test.go index 0193240930..8286d21c4a 100644 --- a/core/types/state_epoch_test.go +++ b/core/types/state_epoch_test.go @@ -60,13 +60,13 @@ func TestSimpleStateEpoch(t *testing.T) { } assert.NoError(t, temp.CheckConfigForkOrder()) - assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) - assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(1000))) - assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(10000))) - assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(19999))) - assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(20000))) - assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(20000), EpochPeriod))) - assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch0, GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(1000))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(19999))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(20000))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), EpochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) } func TestNoZeroStateEpoch(t *testing.T) { @@ -76,12 +76,12 @@ func TestNoZeroStateEpoch(t *testing.T) { } assert.NoError(t, temp.CheckConfigForkOrder()) - assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) - assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(1))) - assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(2))) - assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(10000))) - assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(2), EpochPeriod))) - assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(1))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(2))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), EpochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) } func TestNearestStateEpoch(t *testing.T) { @@ -91,9 +91,9 @@ func TestNearestStateEpoch(t *testing.T) { } assert.NoError(t, temp.CheckConfigForkOrder()) - assert.Equal(t, big.NewInt(0), GetCurrentEpoch(temp, big.NewInt(0))) - assert.Equal(t, big.NewInt(1), GetCurrentEpoch(temp, big.NewInt(10000))) - assert.Equal(t, big.NewInt(2), GetCurrentEpoch(temp, big.NewInt(10001))) - assert.Equal(t, big.NewInt(3), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(10001), EpochPeriod))) - assert.Equal(t, big.NewInt(102), GetCurrentEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10001))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), EpochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) } diff --git a/core/vm/errors.go b/core/vm/errors.go index 004f8ef1c8..3956c10a39 100644 --- a/core/vm/errors.go +++ b/core/vm/errors.go @@ -19,6 +19,8 @@ package vm import ( "errors" "fmt" + + "github.com/ethereum/go-ethereum/common" ) // List evm execution errors @@ -70,3 +72,23 @@ type ErrInvalidOpCode struct { } func (e *ErrInvalidOpCode) Error() string { return fmt.Sprintf("invalid opcode: %s", e.opcode) } + +type EVMError struct { + from common.Address + to common.Address + opcode OpCode + err error +} + +func NewEVMErr(contract *Contract, op OpCode, err error) *EVMError { + return &EVMError{ + from: contract.Caller(), + to: contract.Address(), + opcode: op, + err: err, + } +} + +func (e *EVMError) Error() string { + return fmt.Sprintf("EVM err, from %v to %v at %v, got: %v", e.from, e.to, e.opcode, e.err) +} diff --git a/core/vm/evm.go b/core/vm/evm.go index eb483fd6de..ef98952562 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -135,6 +135,9 @@ type EVM struct { // available gas is calculated in gasCall* according to the 63/64 rule and later // applied in opCall*. callGasTemp uint64 + + // errorCollection all op code and err list will collect in here + errorCollection []*EVMError } // NewEVM returns a new EVM. The returned EVM is not thread safe and should @@ -152,6 +155,7 @@ func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig evm.depth = 0 evm.interpreter = NewEVMInterpreter(evm, config) + evm.errorCollection = []*EVMError{} return evm } @@ -531,3 +535,11 @@ func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment * // ChainConfig returns the environment's chain configuration func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig } + +func (evm *EVM) AppendErr(err *EVMError) { + evm.errorCollection = append(evm.errorCollection, err) +} + +func (evm *EVM) Errors() []*EVMError { + return evm.errorCollection +} diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 95173cb274..083f8c7f7b 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -95,9 +95,12 @@ var ( func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { var ( - y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) ) + current, err := evm.StateDB.GetState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } // The legacy gas metering only takes into consideration the current state // Legacy rules should be applied if we are in Petersburg (removal of EIP-1283) // OR Constantinople is not active @@ -135,7 +138,10 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi if current == value { // noop (1) return params.NetSstoreNoopGas, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return params.NetSstoreInitGas, nil @@ -182,15 +188,21 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) ) + current, err := evm.StateDB.GetState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } value := common.Hash(y.Bytes32()) if current == value { // noop (1) return params.SloadGasEIP2200, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return params.SstoreSetGasEIP2200, nil diff --git a/core/vm/instructions.go b/core/vm/instructions.go index eacc77587a..a05e2ace09 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -523,7 +523,10 @@ func opMstore8(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([] func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) - val := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) + val, err := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) + if err != nil { + return nil, err + } loc.SetBytes(val.Bytes()) return nil, nil } diff --git a/core/vm/interface.go b/core/vm/interface.go index 86dde81c80..8bc0891a60 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -43,9 +43,9 @@ type StateDB interface { SubRefund(uint64) GetRefund() uint64 - GetCommittedState(common.Address, common.Hash) common.Hash - GetState(common.Address, common.Hash) common.Hash - SetState(common.Address, common.Hash, common.Hash) + GetCommittedState(common.Address, common.Hash) (common.Hash, error) + GetState(common.Address, common.Hash) (common.Hash, error) + SetState(common.Address, common.Hash, common.Hash) error Suicide(common.Address) bool HasSuicided(common.Address) bool diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 030a063948..a78ba9666e 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -20,6 +20,8 @@ import ( "hash" "sync" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/log" @@ -172,6 +174,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( // so that it get's executed _after_: the capturestate needs the stacks before // they are returned to the pools defer func() { + if err != nil { + in.evm.AppendErr(NewEVMErr(contract, op, err)) + } returnStack(stack) }() contract.Input = input @@ -235,6 +240,10 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( dynamicCost, err = operation.dynamicGas(in.evm, contract, stack, mem, memorySize) cost += dynamicCost // for tracing if err != nil || !contract.UseGas(dynamicCost) { + // capture expired state error + if _, ok := err.(*state.ExpiredStateError); ok { + break + } return nil, ErrOutOfGas } if memorySize > 0 { diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go index 551e1f5f11..3907f8b7b3 100644 --- a/core/vm/operations_acl.go +++ b/core/vm/operations_acl.go @@ -32,11 +32,14 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.peek() - slot = common.Hash(x.Bytes32()) - current = evm.StateDB.GetState(contract.Address(), slot) - cost = uint64(0) + y, x = stack.Back(1), stack.peek() + slot = common.Hash(x.Bytes32()) + cost = uint64(0) ) + current, err := evm.StateDB.GetState(contract.Address(), slot) + if err != nil { + return 0, err + } // Check slot presence in the access list if addrPresent, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { cost = params.ColdSloadCostEIP2929 @@ -56,7 +59,10 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { // return params.SloadGasEIP2200, nil return cost + params.WarmStorageReadCostEIP2929, nil // SLOAD_GAS } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return cost + params.SstoreSetGasEIP2200, nil diff --git a/eth/api.go b/eth/api.go index f81dfa922b..b1bea41d87 100644 --- a/eth/api.go +++ b/eth/api.go @@ -291,7 +291,7 @@ func (api *PublicDebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error if block == nil { return state.Dump{}, fmt.Errorf("block #%d not found", blockNr) } - stateDb, err := api.eth.BlockChain().StateAt(block.Root()) + stateDb, err := api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.Dump{}, err } @@ -379,7 +379,7 @@ func (api *PublicDebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, sta if block == nil { return state.IteratorDump{}, fmt.Errorf("block #%d not found", number) } - stateDb, err = api.eth.BlockChain().StateAt(block.Root()) + stateDb, err = api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.IteratorDump{}, err } @@ -389,7 +389,7 @@ func (api *PublicDebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, sta if block == nil { return state.IteratorDump{}, fmt.Errorf("block %s not found", hash.Hex()) } - stateDb, err = api.eth.BlockChain().StateAt(block.Root()) + stateDb, err = api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.IteratorDump{}, err } diff --git a/eth/api_backend.go b/eth/api_backend.go index 7c0f1f0482..71728f2864 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -152,7 +152,7 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B if header == nil { return nil, nil, errors.New("header not found") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Number) return stateDb, header, err } @@ -171,7 +171,7 @@ func (b *EthAPIBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockN if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, nil, errors.New("hash is not currently canonical") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Number) return stateDb, header, err } return nil, nil, errors.New("invalid arguments; neither block nor hash specified") diff --git a/eth/catalyst/api_test.go b/eth/catalyst/api_test.go index 7e931960e5..978b37de65 100644 --- a/eth/catalyst/api_test.go +++ b/eth/catalyst/api_test.go @@ -257,7 +257,7 @@ func TestEth2NewBlock(t *testing.T) { ethservice.BlockChain().SubscribeRemovedLogsEvent(rmLogsCh) for i := 0; i < 10; i++ { - statedb, _ := ethservice.BlockChain().StateAt(parent.Root()) + statedb, _ := ethservice.BlockChain().StateAt(parent.Root(), new(big.Int).Add(parent.Number(), common.Big1)) nonce := statedb.GetNonce(testAddr) tx, _ := types.SignTx(types.NewContractCreation(nonce, new(big.Int), 1000000, big.NewInt(2*params.InitialBaseFee), logCode), types.LatestSigner(ethservice.BlockChain().Config()), testKey) ethservice.TxPool().AddLocal(tx) @@ -420,7 +420,7 @@ func TestFullAPI(t *testing.T) { logCode = common.Hex2Bytes("60606040525b7f24ec1d3ff24c2f6ff210738839dbc339cd45a5294d85c79361016243157aae7b60405180905060405180910390a15b600a8060416000396000f360606040526008565b00") ) for i := 0; i < 10; i++ { - statedb, _ := ethservice.BlockChain().StateAt(parent.Root()) + statedb, _ := ethservice.BlockChain().StateAt(parent.Root(), new(big.Int).Add(parent.Number(), common.Big1)) nonce := statedb.GetNonce(testAddr) tx, _ := types.SignTx(types.NewContractCreation(nonce, new(big.Int), 1000000, big.NewInt(2*params.InitialBaseFee), logCode), types.LatestSigner(ethservice.BlockChain().Config()), testKey) ethservice.TxPool().AddLocal(tx) diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 55e612b801..8dff04b198 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -467,10 +467,10 @@ func testGetNodeData(t *testing.T, protocol uint) { // Sanity check whether all state matches. accounts := []common.Address{testAddr, acc1Addr, acc2Addr} for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ { - root := backend.chain.GetBlockByNumber(i).Root() - reconstructed, _ := state.New(root, state.NewDatabase(reconstructDB), nil) + block := backend.chain.GetBlockByNumber(i) + reconstructed, _ := state.New(block.Root(), state.NewDatabase(reconstructDB), nil) for j, acc := range accounts { - state, _ := backend.chain.StateAt(root) + state, _ := backend.chain.StateAt(block.Root(), block.Number()) bw := state.GetBalance(acc) bh := reconstructed.GetBalance(acc) diff --git a/eth/state_accessor.go b/eth/state_accessor.go index b0b9f38f64..d6ca529589 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -56,7 +56,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state ) // Check the live database first if we have the state fully available, use that. if checkLive { - statedb, err = eth.blockchain.StateAt(block.Root()) + statedb, err = eth.blockchain.StateAt(block.Root(), block.Number()) if err == nil { return statedb, nil } @@ -66,7 +66,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // Create an ephemeral trie.Database for isolating the live one. Otherwise // the internal junks created by tracing will be persisted into the disk. database = state.NewDatabaseWithConfig(eth.chainDb, &trie.Config{Cache: 16}) - if statedb, err = state.New(block.Root(), database, nil); err == nil { + if statedb, err = state.NewWithEpoch(block.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), block.Number())); err == nil { log.Info("Found disk backend for state trie", "root", block.Root(), "number", block.Number()) return statedb, nil } @@ -89,7 +89,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // we would rewind past a persisted block (specific corner case is chain // tracing from the genesis). if !checkLive { - statedb, err = state.New(current.Root(), database, nil) + statedb, err = state.NewWithEpoch(current.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) if err == nil { return statedb, nil } @@ -105,7 +105,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state } current = parent - statedb, err = state.New(current.Root(), database, nil) + statedb, err = state.NewWithEpoch(current.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) if err == nil { break } @@ -148,7 +148,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state return nil, fmt.Errorf("stateAtBlock commit failed, number %d root %v: %w", current.NumberU64(), current.Root().Hex(), err) } - statedb, err = state.New(root, database, nil) + statedb, err = state.NewWithEpoch(root, database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) if err != nil { return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) } diff --git a/eth/tracers/api_test.go b/eth/tracers/api_test.go index c71ebfa6bc..0e36527234 100644 --- a/eth/tracers/api_test.go +++ b/eth/tracers/api_test.go @@ -141,7 +141,7 @@ func (b *testBackend) ChainDb() ethdb.Database { } func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, checkLive bool, preferDisk bool) (*state.StateDB, error) { - statedb, err := b.chain.StateAt(block.Root()) + statedb, err := b.chain.StateAt(block.Root(), block.Number()) if err != nil { return nil, errStateNotFound } @@ -153,7 +153,7 @@ func (b *testBackend) StateAtTransaction(ctx context.Context, block *types.Block if parent == nil { return nil, vm.BlockContext{}, nil, errBlockNotFound } - statedb, err := b.chain.StateAt(parent.Root()) + statedb, err := b.chain.StateAt(parent.Root(), block.Number()) if err != nil { return nil, vm.BlockContext{}, nil, errStateNotFound } diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index e91e222a67..348e0dc6b6 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -702,8 +702,12 @@ func (do *dbObj) GetState(addrSlice goja.Value, hashSlice goja.Value) goja.Value return nil } hash := common.BytesToHash(h) - state := do.db.GetState(addr, hash).Bytes() - res, err := do.toBuf(do.vm, state) + val, err := do.db.GetState(addr, hash) + if err != nil { + do.vm.Interrupt(err) + return nil + } + res, err := do.toBuf(do.vm, val.Bytes()) if err != nil { do.vm.Interrupt(err) return nil diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index aea44801d8..87f62db44a 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -188,10 +188,8 @@ func (l *StructLogger) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, s } // capture SLOAD opcodes and record the read entry in the local storage if op == vm.SLOAD && stackLen >= 1 { - var ( - address = common.Hash(stackData[stackLen-1].Bytes32()) - value = l.env.StateDB.GetState(contract.Address(), address) - ) + address := common.Hash(stackData[stackLen-1].Bytes32()) + value, _ := l.env.StateDB.GetState(contract.Address(), address) l.storage[contract.Address()][address] = value storage = l.storage[contract.Address()].Copy() } else if op == vm.SSTORE && stackLen >= 2 { diff --git a/eth/tracers/logger/logger_test.go b/eth/tracers/logger/logger_test.go index 6b1e740814..8c3bcfe85d 100644 --- a/eth/tracers/logger/logger_test.go +++ b/eth/tracers/logger/logger_test.go @@ -48,9 +48,11 @@ type dummyStatedb struct { state.StateDB } -func (*dummyStatedb) GetRefund() uint64 { return 1337 } -func (*dummyStatedb) GetState(_ common.Address, _ common.Hash) common.Hash { return common.Hash{} } -func (*dummyStatedb) SetState(_ common.Address, _ common.Hash, _ common.Hash) {} +func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetState(_ common.Address, _ common.Hash) (common.Hash, error) { + return common.Hash{}, nil +} +func (*dummyStatedb) SetState(_ common.Address, _ common.Hash, _ common.Hash) error { return nil } func TestStoreCapture(t *testing.T) { var ( diff --git a/eth/tracers/native/prestate.go b/eth/tracers/native/prestate.go index b513f383b9..532c6de6dd 100644 --- a/eth/tracers/native/prestate.go +++ b/eth/tracers/native/prestate.go @@ -174,5 +174,10 @@ func (t *prestateTracer) lookupStorage(addr common.Address, key common.Hash) { if _, ok := t.prestate[addr].Storage[key]; ok { return } - t.prestate[addr].Storage[key] = t.env.StateDB.GetState(addr, key) + // TODO when trace meet state err, just pass + val, err := t.env.StateDB.GetState(addr, key) + if err != nil { + return + } + t.prestate[addr].Storage[key] = val } diff --git a/graphql/graphql.go b/graphql/graphql.go index cc3d36235d..511a18bdd4 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -130,7 +130,7 @@ func (a *Account) Storage(ctx context.Context, args struct{ Slot common.Hash }) if err != nil { return common.Hash{}, err } - return state.GetState(a.address, args.Slot), nil + return state.GetState(a.address, args.Slot) } // Log represents an individual log message. All arguments are mandatory. diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 1c12fcdabd..8c63b94442 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -686,7 +686,11 @@ func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Addre if storageError != nil { return nil, storageError } - storageProof[i] = StorageResult{key, (*hexutil.Big)(state.GetState(address, common.HexToHash(key)).Big()), toHexSlice(proof)} + val, stateErr := state.GetState(address, common.HexToHash(key)) + if stateErr != nil { + return nil, stateErr + } + storageProof[i] = StorageResult{key, (*hexutil.Big)(val.Big()), toHexSlice(proof)} } else { storageProof[i] = StorageResult{key, &hexutil.Big{}, []string{}} } @@ -839,7 +843,10 @@ func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.A if state == nil || err != nil { return nil, err } - res := state.GetState(address, common.HexToHash(key)) + res, err := state.GetState(address, common.HexToHash(key)) + if err != nil { + return nil, err + } return res[:], state.Error() } @@ -1197,11 +1204,11 @@ func (s *PublicBlockChainAPI) needToReplay(ctx context.Context, block *types.Blo if err != nil { return false, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - parentState, err := s.b.Chain().StateAt(parent.Root()) + parentState, err := s.b.Chain().StateAt(parent.Root(), parent.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64()-1, err) } - currentState, err := s.b.Chain().StateAt(block.Root()) + currentState, err := s.b.Chain().StateAt(block.Root(), block.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64(), err) } @@ -1227,7 +1234,7 @@ func (s *PublicBlockChainAPI) replay(ctx context.Context, block *types.Block, ac if err != nil { return nil, nil, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - statedb, err := s.b.Chain().StateAt(parent.Root()) + statedb, err := s.b.Chain().StateAt(parent.Root(), block.Number()) if err != nil { return nil, nil, fmt.Errorf("state not found for block number (%d): %v", block.NumberU64()-1, err) } diff --git a/les/api_backend.go b/les/api_backend.go index 0e02a03050..1ce9140bac 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -141,7 +141,7 @@ func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B if header == nil { return nil, nil, errors.New("header not found") } - return light.NewState(ctx, header, b.eth.odr), header, nil + return light.NewState(ctx, b.ChainConfig(), header, b.eth.odr), header, nil } func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { @@ -156,7 +156,7 @@ func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockN if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, nil, errors.New("hash is not currently canonical") } - return light.NewState(ctx, header, b.eth.odr), header, nil + return light.NewState(ctx, b.ChainConfig(), header, b.eth.odr), header, nil } return nil, nil, errors.New("invalid arguments; neither block nor hash specified") } diff --git a/les/odr_test.go b/les/odr_test.go index b77bb30d9b..18624d8280 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -100,7 +100,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon st, err = state.New(header.Root, state.NewDatabase(db), nil) } else { header := lc.GetHeaderByHash(bhash) - st = light.NewState(ctx, header, lc.Odr()) + st = light.NewState(ctx, config, header, lc.Odr()) } if err == nil { bal := st.GetBalance(addr) @@ -148,7 +148,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai } } else { header := lc.GetHeaderByHash(bhash) - state := light.NewState(ctx, header, lc.Odr()) + state := light.NewState(ctx, config, header, lc.Odr()) state.SetBalance(bankAddr, math.MaxBig256) msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} context := core.NewEVMBlockContext(header, lc, nil) diff --git a/les/state_accessor.go b/les/state_accessor.go index 112e6fd44d..b28ba33add 100644 --- a/les/state_accessor.go +++ b/les/state_accessor.go @@ -30,7 +30,7 @@ import ( // stateAtBlock retrieves the state database associated with a certain block. func (leth *LightEthereum) stateAtBlock(ctx context.Context, block *types.Block, reexec uint64) (*state.StateDB, error) { - return light.NewState(ctx, block.Header(), leth.odr), nil + return light.NewState(ctx, leth.chainConfig, block.Header(), leth.odr), nil } // stateAtTransaction returns the execution environment of a certain transaction. diff --git a/light/odr_test.go b/light/odr_test.go index 1a91662484..0a69cb642a 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -146,7 +146,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc var st *state.StateDB if bc == nil { header := lc.GetHeaderByHash(bhash) - st = NewState(ctx, header, lc.Odr()) + st = NewState(ctx, lc.Config(), header, lc.Odr()) } else { header := bc.GetHeaderByHash(bhash) st, _ = state.New(header.Root, state.NewDatabase(db), nil) @@ -185,7 +185,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if bc == nil { chain = lc header = lc.GetHeaderByHash(bhash) - st = NewState(ctx, header, lc.Odr()) + st = NewState(ctx, lc.Config(), header, lc.Odr()) } else { chain = bc header = bc.GetHeaderByHash(bhash) diff --git a/light/trie.go b/light/trie.go index f7962e6ad6..c9379beb1b 100644 --- a/light/trie.go +++ b/light/trie.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" @@ -35,8 +37,8 @@ var ( sha3Nil = crypto.Keccak256Hash(nil) ) -func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { - state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) +func NewState(ctx context.Context, config *params.ChainConfig, head *types.Header, odr OdrBackend) *state.StateDB { + state, _ := state.NewWithEpoch(head.Root, NewStateDatabase(ctx, head, odr), nil, types.GetStateEpoch(config, head.Number)) return state } diff --git a/light/txpool.go b/light/txpool.go index 1b819794c1..b37870abd1 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -117,7 +117,7 @@ func NewTxPool(config *params.ChainConfig, chain *LightChain, relay TxRelayBacke // currentState returns the light state of the current head header func (pool *TxPool) currentState(ctx context.Context) *state.StateDB { - return NewState(ctx, pool.chain.CurrentHeader(), pool.odr) + return NewState(ctx, pool.config, pool.chain.CurrentHeader(), pool.odr) } // GetNonce returns the "pending" nonce of a given address. It always queries diff --git a/miner/miner.go b/miner/miner.go index 7f0f1583e8..9bd7a50c0b 100644 --- a/miner/miner.go +++ b/miner/miner.go @@ -201,7 +201,7 @@ func (miner *Miner) Pending() (*types.Block, *state.StateDB) { if block == nil { return nil, nil } - stateDb, err := miner.worker.chain.StateAt(block.Root()) + stateDb, err := miner.worker.chain.StateAt(block.Root(), block.Number()) if err != nil { return nil, nil } diff --git a/miner/miner_test.go b/miner/miner_test.go index cf619845dd..028b345831 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -19,6 +19,7 @@ package miner import ( "errors" + "math/big" "testing" "time" @@ -75,7 +76,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return bc.CurrentBlock() } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/miner/worker.go b/miner/worker.go index 9f4cce2988..e8a67514e6 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -698,7 +698,8 @@ func (w *worker) makeEnv(parent *types.Block, header *types.Header, coinbase com prevEnv *environment) (*environment, error) { // Retrieve the parent state to execute on top and start a prefetcher for // the miner to speed block sealing up a bit - state, err := w.chain.StateAtWithSharedPool(parent.Root()) + //state, err := w.chain.StateAtWithSharedPool(parent.Root()) + state, err := w.chain.StateAt(parent.Root(), header.Number) if err != nil { return nil, err } diff --git a/trie/errors.go b/trie/errors.go index 57d281067d..c0121dee6a 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -19,6 +19,8 @@ package trie import ( "fmt" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" ) @@ -37,7 +39,7 @@ func (err *MissingNodeError) Error() string { type ExpiredNodeError struct { ExpiredNode node // node of the expired node Path []byte // hex-encoded path to the expired node - Epoch uint16 + Epoch types.StateEpoch } func (err *ExpiredNodeError) Error() string { From 1cf57da77cd052f391b820aa1ae8a78c1051130c Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:12:44 +0800 Subject: [PATCH 28/51] trie: add proofNub resolve kv method, opt split partial witness rule; stateObject: opt revive state query; --- core/state/database.go | 3 + core/state/errors.go | 8 +-- core/state/journal.go | 2 +- core/state/state_object.go | 57 ++++++++++++---- core/vm/instructions.go | 6 +- eth/backend.go | 4 +- light/trie.go | 4 ++ trie/dummy_trie.go | 4 ++ trie/proof.go | 136 ++++++++++++++++++++++++++++++------- trie/proof_test.go | 41 ++++++++++- trie/secure_trie.go | 14 ++-- trie/trie.go | 2 +- 12 files changed, 225 insertions(+), 56 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 46da90c954..7df0071670 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -110,6 +110,9 @@ type Trie interface { // can be used even if the trie doesn't have one. Hash() common.Hash + // HashKey return trie key hash result + HashKey(key []byte) []byte + // Commit writes all nodes to the trie's memory database, tracking the internal // and external (for account tries) references. Commit(onleaf trie.LeafCallback) (common.Hash, int, error) diff --git a/core/state/errors.go b/core/state/errors.go index 66f0f1dbd9..ac7649c5c3 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -25,20 +25,20 @@ func NewPlainExpiredStateError(addr common.Address, key common.Hash, epoch types } } -func NewExpiredStateError(addr common.Address, err *trie.ExpiredNodeError) *ExpiredStateError { +func NewExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { return &ExpiredStateError{ Addr: addr, - Key: common.Hash{}, + Key: key, Path: err.Path, Epoch: err.Epoch, isInsert: false, } } -func NewInsertExpiredStateError(addr common.Address, err *trie.ExpiredNodeError) *ExpiredStateError { +func NewInsertExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { return &ExpiredStateError{ Addr: addr, - Key: common.Hash{}, + Key: key, Path: err.Path, Epoch: err.Epoch, isInsert: true, diff --git a/core/state/journal.go b/core/state/journal.go index 10b00a9228..bf594585b2 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -281,7 +281,7 @@ func (ch accessListAddSlotChange) dirtied() *common.Address { } func (ch reviveStorageTrieNodeChange) revert(s *StateDB) { - s.getStateObject(*ch.address).dirtyReviveState = make(Storage) + s.getStateObject(*ch.address).dirtyReviveState = make(map[string]common.Hash) s.getStateObject(*ch.address).dirtyReviveTrie = nil } diff --git a/core/state/state_object.go b/core/state/state_object.go index 7305e705bc..26762b31f8 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -99,8 +99,8 @@ type StateObject struct { dirtyReviveTrie Trie // dirtyReviveTrie for tx // when R&W, access revive state first - pendingReviveState Storage // pendingReviveState for block, it cannot flush to trie, just cache - dirtyReviveState Storage // dirtyReviveState for tx, for cache dirtyReviveTrie + pendingReviveState map[string]common.Hash // pendingReviveState for block, it cannot flush to trie, just cache + dirtyReviveState map[string]common.Hash // dirtyReviveState for tx, for cache dirtyReviveTrie // accessed state, don't record revive state, don't record nonexist state or any err access pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed, it will update epoch index late @@ -149,8 +149,8 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *St originStorage: make(Storage), pendingStorage: make(Storage), dirtyStorage: make(Storage), - dirtyReviveState: make(Storage), - pendingReviveState: make(Storage), + dirtyReviveState: make(map[string]common.Hash), + pendingReviveState: make(map[string]common.Hash), dirtyAccessedState: make(map[common.Hash]int), pendingAccessedState: make(map[common.Hash]int), Epoch: db.Epoch, @@ -231,7 +231,7 @@ func (s *StateObject) GetState(db Database, key common.Hash) (common.Hash, error s.accessState(key) return value, nil } - if revived, revive := s.dirtyReviveState[key]; revive { + if revived, revive := s.queryFromReviveState(db, s.dirtyReviveState, key); revive { s.accessState(key) return revived, nil } @@ -278,7 +278,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha if value, pending := s.pendingStorage[key]; pending { return value, nil } - if revived, revive := s.pendingReviveState[key]; revive { + if revived, revive := s.queryFromReviveState(db, s.pendingReviveState, key); revive { return revived, nil } @@ -322,7 +322,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha } if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { - return common.Hash{}, NewExpiredStateError(s.address, enErr) + return common.Hash{}, NewExpiredStateError(s.address, key, enErr) } s.setError(err) return common.Hash{}, nil @@ -364,7 +364,7 @@ func (s *StateObject) SetState(db Database, key, value common.Hash) error { _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { - return NewInsertExpiredStateError(s.address, enErr) + return NewInsertExpiredStateError(s.address, key, enErr) } s.setError(err) return nil @@ -429,7 +429,7 @@ func (s *StateObject) finalise(prefetch bool) { s.dirtyStorage = make(Storage) } if len(s.dirtyReviveState) > 0 { - s.dirtyReviveState = make(Storage) + s.dirtyReviveState = make(map[string]common.Hash) } if len(s.dirtyAccessedState) > 0 { s.dirtyAccessedState = make(map[common.Hash]int) @@ -530,7 +530,7 @@ func (s *StateObject) updateTrie(db Database) Trie { s.pendingStorage = make(Storage) } if len(s.pendingReviveState) > 0 { - s.pendingReviveState = make(Storage) + s.pendingReviveState = make(map[string]common.Hash) } if len(s.pendingAccessedState) > 0 { s.pendingAccessedState = make(map[common.Hash]int) @@ -651,8 +651,14 @@ func (s *StateObject) deepCopy(db *StateDB) *StateObject { if s.pendingReviveTrie != nil { stateObject.pendingReviveTrie = db.db.CopyTrie(s.pendingReviveTrie) } - stateObject.dirtyReviveState = s.dirtyReviveState.Copy() - stateObject.pendingReviveState = s.pendingReviveState.Copy() + stateObject.dirtyReviveState = make(map[string]common.Hash, len(s.dirtyReviveState)) + for k, v := range s.dirtyReviveState { + stateObject.dirtyReviveState[k] = v + } + stateObject.pendingReviveState = make(map[string]common.Hash, len(s.pendingReviveState)) + for k, v := range s.pendingReviveState { + stateObject.pendingReviveState[k] = v + } stateObject.dirtyAccessedState = make(map[common.Hash]int, len(s.dirtyAccessedState)) for k, v := range s.dirtyAccessedState { stateObject.dirtyAccessedState[k] = v @@ -755,11 +761,27 @@ func (s *StateObject) Value() *big.Int { func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { dr := s.getDirtyReviveTrie(s.db.db) - // TODO(0xbundler): revive nub and cache revive state - dr.ReviveTrie(proofCache.CacheNubs()) s.db.journal.append(reviveStorageTrieNodeChange{ address: &s.address, }) + // revive nub and cache revive state + for _, nub := range dr.ReviveTrie(proofCache.CacheNubs()) { + kv, err := nub.ResolveKV() + if err != nil { + return err + } + for k, enc := range kv { + var value common.Hash + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + return err + } + value.SetBytes(content) + } + s.dirtyReviveState[k] = value + } + } return nil } @@ -771,3 +793,10 @@ func (s *StateObject) accessState(key common.Hash) { count := s.dirtyAccessedState[key] s.dirtyAccessedState[key] = count + 1 } + +// TODO(0xbundler): add hash key cache later +func (s *StateObject) queryFromReviveState(db Database, reviveState map[string]common.Hash, key common.Hash) (common.Hash, bool) { + hashKey := string(s.getTrie(db).HashKey(key.Bytes())) + val, ok := reviveState[hashKey] + return val, ok +} diff --git a/core/vm/instructions.go b/core/vm/instructions.go index a05e2ace09..e2bf1f0825 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -537,8 +537,10 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b } loc := scope.Stack.pop() val := scope.Stack.pop() - interpreter.evm.StateDB.SetState(scope.Contract.Address(), - loc.Bytes32(), val.Bytes32()) + if err := interpreter.evm.StateDB.SetState(scope.Contract.Address(), + loc.Bytes32(), val.Bytes32()); err != nil { + return nil, err + } return nil, nil } diff --git a/eth/backend.go b/eth/backend.go index fc0ca6534c..f7fe65e285 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -216,7 +216,9 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { bcOps = append(bcOps, core.EnableLightProcessor) } if config.PipeCommit { - bcOps = append(bcOps, core.EnablePipelineCommit) + // TODO(0xbundler): may support pipeCommit in state expiry later + log.Info("temporary not support pipeCommit in state expiry") + //bcOps = append(bcOps, core.EnablePipelineCommit) } if config.PersistDiff { bcOps = append(bcOps, core.EnablePersistDiff(config.DiffBlock)) diff --git a/light/trie.go b/light/trie.go index c9379beb1b..56f0451a9b 100644 --- a/light/trie.go +++ b/light/trie.go @@ -175,6 +175,10 @@ func (t *odrTrie) GetKey(sha []byte) []byte { return nil } +func (t *odrTrie) HashKey(key []byte) []byte { + return nil +} + func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { return errors.New("not implemented, needs client/server interface split") } diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index ac021ed643..9b8a29f888 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -74,6 +74,10 @@ func (t *EmptyTrie) GetKey(shaKey []byte) []byte { return nil } +func (t *EmptyTrie) HashKey(key []byte) []byte { + return nil +} + func (t *EmptyTrie) Commit(onleaf LeafCallback) (root common.Hash, committed int, err error) { return common.Hash{}, 0, nil } diff --git a/trie/proof.go b/trie/proof.go index 357f96a0a7..cec1564ab9 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -164,9 +164,44 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) // MPTProofNub include fullNode shortNode, revive n1 first if exist, // revive n2 later if exist, include node hash type MPTProofNub struct { - RootHexKey []byte // root hex key, max 64bytes - n1 node - n2 node + n1PrefixKey []byte // n1's prefix hex key, max 64bytes + n1 node + n2PrefixKey []byte // n2's prefix hex key, max 64bytes + n2 node +} + +// ResolveKV revive state could revive KV from fullNode[0-15] or fullNode[16] or shortNode.Val, return KVs for cache & snap +func (m *MPTProofNub) ResolveKV() (map[string][]byte, error) { + kvMap := make(map[string][]byte) + if err := resolveKV(m.n1, m.n1PrefixKey, kvMap); err != nil { + return nil, err + } + if err := resolveKV(m.n2, m.n2PrefixKey, kvMap); err != nil { + return nil, err + } + + return kvMap, nil +} + +func resolveKV(origin node, prefixKey []byte, kvWriter map[string][]byte) error { + switch n := origin.(type) { + case nil, hashNode: + return nil + case valueNode: + kvWriter[string(hexToKeybytes(prefixKey))] = n + return nil + case *shortNode: + return resolveKV(n.Val, append(prefixKey, n.Key...), kvWriter) + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + if err := resolveKV(n.Children[i], append(prefixKey, byte(i)), kvWriter); err != nil { + return err + } + } + return resolveKV(n.Children[BranchNodeLength-1], prefixKey, kvWriter) + default: + panic(fmt.Sprintf("invalid node: %v", origin)) + } } type MPTProofCache struct { @@ -181,10 +216,8 @@ type MPTProofCache struct { // VerifyProof verify proof in MPT witness // 1. calculate hash // 2. decode trie node -// 3. verify partial merkle proof of the witness, TODO match algo will check inner mem node scene, until meet hash node or value node or nil? -// 4. split to partial witness, TODO check if satisfy partial witness rules? -// TODO later revive state could revive KV from fullNode[0-15] or fullNode[16] shortNode.Val, return KVs for cache & snap -// another easy method is that revive direct to Trie/Trie cache, query later and set to KV cache +// 3. verify partial merkle proof of the witness +// 4. split to partial witness func (m *MPTProofCache) VerifyProof() error { m.cacheHashes = make([][]byte, len(m.Proof)) m.cacheNodes = make([]node, len(m.Proof)) @@ -231,49 +264,102 @@ func (m *MPTProofCache) VerifyProof() error { prefix := m.RootKeyHex for i := 0; i < len(m.cacheNodes); i++ { if i-1 >= 0 { - prefix = append(prefix, m.cacheHexPath[i-1]...) + prefix = copyNewSlice(prefix, m.cacheHexPath[i-1]) } // prefix = append(prefix, m.cacheHexPath[i]...) n1 := m.cacheNodes[i] nub := MPTProofNub{ - RootHexKey: prefix, - n1: n1, - n2: nil, + n1PrefixKey: prefix, + n1: n1, + n2: nil, + n2PrefixKey: nil, } - if needMergeNextNode(m.cacheNodes, i) { + + // check if satisfy partial witness rules, + // that short node must with its child, may full node or valueNode + merge, err := mergeNextNode(m.cacheNodes, i) + if err != nil { + return err + } + if merge { i++ - prefix = append(prefix, m.cacheHexPath[i-1]...) + prefix = copyNewSlice(prefix, m.cacheHexPath[i-1]) nub.n2 = m.cacheNodes[i] + nub.n2PrefixKey = prefix } - // TODO check short node must with child in same nub m.cacheNubs = append(m.cacheNubs, &nub) } return nil } +func copyNewSlice(s1, s2 []byte) []byte { + ret := make([]byte, len(s1)+len(s2)) + copy(ret, s1) + copy(ret[len(s1):], s2) + return ret +} + func (m *MPTProofCache) CacheNubs() []*MPTProofNub { return m.cacheNubs } -func needMergeNextNode(nodes []node, i int) bool { - if i >= len(nodes) || i+1 >= len(nodes) { - return false +// mergeNextNode check short node must with child in same nub +func mergeNextNode(nodes []node, i int) (bool, error) { + if i >= len(nodes) { + return false, errors.New("mergeNextNode input outbound index") } n1 := nodes[i] - n2 := nodes[i+1] - - if n2.nodeType() == valueNodeType { - return true + switch n := n1.(type) { + case *shortNode: + need, err := needNextProofNode(n, n.Val) + if err != nil { + return false, err + } + if need && i+1 >= len(nodes) { + return false, errors.New("mergeNextNode short node must with its child") + } + return need, nil + case valueNode: + return false, errors.New("mergeNextNode value node need merge with prev node") } - // check extended node - if n1.nodeType() == shortNodeType { - return true + if i+1 >= len(nodes) { + return false, nil } + return nodes[i+1].nodeType() == valueNodeType, nil +} - return false +// needNextProofNode check if node need merge next node into a proofNub, because TrieExtendNode must with its child to revive together +func needNextProofNode(parent, origin node) (bool, error) { + switch n := origin.(type) { + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + need, err := needNextProofNode(n, n.Children[i]) + if err != nil { + return false, err + } + if need { + return true, nil + } + } + return false, nil + case *shortNode: + if parent.nodeType() == shortNodeType { + return false, errors.New("needNextProofNode cannot short node's child is short node") + } + return needNextProofNode(n, n.Val) + case valueNode: + return false, nil + case hashNode: + if parent.nodeType() == fullNodeType { + return false, nil + } + return true, nil + default: + return false, errors.New("needNextProofNode unsupported node") + } } func matchHashNodeInFullNode(child []byte, n *fullNode) (int, error) { diff --git a/trie/proof_test.go b/trie/proof_test.go index fcdcb2a89d..10eeb5cb32 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -20,12 +20,15 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "encoding/hex" mrand "math/rand" "sort" "strings" "testing" "time" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/core/types" "github.com/stretchr/testify/assert" @@ -1415,10 +1418,46 @@ func TestMPTProofCache_VerifyProof_normalCase(t *testing.T) { h.sha.Write(key) h.sha.Read(hash) ln := cache.cacheNubs[6].n1.(*shortNode) - hexKey := append(cache.cacheNubs[6].RootHexKey, ln.Key...) + hexKey := append(cache.cacheNubs[6].n1PrefixKey, ln.Key...) assert.Equal(t, hash, hexToKeybytes(hexKey)) } +func TestMPTProofCache_ResolveKV_normalCase(t *testing.T) { + cache := makeMPTProofCache(nil, []string{ + "0xf90211a03697534056039e03300557bd69fe16e18ce4a6ccd5522db4dfa97dfe1fad3d3aa0b1bf1f230b98b9034738d599177ae817c08143b9395a47f300636b0dd2fb3c5ea0aa04a4966751d4c50063fe13a96a6c7924f665819733f556849b5eb9fa1d6839a0e162e080d1c12c59dc984fb2246d8ad61209264bee40d3fdd07c4ea4ff411b6aa0e5c3f2dde71bf303423f34674748567dcdf8379129653b8213f698468738d492a068a3e3059b6e7115a055a7874f81c5a1e84ddc1967527973f8c78cd86a1c9f8fa0d734bd63b7be8e8471091b792f5bbcbc7b0ce582f6d985b7a15a3c0155242c56a00143c06f57a65c8485dbae750aa51df5dff1bf7bdf28060129a20de9e51364eda07b416f79b3f4e39d0159efff351009d44002d9e83530fb5a5778eb55f5f4432ca036706b52196fa0b73feb2e7ff8f1379c7176d427dd44ad63c7b65e66693904a1a0fd6c8b815e2769ce379a20eaccdba1f145fb11f77c280553f15ee4f1ee135375a02f5233009f082177e5ed2bfa6e180bf1a7310e6bc3c079cb85a4ac6fee4ae379a03f07f1bb33fa26ebd772fa874914dc7a08581095e5159fdcf9221be6cbeb6648a097557eec1ac08c3bfe45ce8e34cd329164a33928ac83fef1009656536ef6907fa028196bfb31aa7f14a0a8000b00b0aa5d09450c32d537e45eebee70b14313ff1ca0126ce265ca7bbb0e0b01f068d1edef1544cbeb2f048c99829713c18d7abc049a80", + "0xf90211a01f7c019858f447dbff8ed8e4329e88600bab8f17fec9594c664e25acc95da3dba00916866833439bc250b5e08edc2b7c041634ca3b33a013f79eb299a3e33a056fa0a8fae921061f3bc81154b5d2c149d3860a3cdef00d02173ee3b5837de6b26f56a097583621a54e74994619a97cf82f823005a35ad1bf4795047726619eccd11ad4a0be39789c4abdb2a185cce40f2c77575eee2a1eab38d3168395560da15307614aa0c46a7fcea5501656d70508178f0731460edcce1c01e32ed08c1468f2593db277a0460f030038b09b8461d834ded79d77fd3ed25c4e248775752bf1c830a530ef2ba03d887abec623c6b4d93be75d1608dcacb291bf30406fbbc944a94aa4203fb1eba0475eeb07d471044af74313093cfbb0201e17a405d7af13a7ae6b245e18915515a0ea794035230f90e14f4f601d0e9f04217ed02710df363417bc3dd7dda5a84608a08aa9f0c44e5a9e359b65d4f0e40937662f07af10642a626e19ca7bc56c329706a05677c9b342c1cbcfd491cd491800d44423d1bdc08ab00eb59376a7118f7c4e23a0b761f22d67328e6c90caf8be65affb701b045f4f1f581472fc5f724e0c61328da0b6727240643009e59bba78249918a83ca8faeaf1d5b47fe0b41c356e8ae300dda0d7cbeb12faf439126bdb86f94b6262c3a70e88e4d8a47b49735f0b9d632a5df8a0428fe930556ce5ea94bcc4d092fe0f05a2a9b360175857711729ab3a7092a81780", + "0xf90211a0a81ff74945250ba9925753e379aa8815a3fe77926aad2d02db4d78a3da7ecb48a0e795c64d2b738b34ebb77a3307df800c9f1fe324b442cd71ffa5fd268ec12ffca020a5f968a1c8292d08135cb451daed999a44eb0cdd04526a89c38e753398c50ca0cc7165f6b984f2f2569d101d70f72af94eb79b18c126895da2d3cf557b5aca51a00a10f4dc5851e71f195cca5bd7a2a62a6c3ca03d95e91c7dc55e55f4cb726903a09ecaaf877b18a55ca4001fe46cc10389c94104abc2994cdfe5843f28814b119ca099f59b2b52ee9d9b44ab7123a86775dcfe6c50301dd7cf9c6fc6ea968c1d2a01a0fccda5c1489dd3268fcf2471f6a372fc3afe5eebecc0ba9fc3c023d85da26aeda0d730ea7aabdad2e5451e826726e53c86e6e805e46220bc7edd80bf3f7e467f96a074f3d767a84557aa9559b08b2d1f93d8205827c042d1ac616b8cf37e22de2beca03ca1fe479ea1e64c3bfeae9c603ef4f55c270de0bb4c79d045f67d3481ca2852a0bdcbc3d154db40b5faa1f6875f46d485b96bdfffb4513da631f9b302768faae7a096923f4f559c7b7f912587292cc865aba5934e9e75c9aa88f20de73e928c6c51a0f52ca9710977327dd407ba9d9f00b0075d7e312ee19686f1b53579585ff50132a076adb4cf98f9af261d1b2a147b9b2cbace1647eab537ca1a55e8d40d35775b74a0723fa2f72d6a983939bc9bf9ee22e05fc142437726450e5707c60155bc34ec0580", + "0xf90211a039a64fd9b31f3ea3bf9a991ca828eadb67cf9ae0ebbcdd195d297454699580e8a0c95d85a63beb9a02b56d032116d86fdf9a64065dd5d44e33acef674ae3ceb6f9a0d83ef07a99302abccdce1289d353a20494713f45d8edbd3c2e87f08788a878d9a063a19aee41fec40a98edcf4d60c759c509b512bfc5d9feae7de50c8c2d00eee7a029ce5c8c1ac9939cbf481274b8e6f5f24a430136dc5aab7deb489a9ce7db5a95a03b45c53d2e4f54e49a53eb298aee828f4803581ee60ca52940533b77c3e3fadfa082b8084dfc0337a49fd5d53ce107fa0bdcdf25ac7bbac017d7ac250b77c8764da0d4dab90004fd3bb36b2fa5f6e914a7c90820d119e5ba3ba8439270c6cfbd0a18a0bd9286c9248ae8a953d00aa9906f06b6f0364d0bb9fb615de04c9f8d5b5fd346a00955eccaa41a17fafb0ed66272b5183b3a30a973af7a43e0f25f0b640fd5df0ca0a7ba773602ace05991211770fc5555dd9c55e2518bcb005d1db022584a89132da0450b066588b44701992ab4a926d5dd185dd465abf1e51a3f60ab4a9f924cf85aa05eda0687636512339d0db99eade61c0d44358bb8027185296ad59b4cedd2935aa0cd18604dee296e5443b598e470a04efde04fbb0213c0fda8cc3af20dae2a1c34a01c0474c1bf15e4732d1c22a1303573c7ec8643f711354e853ab26038b2eebe25a081a0b2d329a9519375f71384a7130faafc3a1379db59bfb7db7135c9d27d9cc080", + "0xf901f1a0e14d9caa85464966f0371f427250e7bf2d86f6d41535b08ea79044391d6f0fe7a08bea233c92e4c1d05bd03d62cd93974cc29f6df54e6fa2574bf8dfb3af936a85a0dd4f4af9ab72d1bbbf59c286a731614f48bf3cdcef572cd819426fc2a30ae5d9a058465ea8e97b2f3a873646f9504aff111ab967abaea8ebb2fe91bf058ea162f3a06466c41eb770bae5c07c26cd692540e7f9af70fee2bc164e12f29083ebd2cca1a0fe0925da033ffb967ca1e1d6505c2ca740553916c3b9a205131d719b7921fb00a0973ecee958d1305f1b6f8159a9732e47f5fbd4121f60254ea44a8f028308150780a0f81717c7f8702ed39d89a34fc86827aab31db26b22d22ee399667e4a44081c95a0e276ec24ee74ee61bdbe92e8afa57813d59f26e0ad34f24d71f0627719f8a11ea00a8ec2f6922480890f9e67c62c1654dcccf1710a01a378f614e02a3785d42d7ea08c43cc0c6c690dd87cf42691e9ead799c5ddbcfc9458ebd054a46acecedbe9f7a0d3727ed077c60ed6ff1d9d5d20fdf2912a513942d0efee9ec2d97d33ab9b7f56a03f293b0c0f25b9aa2735c4e42b132008d8af2ecfaaefdc940dc94cf312f5a1dda04860bc5f19829bb277554927c5b6dbc0af93025aca49aed3be020a3080996a26a0e47475bf4f62364d8a0dd87f2c21d0b64ef4d9aaf5fc46d9b1e3af5b83f56d4580", + "0xf89180a0dd3524693059f48ce47c5dd82d0462953a5141c7abc5a63981637b71e0100bcf80a033ca98826ea65c1c4be16e1df12e78142d992bf1fc0189401632ceff3fcbbe7880a0e5daf803b0890d164e287fa43b51332286cddeb9b62280426037fc02dc2b4d5f8080a0bfa907c2e30720a07347d23cfa573dec8c9b3afa51a4a0e0783a481da0fcb1b38080808080808080", + "0xe79e208246cec5810061f4ff7efe1dcd6cb407d59abc3478830df04484584c868786d647b234389e", + }) + + err := cache.VerifyProof() + assert.NoError(t, err) + key := common.Hex2Bytes("95eea00c49d14a895954837cd876ffa8cfad96cbaacc40fc31d6df2c902528a8") + hash := make([]byte, common.HashLength) + h := newHasher(false) + h.sha.Reset() + h.sha.Write(key) + h.sha.Read(hash) + t.Log("expect:", hex.EncodeToString(hash)) + for i, nub := range cache.cacheNubs { + kvMap, err := nub.ResolveKV() + assert.NoError(t, err) + if i != 6 { + assert.Equal(t, 0, len(kvMap)) + continue + } + for k, v := range kvMap { + t.Log("k:", hex.EncodeToString([]byte(k)), "v:", hex.EncodeToString(v)) + } + enc := kvMap[string(hash)] + _, content, _, _ := rlp.Split(enc) + assert.Equal(t, common.Hex2Bytes("d647b234389e"), content) + } +} + func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { proof := make([][]byte, len(proofs)) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index fa74b8b15c..d6fc697805 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -77,7 +77,7 @@ func (t *SecureTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(t.hashKey(key)) + return t.trie.TryGet(t.HashKey(key)) } // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not @@ -89,7 +89,7 @@ func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) { // TryUpdate account will abstract the write of an account to the // secure trie. func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { - hk := t.hashKey(key) + hk := t.HashKey(key) data, err := rlp.EncodeToBytes(acc) if err != nil { return err @@ -122,7 +122,7 @@ func (t *SecureTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryUpdate(key, value []byte) error { - hk := t.hashKey(key) + hk := t.HashKey(key) err := t.trie.TryUpdate(hk, value) if err != nil { return err @@ -141,7 +141,7 @@ func (t *SecureTrie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryDelete(key []byte) error { - hk := t.hashKey(key) + hk := t.HashKey(key) delete(t.getSecKeyCache(), string(hk)) return t.trie.TryDelete(hk) } @@ -205,10 +205,10 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { return t.trie.NodeIterator(start) } -// hashKey returns the hash of key as an ephemeral buffer. +// HashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -func (t *SecureTrie) hashKey(key []byte) []byte { +// invalid on the next call to HashKey or secKey. +func (t *SecureTrie) HashKey(key []byte) []byte { hash := make([]byte, common.HashLength) h := newHasher(false) h.sha.Reset() diff --git a/trie/trie.go b/trie/trie.go index dfe2f0d660..11ab683e0b 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -680,7 +680,7 @@ func (t *Trie) ReviveTrie(proof []*MPTProofNub) (successNubs []*MPTProofNub) { func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err error) { for _, nub := range proof { - newNode, didResolve, err := t.tryRevive(t.root, nub.RootHexKey, *nub) + newNode, didResolve, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub) if didResolve && err == nil { successNubs = append(successNubs, nub) t.root = newNode From 9c2297d357311747a4301eda0db7573ac2838739 Mon Sep 17 00:00:00 2001 From: cryyl <1226241521@qq.com> Date: Tue, 18 Apr 2023 16:58:00 +0800 Subject: [PATCH 29/51] state expriy: implement of MPT read and write Signed-off-by: cryyl <1226241521@qq.com> --- cmd/geth/snapshot.go | 8 +- core/blockchain.go | 2 +- core/state/database.go | 4 +- core/state/pruner/pruner.go | 4 +- core/state/snapshot/generate_test.go | 28 ++-- eth/api.go | 4 +- eth/protocols/snap/handler.go | 4 +- eth/protocols/snap/sync_test.go | 2 +- les/downloader/downloader_test.go | 2 +- trie/database.go | 21 ++- trie/iterator_test.go | 2 +- trie/node.go | 57 ++++--- trie/secure_trie.go | 18 ++- trie/secure_trie_test.go | 4 +- trie/shadow_node.go | 9 +- trie/shadownodes.go | 15 +- trie/sync_test.go | 6 +- trie/trie.go | 216 +++++++++++++++++++++++---- 18 files changed, 303 insertions(+), 103 deletions(-) diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index c8ce3bff50..32898a2045 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -542,7 +542,7 @@ func traverseState(ctx *cli.Context) error { log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } triedb := trie.NewDatabase(chaindb) - t, err := trie.NewSecure(root, triedb) + t, err := trie.NewSecure(root, triedb, false) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) return err @@ -563,7 +563,7 @@ func traverseState(ctx *cli.Context) error { return err } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, triedb) + storageTrie, err := trie.NewSecure(acc.Root, triedb, true) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err @@ -631,7 +631,7 @@ func traverseRawState(ctx *cli.Context) error { log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } triedb := trie.NewDatabase(chaindb) - t, err := trie.NewSecure(root, triedb) + t, err := trie.NewSecure(root, triedb, false) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) return err @@ -667,7 +667,7 @@ func traverseRawState(ctx *cli.Context) error { return errors.New("invalid account") } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, triedb) + storageTrie, err := trie.NewSecure(acc.Root, triedb, true) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return errors.New("missing storage trie") diff --git a/core/blockchain.go b/core/blockchain.go index 9908c02c5b..09a7922065 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -839,7 +839,7 @@ func (bc *BlockChain) SnapSyncCommitHead(hash common.Hash) error { if block == nil { return fmt.Errorf("non existent block [%x..]", hash[:4]) } - if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB()); err != nil { + if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), false); err != nil { return err } diff --git a/core/state/database.go b/core/state/database.go index 7df0071670..a0bc01dedd 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -216,7 +216,7 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { return tr.(Trie).(*trie.SecureTrie).Copy(), nil } } - tr, err := trie.NewSecure(root, db.db) + tr, err := trie.NewSecure(root, db.db, false) if err != nil { return nil, err } @@ -239,7 +239,7 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { } } - tr, err := trie.NewSecure(root, db.db) + tr, err := trie.NewSecure(root, db.db, true) if err != nil { return nil, err } diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 83c56d5785..7b70b2b310 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -736,7 +736,7 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { if genesis == nil { return errors.New("missing genesis block") } - t, err := trie.NewSecure(genesis.Root(), trie.NewDatabase(db)) + t, err := trie.NewSecure(genesis.Root(), trie.NewDatabase(db), false) if err != nil { return err } @@ -756,7 +756,7 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { return err } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, trie.NewDatabase(db)) + storageTrie, err := trie.NewSecure(acc.Root, trie.NewDatabase(db), true) if err != nil { return err } diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 9eb812f83b..b95bd49ed4 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -42,13 +42,13 @@ func TestGeneration(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -99,13 +99,13 @@ func TestGenerateExistentState(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -179,7 +179,7 @@ type testHelper struct { func newHelper() *testHelper { diskdb := memorydb.New() triedb := trie.NewDatabase(diskdb) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) return &testHelper{ diskdb: diskdb, triedb: triedb, @@ -211,7 +211,7 @@ func (t *testHelper) addSnapStorage(accKey string, keys []string, vals []string) } func (t *testHelper) makeStorageTrie(keys []string, vals []string) []byte { - stTrie, _ := trie.NewSecure(common.Hash{}, t.triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, t.triedb, true) for i, k := range keys { stTrie.Update([]byte(k), []byte(vals[i])) } @@ -384,7 +384,7 @@ func TestGenerateCorruptAccountTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - tr, _ := trie.NewSecure(common.Hash{}, triedb) + tr, _ := trie.NewSecure(common.Hash{}, triedb, false) acc := &Account{Balance: big.NewInt(1), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) tr.Update([]byte("acc-1"), val) // 0xc7a30f39aff471c95d8a837497ad0e49b65be475cc0953540f80cfcdbdcd9074 @@ -428,13 +428,13 @@ func TestGenerateMissingStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -487,13 +487,13 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -537,7 +537,7 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { } func getStorageTrie(n int, triedb *trie.Database) *trie.SecureTrie { - stTrie, _ := trie.NewSecure(common.Hash{}, triedb) + stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) for i := 0; i < n; i++ { k := fmt.Sprintf("key-%d", i) v := fmt.Sprintf("val-%d", i) @@ -554,7 +554,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) { triedb = trie.NewDatabase(diskdb) stTrie = getStorageTrie(5, triedb) ) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) { // Account one in the trie acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) @@ -618,7 +618,7 @@ func TestGenerateWithManyExtraAccounts(t *testing.T) { triedb = trie.NewDatabase(diskdb) stTrie = getStorageTrie(3, triedb) ) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) { // Account one in the trie acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) diff --git a/eth/api.go b/eth/api.go index b1bea41d87..9c2589553c 100644 --- a/eth/api.go +++ b/eth/api.go @@ -525,11 +525,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc } triedb := api.eth.BlockChain().StateCache().TrieDB() - oldTrie, err := trie.NewSecure(startBlock.Root(), triedb) + oldTrie, err := trie.NewSecure(startBlock.Root(), triedb, false) if err != nil { return nil, err } - newTrie, err := trie.NewSecure(endBlock.Root(), triedb) + newTrie, err := trie.NewSecure(endBlock.Root(), triedb, false) if err != nil { return nil, err } diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 314776dffe..e5d210cf00 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -487,7 +487,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s // Make sure we have the state associated with the request triedb := chain.StateCache().TrieDB() - accTrie, err := trie.NewSecure(req.Root, triedb) + accTrie, err := trie.NewSecure(req.Root, triedb, false) if err != nil { // We don't have the requested state available, bail out return nil, nil @@ -529,7 +529,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s if err != nil || account == nil { break } - stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb) + stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb, true) loads++ // always account database reads, even for failures if err != nil { break diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 1dfba03c86..e2865afd0d 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -1604,7 +1604,7 @@ func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { } accounts++ if acc.Root != emptyRoot { - storeTrie, err := trie.NewSecure(acc.Root, triedb) + storeTrie, err := trie.NewSecure(acc.Root, triedb, true) if err != nil { t.Fatal(err) } diff --git a/les/downloader/downloader_test.go b/les/downloader/downloader_test.go index 69bdb90ed2..e270cc0567 100644 --- a/les/downloader/downloader_test.go +++ b/les/downloader/downloader_test.go @@ -229,7 +229,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block { func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { // For now only check that the state trie is correct if block := dl.GetBlockByHash(hash); block != nil { - _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb)) + _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), false) return err } return fmt.Errorf("non existent block: %x", hash[:4]) diff --git a/trie/database.go b/trie/database.go index 002af87498..a4838425c5 100644 --- a/trie/database.go +++ b/trie/database.go @@ -28,6 +28,7 @@ import ( "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -101,10 +102,10 @@ type Database struct { // in the same cache fields). type rawNode []byte -func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } -func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawNode) setEpoch(epcoh uint16) { panic("this should never end up in a live trie") } -func (n rawNode) getEpoch() uint16 { panic("this should never end up in a live trie") } +func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawNode) setEpoch(epcoh types.StateEpoch) { panic("this should never end up in a live trie") } +func (n rawNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } func (n rawNode) EncodeRLP(w io.Writer) error { _, err := w.Write(n) @@ -122,8 +123,10 @@ type rawFullNode [17]node func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawFullNode) setEpoch(epcoh uint16) { panic("this should never end up in a live trie") } -func (n rawFullNode) getEpoch() uint16 { panic("this should never end up in a live trie") } +func (n rawFullNode) setEpoch(epcoh types.StateEpoch) { + panic("this should never end up in a live trie") +} +func (n rawFullNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } func (n rawFullNode) nodeType() int { return rawFullNodeType @@ -145,8 +148,10 @@ type rawShortNode struct { func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawShortNode) setEpoch(epoch uint16) { panic("this should never end up in a live trie") } -func (n rawShortNode) getEpoch() uint16 { panic("this should never end up in a live trie") } +func (n rawShortNode) setEpoch(epoch types.StateEpoch) { + panic("this should never end up in a live trie") +} +func (n rawShortNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } func (n rawShortNode) nodeType() int { return rawShortNodeType diff --git a/trie/iterator_test.go b/trie/iterator_test.go index 871a5f3fca..f2e6c7d586 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -508,7 +508,7 @@ func makeLargeTestTrie() (*Database, *SecureTrie, *loggingDb) { // Create an empty trie logDb := &loggingDb{0, memorydb.New()} triedb := NewDatabase(logDb) - trie, _ := NewSecure(common.Hash{}, triedb) + trie, _ := NewSecure(common.Hash{}, triedb, false) // Fill it with some arbitrary data for i := 0; i < 10000; i++ { diff --git a/trie/node.go b/trie/node.go index d4d84360f1..1db3d4310f 100644 --- a/trie/node.go +++ b/trie/node.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" ) @@ -46,30 +47,30 @@ type node interface { encode(w rlp.EncoderBuffer) fstring(string) string nodeType() int - setEpoch(epoch uint16) - getEpoch() uint16 + setEpoch(epoch types.StateEpoch) + getEpoch() types.StateEpoch } type ( fullNode struct { Children [BranchNodeLength]node // Actual trie node data to encode/decode (needs custom encoder) flags nodeFlag - epoch uint16 `rlp:"-" json:"-"` - shadowNode *shadowBranchNode `rlp:"-" json:"-"` + epoch types.StateEpoch `rlp:"-" json:"-"` + shadowNode shadowBranchNode `rlp:"-" json:"-"` } shortNode struct { Key []byte Val node flags nodeFlag - epoch uint16 `rlp:"-" json:"-"` - shadowNode *shadowExtensionNode `rlp:"-" json:"-"` + epoch types.StateEpoch `rlp:"-" json:"-"` + shadowNode shadowExtensionNode `rlp:"-" json:"-"` } hashNode []byte valueNode []byte ) type RootNode struct { - Epoch uint16 + Epoch types.StateEpoch TrieHash common.Hash ShadowHash common.Hash } @@ -86,19 +87,33 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { } func (n *fullNode) GetShadowNode() *shadowBranchNode { + // TODO:get shadow node from cache or disk if shadow node is nil return &shadowBranchNode{} } -func (n *fullNode) IsChildExpired(pos int) (bool, error) { - return false, nil +func (n *fullNode) GetChildEpoch(index int) types.StateEpoch { + if index < 16 { + return n.GetShadowNode().EpochMap[index] + } + return n.epoch } -func (n *fullNode) GetChildEpoch(pos int) uint16 { - return n.GetShadowNode().EpochMap[pos] +func (n *fullNode) UpdateChildEpoch(index int, epoch types.StateEpoch) { + if index < 16 { + n.GetShadowNode().EpochMap[index] = epoch + } } -func (n *fullNode) UpdateChildEpoch(pos int, epoch uint16) { - n.GetShadowNode().EpochMap[pos] = epoch +func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.StateEpoch) (bool, error) { + childEpoch := n.GetChildEpoch(index) + if currentEpoch-childEpoch >= 2 { + return true, &ExpiredNodeError{ + ExpiredNode: n.Children[index], + Path: prefix, + Epoch: childEpoch, + } + } + return false, nil } func (n *shortNode) GetShadowNode() *shadowExtensionNode { @@ -125,15 +140,15 @@ func (n *shortNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } -func (n *fullNode) setEpoch(epoch uint16) { n.epoch = epoch } -func (n *shortNode) setEpoch(epoch uint16) { n.epoch = epoch } -func (n hashNode) setEpoch(epoch uint16) {} -func (n valueNode) setEpoch(epoch uint16) {} +func (n *fullNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } +func (n *shortNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } +func (n hashNode) setEpoch(epoch types.StateEpoch) {} +func (n valueNode) setEpoch(epoch types.StateEpoch) {} -func (n *fullNode) getEpoch() uint16 { return n.epoch } -func (n *shortNode) getEpoch() uint16 { return n.epoch } -func (n hashNode) getEpoch() uint16 { return 0 } -func (n valueNode) getEpoch() uint16 { return 0 } +func (n *fullNode) getEpoch() types.StateEpoch { return n.epoch } +func (n *shortNode) getEpoch() types.StateEpoch { return n.epoch } +func (n hashNode) getEpoch() types.StateEpoch { return 0 } +func (n valueNode) getEpoch() types.StateEpoch { return 0 } func (n *fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index d6fc697805..a2acc0c17a 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -52,14 +52,30 @@ type SecureTrie struct { // Loaded nodes are kept around until their 'cache generation' expires. // A new cache generation is created by each call to Commit. // cachelimit sets the number of past cache generations to keep. -func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { +func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, error) { if db == nil { panic("trie.NewSecure called without a database") } + epoch := types.StateEpoch(0) + shadowHash := common.Hash{} + if isStorageTrie { + if rootNode := db.RootNode(root); rootNode != nil { + root = rootNode.TrieHash + epoch = rootNode.Epoch + shadowHash = rootNode.ShadowHash + } + } trie, err := New(root, db) if err != nil { return nil, err } + if isStorageTrie { + if trie.root != nil { + trie.root.setEpoch(epoch) + } + trie.shadowHash = shadowHash + } + trie.isStorageTrie = isStorageTrie return &SecureTrie{trie: *trie}, nil } diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index 7bbdf29ef0..f101514675 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -29,7 +29,7 @@ import ( ) func newEmptySecure() *SecureTrie { - trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New())) + trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New()), false) return trie } @@ -37,7 +37,7 @@ func newEmptySecure() *SecureTrie { func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(memorydb.New()) - trie, _ := NewSecure(common.Hash{}, triedb) + trie, _ := NewSecure(common.Hash{}, triedb, false) // Fill it with some arbitrary data content := make(map[string][]byte) diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 66e37224ea..6e2f07fcb7 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -2,6 +2,7 @@ package trie import ( "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" ) @@ -41,11 +42,11 @@ func NewShadowNodeManager(diskdb ethdb.KeyValueStore) *ShadowNodeManager { //} type shadowNodeStorageReaderWriterMock struct { - mockEpoch uint16 + mockEpoch types.StateEpoch nodeMap map[string][]byte } -func newShadowNodeStorageMock(epoch uint16) ShadowNodeStorage { +func newShadowNodeStorageMock(epoch types.StateEpoch) ShadowNodeStorage { return &shadowNodeStorageReaderWriterMock{ mockEpoch: epoch, nodeMap: make(map[string][]byte), @@ -58,8 +59,8 @@ func (s *shadowNodeStorageReaderWriterMock) Get(key []byte) ([]byte, error) { val, ok := s.nodeMap[tmp] if !ok { n := shadowBranchNode{ - ShadowHash: nil, - EpochMap: [16]uint16{}, + ShadowHash: common.Hash{}, + EpochMap: [16]types.StateEpoch{}, } for i := range n.EpochMap { n.EpochMap[i] = s.mockEpoch diff --git a/trie/shadownodes.go b/trie/shadownodes.go index 819b405f6a..240f241794 100644 --- a/trie/shadownodes.go +++ b/trie/shadownodes.go @@ -2,6 +2,7 @@ package trie import ( "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" ) //type shadowNode interface { @@ -9,10 +10,18 @@ import ( //} type shadowExtensionNode struct { - ShadowHash *common.Hash + ShadowHash common.Hash +} + +func NewShadowExtensionNode(hash common.Hash) shadowExtensionNode { + return shadowExtensionNode{hash} } type shadowBranchNode struct { - ShadowHash *common.Hash - EpochMap [16]uint16 + ShadowHash common.Hash + EpochMap [16]types.StateEpoch +} + +func NewShadowBranchNode(hash common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { + return shadowBranchNode{hash, epochMap} } diff --git a/trie/sync_test.go b/trie/sync_test.go index 970730b671..05015c5367 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -29,7 +29,7 @@ import ( func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(memorydb.New()) - trie, _ := NewSecure(common.Hash{}, triedb) + trie, _ := NewSecure(common.Hash{}, triedb, false) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -60,7 +60,7 @@ func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) { // content map. func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) { // Check root availability and trie contents - trie, err := NewSecure(common.BytesToHash(root), db) + trie, err := NewSecure(common.BytesToHash(root), db, false) if err != nil { t.Fatalf("failed to create trie at %x: %v", root, err) } @@ -77,7 +77,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri // checkTrieConsistency checks that all nodes in a trie are indeed present. func checkTrieConsistency(db *Database, root common.Hash) error { // Create and iterate a trie rooted in a subnode - trie, err := NewSecure(root, db) + trie, err := NewSecure(root, db, false) if err != nil { return nil // Consider a non existent state consistent } diff --git a/trie/trie.go b/trie/trie.go index 11ab683e0b..0717e9d5ff 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -66,7 +66,10 @@ type Trie struct { // Keep track of the number leafs which have been inserted since the last // hashing operation. This number will not directly map to the number of // actually unhashed nodes - unhashed int + unhashed int + currentEpoch types.StateEpoch + isStorageTrie bool + shadowHash common.Hash } // newFlag returns the cache flag value for a newly created node. @@ -88,11 +91,6 @@ func New(root common.Hash, db *Database) (*Trie, error) { db: db, sndb: newShadowNodeStorageMock(0), } - //epoch := uint16(0) - if rootNode := db.RootNode(root); rootNode != nil { - root = rootNode.TrieHash - //epoch = rootNode.Epoch - } if root != (common.Hash{}) && root != emptyRoot { rootnode, err := trie.resolveHash(root[:], nil) if err != nil { @@ -122,19 +120,39 @@ func (t *Trie) Get(key []byte) []byte { // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryGet(key []byte) ([]byte, error) { - var nextEpoch uint16 +func (t *Trie) TryGet(key []byte) (value []byte, err error) { + var newroot node + var didResolve bool + if t.isStorageTrie { + var nextEpoch types.StateEpoch + if t.root != nil { + nextEpoch = t.root.getEpoch() + } + value, newroot, didResolve, err = t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, nextEpoch, false) + } else { + value, newroot, didResolve, err = t.tryGet(t.root, keybytesToHex(key), 0) + } + + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + +func (t *Trie) TryGetAndUpdateEpoch(key []byte) ([]byte, error) { + var nextEpoch types.StateEpoch if t.root != nil { nextEpoch = t.root.getEpoch() } - value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0, nextEpoch) + value, newroot, didResolve, err := t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, nextEpoch, true) + if err == nil && didResolve { t.root = newroot } return value, err } -func (t *Trie) tryGet(origNode node, key []byte, pos int, epoch uint16) (value []byte, newnode node, didResolve bool, err error) { +func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -145,17 +163,75 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int, epoch uint16) (value [ // key not found in trie return nil, n, false, nil } - value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key), n.epoch) + value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { n = n.copy() n.Val = newnode } return value, n, didResolve, err case *fullNode: - if expired, err := n.IsChildExpired(pos); expired { + value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) + if err == nil && didResolve { + n = n.copy() + n.Children[key[pos]] = newnode + } + return value, n, didResolve, err + case hashNode: + child, err := t.resolveHash(n, key[:pos]) + if err != nil { + return nil, n, true, err + } + value, newnode, _, err := t.tryGet(child, key, pos) + return value, newnode, true, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + +func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.StateEpoch, updateEpoch bool) (value []byte, newnode node, didResolve bool, err error) { + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, false, nil + case *shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + // node is expired + if expired, err := t.nodeExpired(n, key[:pos]); expired { return nil, n, false, err } - value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(pos)) + + if updateEpoch { + n.setEpoch(t.currentEpoch) + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), t.currentEpoch, true) + } else { + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), epoch, false) + } + if err == nil && didResolve { + n = n.copy() + n.Val = newnode + } + return value, n, didResolve, err + case *fullNode: + // full node is expired + if expired, err := t.nodeExpired(n, key[:pos]); expired { + return nil, n, false, err + } + // child node is expired + if expired, err := n.ChildExpired(key[:pos+1], int(key[pos]), t.currentEpoch); expired { + return nil, n, false, err + } + + if updateEpoch { + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, t.currentEpoch, true) + } else { + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos])), false) + } if err == nil && didResolve { n = n.copy() n.Children[key[pos]] = newnode @@ -166,7 +242,8 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int, epoch uint16) (value [ if err != nil { return nil, n, true, err } - value, newnode, _, err := t.tryGet(child, key, pos, epoch) + child.setEpoch(epoch) + value, newnode, _, err = t.tryGetWithEpoch(child, key, pos, epoch, updateEpoch) return value, newnode, true, err default: panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) @@ -281,14 +358,18 @@ func (t *Trie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { func (t *Trie) TryUpdate(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) + var nextEpoch types.StateEpoch + if t.root != nil { + nextEpoch = t.root.getEpoch() + } if len(value) != 0 { - _, n, err := t.insert(t.root, nil, k, valueNode(value)) + _, n, err := t.insert(t.root, nil, k, valueNode(value), nextEpoch) if err != nil { return err } t.root = n } else { - _, n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k, nextEpoch) if err != nil { return err } @@ -297,7 +378,7 @@ func (t *Trie) TryUpdate(key, value []byte) error { return nil } -func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { +func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateEpoch) (bool, node, error) { if len(key) == 0 { if v, ok := n.(valueNode); ok { return !bytes.Equal(v, value.(valueNode)), value, nil @@ -306,36 +387,68 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error } switch n := n.(type) { case *shortNode: + if t.isStorageTrie && t.currentEpoch >= 2 { + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err + } + } matchlen := prefixLen(key, n.Key) // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + n.setEpoch(t.currentEpoch) + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value, n.epoch) if !dirty || err != nil { return false, n, err } - return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag(), epoch: t.currentEpoch}, nil } // Otherwise branch out at the index where they differ. branch := &fullNode{flags: t.newFlag()} var err error - _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val, t.currentEpoch) if err != nil { return false, nil, err } - _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + if t.isStorageTrie { + branch.setEpoch(t.currentEpoch) + branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) + } + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value, t.currentEpoch) if err != nil { return false, nil, err } + if t.isStorageTrie { + branch.setEpoch(t.currentEpoch) + branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) + } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag()}, nil + return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil case *fullNode: - dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) + if t.isStorageTrie { + if t.currentEpoch >= 2 { + // this full node is expired, return err + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err + } + } + // else, set its epoch to current epoch. + n.setEpoch(t.currentEpoch) + if t.currentEpoch >= 2 { + // if child is expired, return err + if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { + return false, n.Children[key[0]], err + } + } + // else, set child node's epoch to current epoch + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value, n.GetChildEpoch(int(key[0]))) if !dirty || err != nil { return false, n, err } @@ -355,7 +468,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if err != nil { return false, nil, err } - dirty, nn, err := t.insert(rn, prefix, key, value) + rn.setEpoch(epoch) + dirty, nn, err := t.insert(rn, prefix, key, value, epoch) if !dirty || err != nil { return false, rn, err } @@ -378,7 +492,11 @@ func (t *Trie) Delete(key []byte) { func (t *Trie) TryDelete(key []byte) error { t.unhashed++ k := keybytesToHex(key) - _, n, err := t.delete(t.root, nil, k) + var nextEpoch types.StateEpoch + if t.root != nil { + nextEpoch = t.root.getEpoch() + } + _, n, err := t.delete(t.root, nil, k, nextEpoch) if err != nil { return err } @@ -389,9 +507,15 @@ func (t *Trie) TryDelete(key []byte) error { // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { +func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { switch n := n.(type) { case *shortNode: + if t.isStorageTrie && t.currentEpoch >= 2 { + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err + } + n.setEpoch(t.currentEpoch) + } matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { return false, n, nil // don't replace n on mismatch @@ -403,7 +527,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):], n.epoch) if !dirty || err != nil { return false, n, err } @@ -415,13 +539,31 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag()}, nil + return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil default: - return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag(), epoch: t.currentEpoch}, nil } case *fullNode: - dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) + if t.isStorageTrie { + if t.currentEpoch >= 2 { + // this full node is expired, return err + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err + } + } + // else, set its epoch to current epoch. + n.setEpoch(t.currentEpoch) + if t.currentEpoch >= 2 { + // if child is expired, return err + if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { + return false, n.Children[key[0]], err + } + } + // else, set child node's epoch to current epoch + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:], n.GetChildEpoch(int(key[0]))) if !dirty || err != nil { return false, n, err } @@ -495,7 +637,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { if err != nil { return false, nil, err } - dirty, nn, err := t.delete(rn, prefix, key) + rn.setEpoch(epoch) + dirty, nn, err := t.delete(rn, prefix, key, epoch) if !dirty || err != nil { return false, rn, err } @@ -591,6 +734,17 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } +func (t *Trie) nodeExpired(n node, prefix []byte) (bool, error) { + if t.currentEpoch-n.getEpoch() >= 2 { + return true, &ExpiredNodeError{ + ExpiredNode: n, + Path: prefix, + Epoch: n.getEpoch(), + } + } + return false, nil +} + // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { From 4dbdd679f57c46a6d5f1c98118b3180f2c9b8803 Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 25 Apr 2023 16:26:11 +0800 Subject: [PATCH 30/51] feat(simulation): add EstimateGasReviveState fix logic error fix logic error fix comments Squashed commit of the following: commit 27110c96e99b8343161bb8ffa867a109573bedd8 Merge: 1cf57da77 9c2297d35 Author: cryyl <90364156+cryyl@users.noreply.github.com> Date: Wed Apr 26 10:34:14 2023 +0800 Merge pull request #91 from cryyl/state_expire_mpt State expire: implementation of MPT R&W commit 9c2297d357311747a4301eda0db7573ac2838739 Author: cryyl <1226241521@qq.com> Date: Tue Apr 18 16:58:00 2023 +0800 state expriy: implement of MPT read and write Signed-off-by: cryyl <1226241521@qq.com> fix --- accounts/abi/bind/backends/simulated.go | 225 +++++++++++++++++++ accounts/abi/bind/backends/simulated_test.go | 94 ++++++++ core/state/statedb.go | 2 +- core/vm/errors.go | 6 +- core/vm/evm.go | 10 +- 5 files changed, 328 insertions(+), 9 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index a9a4981559..be09d6aed9 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -42,6 +42,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" ) @@ -589,6 +590,172 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMs return hi, nil } +// EstimateGasAndReviveState executes the requested code against the currently pending block/state and +// returns the used amount of gas. If the execution fails due to expired state, it will revive the state +// and try again. +func (b *SimulatedBackend) EstimateGasAndReviveState(ctx context.Context, call ethereum.CallMsg) (uint64, []types.ReviveWitness, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Initialize witnessList + var witnessList []types.ReviveWitness + if call.WitnessList != nil { + witnessList = call.WitnessList + } + witLen := len(witnessList) + + // Determine the lowest and highest possible gas limits to binary search in between + var ( + lo uint64 = params.TxGas - 1 + hi uint64 + cap uint64 + ) + if call.Gas >= params.TxGas { + hi = call.Gas + } else { + hi = b.pendingBlock.GasLimit() + } + // Normalize the max fee per gas the call is willing to spend. + var feeCap *big.Int + if call.GasPrice != nil && (call.GasFeeCap != nil || call.GasTipCap != nil) { + return 0, witnessList, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } else if call.GasPrice != nil { + feeCap = call.GasPrice + } else if call.GasFeeCap != nil { + feeCap = call.GasFeeCap + } else { + feeCap = common.Big0 + } + // Recap the highest gas allowance with account's balance. + if feeCap.BitLen() != 0 { + balance := b.pendingState.GetBalance(call.From) // from can't be nil + available := new(big.Int).Set(balance) + if call.Value != nil { + if call.Value.Cmp(available) >= 0 { + return 0, witnessList, errors.New("insufficient funds for transfer") + } + available.Sub(available, call.Value) + } + allowance := new(big.Int).Div(available, feeCap) + if allowance.IsUint64() && hi > allowance.Uint64() { + transfer := call.Value + if transfer == nil { + transfer = new(big.Int) + } + log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance, + "sent", transfer, "feecap", feeCap, "fundable", allowance) + hi = allowance.Uint64() + } + } + cap = hi + + // Create a helper to check if a gas allowance results in an executable transaction + executable := func(gas uint64, witnessList []types.ReviveWitness) (bool, *core.ExecutionResult, bool, error) { + call.Gas = gas + + snapshot := b.pendingState.Snapshot() + res, evmErrors, err := b.callContractExpired(ctx, call, b.pendingBlock, b.pendingState) + + // Create MPTProof + isExpiredError := false + if len(evmErrors) > 0 { + addressToProofMap := make(map[common.Address][]types.MPTProof) + for _, evmErr := range evmErrors { + if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { + isExpiredError = true + proof, err := b.pendingState.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) // TODO (asyukii): Is this key correct? + if err != nil { + return true, nil, isExpiredError, err + } + addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ + RootKeyHex: stateErr.Path, + Proof: proof, + }) + } + } + + // TODO (asyukii): existing witnessList might already has ReviveWitness with the same address + // Might want to consider decoding it and merge them together + + // Create a ReviveWitness object for each address and add it to witnessList + for addr, proofs := range addressToProofMap { + // Build a storageTrieWitness + storageTrieWitness := types.StorageTrieWitness{ + Address: addr, + ProofList: proofs, + } + // Encode StorageTrieWitness + enc, err := rlp.EncodeToBytes(storageTrieWitness) + if err != nil { + return true, nil, isExpiredError, err + } + // Create a ReviveWitness + reviveWitness := types.ReviveWitness{ + WitnessType: types.StorageTrieWitnessType, + Data: enc, + } + // Append to witness list + witnessList = append(witnessList, reviveWitness) + } + } + + b.pendingState.RevertToSnapshot(snapshot) + + if err != nil { + if _, ok := err.(*state.ExpiredStateError); ok || errors.Is(err, core.ErrIntrinsicGas) { + return true, nil, isExpiredError, nil // Special case, raise gas limit + } + return true, nil, isExpiredError, err // Bail out + } + return res.Failed(), res, isExpiredError, nil + } + // Execute the binary search and hone in on an executable gas limit + for lo+1 < hi { + mid := (hi + lo) / 2 + failed, _, isExpiredError, err := executable(mid, witnessList) + + if isExpiredError { + if witLen == len(witnessList) { + // If witnessList is not updated, it means that the proofs are not + // sufficient to revive the state. + return 0, witnessList, fmt.Errorf("cannot generate enough proofs to revive the state") + } + witLen = len(witnessList) + continue + } + + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more + if err != nil { + return 0, witnessList, err + } + if failed { + lo = mid + } else { + hi = mid + } + } + // Reject the transaction as invalid if it still fails at the highest allowance + if hi == cap { + failed, result, _, err := executable(hi, witnessList) + if err != nil { + return 0, witnessList, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return 0, witnessList, newRevertError(result) + } + return 0, witnessList, result.Err + } + // Otherwise, the specified gas cap is too low + return 0, witnessList, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } + } + return hi, witnessList, nil +} + // callContract implements common code between normal and pending contract calls. // state is modified during execution, make sure to copy it if necessary. func (b *SimulatedBackend) callContract(ctx context.Context, call ethereum.CallMsg, block *types.Block, stateDB *state.StateDB) (*core.ExecutionResult, error) { @@ -646,6 +813,64 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call ethereum.CallM return core.NewStateTransition(vmEnv, msg, gasPool).TransitionDb() } +// callContractExpired implements common code between normal and pending contract calls. +// state is modified during execution, make sure to copy it if necessary. +// It will also return the EVM errors if any. +func (b *SimulatedBackend) callContractExpired(ctx context.Context, call ethereum.CallMsg, block *types.Block, stateDB *state.StateDB) (*core.ExecutionResult, []*vm.EVMError, error) { + // Gas prices post 1559 need to be initialized + if call.GasPrice != nil && (call.GasFeeCap != nil || call.GasTipCap != nil) { + return nil, nil, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } + head := b.blockchain.CurrentHeader() + if !b.blockchain.Config().IsLondon(head.Number) { + // If there's no basefee, then it must be a non-1559 execution + if call.GasPrice == nil { + call.GasPrice = new(big.Int) + } + call.GasFeeCap, call.GasTipCap = call.GasPrice, call.GasPrice + } else { + // A basefee is provided, necessitating 1559-type execution + if call.GasPrice != nil { + // User specified the legacy gas field, convert to 1559 gas typing + call.GasFeeCap, call.GasTipCap = call.GasPrice, call.GasPrice + } else { + // User specified 1559 gas feilds (or none), use those + if call.GasFeeCap == nil { + call.GasFeeCap = new(big.Int) + } + if call.GasTipCap == nil { + call.GasTipCap = new(big.Int) + } + // Backfill the legacy gasPrice for EVM execution, unless we're all zeroes + call.GasPrice = new(big.Int) + if call.GasFeeCap.BitLen() > 0 || call.GasTipCap.BitLen() > 0 { + call.GasPrice = math.BigMin(new(big.Int).Add(call.GasTipCap, head.BaseFee), call.GasFeeCap) + } + } + } + // Ensure message is initialized properly. + if call.Gas == 0 { + call.Gas = 50000000 + } + if call.Value == nil { + call.Value = new(big.Int) + } + // Set infinite balance to the fake caller account. + from := stateDB.GetOrNewStateObject(call.From) + from.SetBalance(math.MaxBig256) + // Execute the call. + msg := callMsg{call} + + txContext := core.NewEVMTxContext(msg) + evmContext := core.NewEVMBlockContext(block.Header(), b.blockchain, nil) + // Create a new environment which holds all relevant information + // about the transaction and calling mechanisms. + vmEnv := vm.NewEVM(evmContext, txContext, stateDB, b.config, vm.Config{NoBaseFee: true}) + gasPool := new(core.GasPool).AddGas(math.MaxUint64) + res, err := core.NewStateTransition(vmEnv, msg, gasPool).TransitionDb() + return res, vmEnv.ErrorCollection, err +} + // SendTransaction updates the pending block to include the given transaction. func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transaction) error { b.mu.Lock() diff --git a/accounts/abi/bind/backends/simulated_test.go b/accounts/abi/bind/backends/simulated_test.go index f857c399f7..97ed6e8061 100644 --- a/accounts/abi/bind/backends/simulated_test.go +++ b/accounts/abi/bind/backends/simulated_test.go @@ -531,6 +531,100 @@ func TestEstimateGas(t *testing.T) { } } +// TODO (asyukii) +// Need some smart way to modify state such that they can be expired +// Add more tests: +// [] - Test for gas estimation of fully expired contract +// [] - Test for gas estimation of contract with partially expired storage +func TestEstimateGasAndReviveState(t *testing.T) { + /* + pragma solidity ^0.8.0; + contract HardcodedStorage { + uint256 private _value1 = 10; + uint256 private _value2 = 20; + uint256 private _value3 = 30; + uint256 private _value5 = 50; + uint256 private _value6 = 60; + uint256 private _value7 = 70; + uint256 private _value8 = 80; + + function getValue1() public view returns (uint256) { + return _value1; + } + } + */ + const contractAbi = "[{\"inputs\":[],\"name\":\"getValue1\",\"outputs\":[{\"internalType\":\"uint256\",\"name\":\"\",\"type\":\"uint256\"}],\"stateMutability\":\"view\",\"type\":\"function\"}]" + const contractBin = "0x6080604052600a6000556014600155601e6002556032600355603c6004556046600555605060065534801561003357600080fd5b5060b6806100426000396000f3fe6080604052348015600f57600080fd5b506004361060285760003560e01c806360d586f814602d575b600080fd5b60336047565b604051603e91906067565b60405180910390f35b60008054905090565b6000819050919050565b6061816050565b82525050565b6000602082019050607a6000830184605a565b9291505056fea26469706673582212208ea04f86f5d169c7e3fde0277a16120a9cc6db59035dee70356b3b1e615fe6a664736f6c63430008120033" + + key, _ := crypto.GenerateKey() + addr := crypto.PubkeyToAddress(key.PublicKey) + opts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) + + sim := NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(params.Ether)}}, 10000000) + defer sim.Close() + + parsed, _ := abi.JSON(strings.NewReader(contractAbi)) + contractAddr, _, _, _ := bind.DeployContract(opts, parsed, common.FromHex(contractBin), sim) + sim.Commit() + + var cases = []struct { + name string + message ethereum.CallMsg + expect uint64 + expectError error + expectData interface{} + }{ + {"plain transfer(valid)", ethereum.CallMsg{ + From: addr, + To: &addr, + Gas: 0, + GasPrice: big.NewInt(0), + Value: big.NewInt(1), + Data: nil, + }, params.TxGas, nil, nil}, + + {"plain transfer(invalid)", ethereum.CallMsg{ + From: addr, + To: &contractAddr, + Gas: 0, + GasPrice: big.NewInt(0), + Value: big.NewInt(1), + Data: nil, + }, 0, errors.New("execution reverted"), nil}, + {"call getValue1(valid)", ethereum.CallMsg{ + From: addr, + To: &contractAddr, + Gas: 0, + GasPrice: big.NewInt(100000000000), + Value: big.NewInt(0), + Data: common.Hex2Bytes("60d586f8"), + }, 23479, nil, nil}, + } + + for _, c := range cases { + got, _, err := sim.EstimateGasAndReviveState(context.Background(), c.message) + if c.expectError != nil { + if err == nil { + t.Fatalf("Expect error, got nil") + } + if c.expectError.Error() != err.Error() { + t.Fatalf("Expect error, want %v, got %v", c.expectError, err) + } + if c.expectData != nil { + if err, ok := err.(*revertError); !ok { + t.Fatalf("Expect revert error, got %T", err) + } else if !reflect.DeepEqual(err.ErrorData(), c.expectData) { + t.Fatalf("Error data mismatch, want %v, got %v", c.expectData, err.ErrorData()) + } + } + continue + } + if got != c.expect { + t.Fatalf("Gas estimation mismatch, want %d, got %d", c.expect, got) + } + } +} + func TestEstimateGasWithPrice(t *testing.T) { key, _ := crypto.GenerateKey() addr := crypto.PubkeyToAddress(key.PublicKey) diff --git a/core/state/statedb.go b/core/state/statedb.go index 2db95dc5cb..b9722d0c27 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -514,7 +514,7 @@ func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, key c if trie == nil { return proof, errors.New("storage trie for requested address does not exist") } - err := trie.ProveStorageWitness(crypto.Keccak256(key.Bytes()), prefixKeyHex, &proof) + err := trie.ProveStorageWitness(crypto.Keccak256(key.Bytes()), prefixKeyHex, &proof) // TODO (asyukii): Might not need the Keccak256 hash, revisit this return proof, err } diff --git a/core/vm/errors.go b/core/vm/errors.go index 3956c10a39..03ecf2092a 100644 --- a/core/vm/errors.go +++ b/core/vm/errors.go @@ -77,7 +77,7 @@ type EVMError struct { from common.Address to common.Address opcode OpCode - err error + Err error } func NewEVMErr(contract *Contract, op OpCode, err error) *EVMError { @@ -85,10 +85,10 @@ func NewEVMErr(contract *Contract, op OpCode, err error) *EVMError { from: contract.Caller(), to: contract.Address(), opcode: op, - err: err, + Err: err, } } func (e *EVMError) Error() string { - return fmt.Sprintf("EVM err, from %v to %v at %v, got: %v", e.from, e.to, e.opcode, e.err) + return fmt.Sprintf("EVM err, from %v to %v at %v, got: %v", e.from, e.to, e.opcode, e.Err) } diff --git a/core/vm/evm.go b/core/vm/evm.go index ef98952562..3408ba40ab 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -136,8 +136,8 @@ type EVM struct { // applied in opCall*. callGasTemp uint64 - // errorCollection all op code and err list will collect in here - errorCollection []*EVMError + // ErrorCollection all op code and err list will collect in here + ErrorCollection []*EVMError } // NewEVM returns a new EVM. The returned EVM is not thread safe and should @@ -155,7 +155,7 @@ func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig evm.depth = 0 evm.interpreter = NewEVMInterpreter(evm, config) - evm.errorCollection = []*EVMError{} + evm.ErrorCollection = []*EVMError{} return evm } @@ -537,9 +537,9 @@ func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment * func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig } func (evm *EVM) AppendErr(err *EVMError) { - evm.errorCollection = append(evm.errorCollection, err) + evm.ErrorCollection = append(evm.ErrorCollection, err) } func (evm *EVM) Errors() []*EVMError { - return evm.errorCollection + return evm.ErrorCollection } From 833a24597ddc25dd86a78b5a17dd62818ac20606 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Thu, 27 Apr 2023 22:17:27 +0800 Subject: [PATCH 31/51] shadownode/difflayer: support shadow node diff layer & reorg; --- core/rawdb/accessors_snapshot.go | 17 ++ core/rawdb/schema.go | 3 + trie/shadow_node_difflayer.go | 350 +++++++++++++++++++++++++++++++ 3 files changed, 370 insertions(+) create mode 100644 trie/shadow_node_difflayer.go diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go index 1c828662c1..220b41ca6f 100644 --- a/core/rawdb/accessors_snapshot.go +++ b/core/rawdb/accessors_snapshot.go @@ -141,6 +141,23 @@ func DeleteSnapshotJournal(db ethdb.KeyValueWriter) { } } +func ReadShadowNodeSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(shadowNodeSnapshotJournalKey) + return data +} + +func WriteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(shadowNodeSnapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +func DeleteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(shadowNodeSnapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + // ReadSnapshotGenerator retrieves the serialized snapshot generator saved at // the last shutdown. func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte { diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index e04f94e7d4..d1e3aa76d5 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -63,6 +63,9 @@ var ( // snapshotSyncStatusKey tracks the snapshot sync status across restarts. snapshotSyncStatusKey = []byte("SnapshotSyncStatus") + // shadowNodeSnapshotJournalKey tracks the in-memory diff layers across restarts. + shadowNodeSnapshotJournalKey = []byte("ShadowNodeSnapshotJournalKey") + // txIndexTailKey tracks the oldest block whose transactions have been indexed. txIndexTailKey = []byte("TransactionIndexTail") diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go new file mode 100644 index 0000000000..5271b1ef58 --- /dev/null +++ b/trie/shadow_node_difflayer.go @@ -0,0 +1,350 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" + "io" + "math/big" + "sync" +) + +const ( + // MaxShadowNodeDiffDepth default is 128 layers + MaxShadowNodeDiffDepth = 128 + journalVersion = 1 +) + +// shadowNodeSnapshot record diff layer and disk layer of shadow nodes, support mini reorg +type shadowNodeSnapshot interface { + // Root block state root + Root() common.Hash + + // ShadowNode query shadow node from db, got RLP format + ShadowNode(addrHash common.Hash, path string) ([]byte, error) + + // Parent parent snap + Parent() shadowNodeSnapshot + + // Update create a new diff layer from here + Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) + + // Journal commit self as a journal to buffer + Journal(buffer *bytes.Buffer) (common.Hash, error) +} + +// shadowNodeSnapTree maintain all diff layers support reorg, will flush to db when MaxShadowNodeDiffDepth reach +// every layer response to a block state change set, there no flatten layers operation. +type shadowNodeSnapTree struct { + diskdb ethdb.KeyValueStore + + // diff layers + layers map[common.Hash]shadowNodeSnapshot + children map[common.Hash][]common.Hash + + // disk layer + // TODO(0xbundler): add disk layer for history & changeSet + diskCache map[common.Hash]map[string][]byte + + lock sync.RWMutex +} + +func newShadowNodeDiffTree(diskdb ethdb.KeyValueStore) (*shadowNodeSnapTree, error) { + diskLayer, err := loadDiskLayer(diskdb) + if err != nil { + return nil, err + } + layers, children, err := loadDiffLayers(diskdb, diskLayer) + if err != nil { + return nil, err + } + layers[diskLayer.blockRoot] = diskLayer + // check if continuously after disk layer + if len(children[diskLayer.blockRoot]) == 0 { + return nil, errors.New("cannot found any diff layers link to disk layer") + } + return &shadowNodeSnapTree{ + diskdb: diskdb, + layers: layers, + children: children, + diskCache: make(map[common.Hash]map[string][]byte), + }, nil +} + +// Cap TODO(0xbundler): keep tree depth not greater MaxShadowNodeDiffDepth, all forks parent to disk layer will delete +// TODO(0xbundler): store disk layer meta(blockNumber, blockRoot) too +func (s *shadowNodeSnapTree) Cap(blockRoot common.Hash) error { + return nil +} + +func (s *shadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { + if blockRoot == parentRoot { + return errors.New("snapshot cycle") + } + // Generate a new snapshot on top of the parent + parent := s.Snapshot(parentRoot) + if parent == nil { + return fmt.Errorf("parent snapshot missing") + } + snap, err := parent.Update(blockNumber, blockRoot, nodeSet) + if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + + s.layers[blockRoot] = snap + s.children[parentRoot] = append(s.children[parentRoot], blockRoot) + return nil +} + +func (s *shadowNodeSnapTree) Snapshot(blockRoot common.Hash) shadowNodeSnapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.layers[blockRoot] +} + +func (s *shadowNodeSnapTree) Journal() error { + s.lock.Lock() + defer s.lock.Unlock() + + // Firstly write out the metadata of journal + journal := new(bytes.Buffer) + if err := rlp.Encode(journal, journalVersion); err != nil { + return err + } + for _, snap := range s.layers { + if _, err := snap.Journal(journal); err != nil { + return err + } + } + rawdb.WriteShadowNodeSnapshotJournal(s.diskdb, journal.Bytes()) + return nil +} + +func loadDiskLayer(db ethdb.KeyValueStore) (*shadowNodeDiskLayer, error) { + // TODO(0xbundler): load disk layer meta(blockNumber, blockRoot) + return newShadowNodeDiskLayer(db, common.Big0, common.Hash{}) +} + +func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map[common.Hash]shadowNodeSnapshot, map[common.Hash][]common.Hash, error) { + journal := rawdb.ReadShadowNodeSnapshotJournal(db) + if len(journal) == 0 { + return nil, nil, errors.New("journal is empty") + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + // Firstly, resolve the first element as the journal version + version, err := r.Uint64() + if err != nil { + return nil, nil, errors.New("failed to resolve journal version") + } + if version != journalVersion { + return nil, nil, errors.New("wrong journal version") + } + + layers := make(map[common.Hash]shadowNodeSnapshot) + parents := make(map[common.Hash]common.Hash) + for { + var ( + parent common.Hash + number big.Int + root common.Hash + js []journalShadowNode + ) + // Read the next diff journal entry + if err := r.Decode(&number); err != nil { + // The first read may fail with EOF, marking the end of the journal + if errors.Is(err, io.EOF) { + break + } + return nil, nil, fmt.Errorf("load diff parent: %v", err) + } + if err := r.Decode(&parent); err != nil { + return nil, nil, fmt.Errorf("load diff parent: %v", err) + } + // Read the next diff journal entry + if err := r.Decode(&root); err != nil { + return nil, nil, fmt.Errorf("load diff root: %v", err) + } + if err := r.Decode(&js); err != nil { + return nil, nil, fmt.Errorf("load diff storage: %v", err) + } + + nodeSet := make(map[common.Hash]map[string][]byte) + for _, entry := range js { + nodes := make(map[string][]byte) + for i, key := range entry.Keys { + if len(entry.Vals[i]) > 0 { // RLP loses nil-ness, but `[]byte{}` is not a valid item, so reinterpret that + nodes[key] = entry.Vals[i] + } else { + nodes[key] = nil + } + } + nodeSet[entry.Hash] = nodes + } + + parents[root] = parent + layers[root] = newShadowNodeDiffLayer(&number, root, nil, nodeSet) + } + + children := make(map[common.Hash][]common.Hash) + for t, s := range layers { + parent := parents[t] + children[parent] = append(children[parent], t) + if p, ok := layers[parent]; ok { + s.(*shadowNodeDiffLayer).parent = p + } else if parent == diskLayer.Root() { + s.(*shadowNodeDiffLayer).parent = diskLayer + } else { + return nil, nil, errors.New("cannot found the snap's parent") + } + } + return layers, children, nil +} + +type shadowNodeDiffLayer struct { + blockNumber *big.Int + blockRoot common.Hash + parent shadowNodeSnapshot + nodeSet map[common.Hash]map[string][]byte + + lock sync.RWMutex +} + +func newShadowNodeDiffLayer(blockNumber *big.Int, blockRoot common.Hash, parent shadowNodeSnapshot, nodeSet map[common.Hash]map[string][]byte) *shadowNodeDiffLayer { + return &shadowNodeDiffLayer{ + blockNumber: blockNumber, + blockRoot: blockRoot, + parent: parent, + nodeSet: nodeSet, + } +} + +func (s *shadowNodeDiffLayer) Root() common.Hash { + s.lock.RLock() + defer s.lock.RUnlock() + return s.blockRoot +} + +func (s *shadowNodeDiffLayer) ShadowNode(addrHash common.Hash, prefix string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + cm, exist := s.nodeSet[addrHash] + if !exist { + return nil, errors.New("cannot find the address") + } + + ret, exist := cm[prefix] + if exist { + return ret, nil + } + return s.parent.ShadowNode(addrHash, prefix) +} + +func (s *shadowNodeDiffLayer) Parent() shadowNodeSnapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.parent +} + +func (s *shadowNodeDiffLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { + s.lock.RLock() + if s.blockNumber.Cmp(blockNumber) < 0 { + return nil, errors.New("update a unordered diff layer") + } + s.lock.RUnlock() + return newShadowNodeDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *shadowNodeDiffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + storage := make([]journalShadowNode, 0, len(s.nodeSet)) + for hash, nodes := range s.nodeSet { + keys := make([]string, 0, len(nodes)) + vals := make([][]byte, 0, len(nodes)) + for key, val := range nodes { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalShadowNode{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + return s.blockRoot, nil +} + +type journalShadowNode struct { + Hash common.Hash + Keys []string + Vals [][]byte +} + +type shadowNodeDiskLayer struct { + // TODO(0xbundler): add history & changeSet later + diskdb ethdb.KeyValueReader + blockNumber *big.Int + blockRoot common.Hash + cache map[common.Hash]map[string][]byte + + lock sync.RWMutex +} + +func newShadowNodeDiskLayer(diskdb ethdb.KeyValueReader, blockNumber *big.Int, blockRoot common.Hash) (*shadowNodeDiskLayer, error) { + return &shadowNodeDiskLayer{ + diskdb: diskdb, + blockNumber: blockNumber, + blockRoot: blockRoot, + cache: make(map[common.Hash]map[string][]byte), + }, nil +} + +func (s *shadowNodeDiskLayer) Root() common.Hash { + return s.blockRoot +} + +func (s *shadowNodeDiskLayer) ShadowNode(addrHash common.Hash, path string) ([]byte, error) { + nodeSet, exist := s.cache[addrHash] + if exist { + if enc, ok := nodeSet[path]; ok { + return enc, nil + } + } + //TODO(0xbundler): return history & changeSet later + s.cache[addrHash] = make(map[string][]byte) + n := shadowBranchNode{ + ShadowHash: common.Hash{}, + EpochMap: [16]types.StateEpoch{}, + } + enc, err := rlp.EncodeToBytes(n) + if err != nil { + return nil, err + } + s.cache[addrHash][path] = enc + return enc, nil +} + +func (s *shadowNodeDiskLayer) Parent() shadowNodeSnapshot { + return nil +} + +func (s *shadowNodeDiskLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { + s.lock.RLock() + if s.blockNumber.Cmp(blockNumber) < 0 { + return nil, errors.New("update a unordered diff layer") + } + s.lock.RUnlock() + return newShadowNodeDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *shadowNodeDiskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + return common.Hash{}, nil +} From 1b2cc78c2320a467f5e36fa4080aae066903d81f Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 28 Apr 2023 10:40:16 +0800 Subject: [PATCH 32/51] shadownode/diff: opt lock, add more UTs; blockchain: handle shadow node tree; statedb: using shadow node database, commit later; trie: support resolve shadow node, and r&w, add root node parse; --- accounts/abi/bind/backends/simulated.go | 7 +- cmd/evm/internal/t8ntool/execution.go | 5 +- cmd/evm/runner.go | 6 +- core/blockchain.go | 19 +- core/blockchain_reader.go | 9 +- core/chain_makers.go | 8 +- core/genesis.go | 2 +- core/state/database.go | 31 +++ core/state/iterator.go | 1 + core/state/state_object.go | 24 +-- core/state/statedb.go | 41 +++- core/state/trie_prefetcher.go | 2 + core/state_processor.go | 3 +- core/tx_pool.go | 4 +- eth/state_accessor.go | 8 +- les/server_requests.go | 1 + light/trie.go | 15 +- trie/database.go | 2 +- trie/dummy_trie.go | 8 + trie/node.go | 44 +++- trie/node_test.go | 34 +++ trie/secure_trie.go | 32 +++ trie/shadow_node.go | 182 +++++++++++----- trie/shadow_node_difflayer.go | 246 +++++++++++++++++----- trie/shadow_node_difflayer_test.go | 263 ++++++++++++++++++++++++ trie/shadow_node_test.go | 81 ++++++++ trie/shadownodes.go | 27 --- trie/trie.go | 112 +++++++++- trie/trie_test.go | 37 ++++ 29 files changed, 1076 insertions(+), 178 deletions(-) create mode 100644 trie/shadow_node_difflayer_test.go create mode 100644 trie/shadow_node_test.go delete mode 100644 trie/shadownodes.go diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index be09d6aed9..91a40206ff 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -132,7 +132,7 @@ func (b *SimulatedBackend) rollback(parent *types.Block) { b.pendingBlock = blocks[0] blockNum := new(big.Int).Add(parent.Number(), common.Big1) - b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), b.blockchain.StateCache(), nil, types.GetStateEpoch(b.config, blockNum)) + b.pendingState, _ = state.NewWithEpoch(b.config, blockNum, b.pendingBlock.Root(), b.blockchain.StateCache(), nil, b.blockchain.ShadowNodeTree()) } // Fork creates a side-chain that can be used to simulate reorgs. @@ -901,7 +901,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), stateDB.Database(), nil, types.GetStateEpoch(b.config, b.pendingBlock.Number())) + b.pendingState, _ = state.NewWithEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) b.pendingReceipts = receipts[0] return nil } @@ -1017,8 +1017,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.NewWithEpoch(b.pendingBlock.Root(), stateDB.Database(), nil, types.GetStateEpoch(b.config, b.pendingBlock.Number())) - + b.pendingState, _ = state.NewWithEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) return nil } diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 0e2fc0b255..4b7f2fbe6f 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -272,7 +272,8 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) *state.StateDB { sdb := state.NewDatabase(db) - statedb, _ := state.NewWithEpoch(common.Hash{}, sdb, nil, types.GetStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1))) + tree, _ := trie.NewShadowNodeSnapTree(db) + statedb, _ := state.NewWithEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1), common.Hash{}, sdb, nil, tree) for addr, a := range pre.Pre { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) @@ -285,7 +286,7 @@ func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) statedb.Finalise(false) statedb.AccountsIntermediateRoot() root, _, _ := statedb.Commit(nil) - statedb, _ = state.NewWithEpoch(root, sdb, nil, types.GetStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number))) + statedb, _ = state.NewWithEpoch(config, new(big.Int).SetUint64(pre.Env.Number), root, sdb, nil, tree) return statedb } diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index 4964e41a14..1baffb5db9 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -28,8 +28,6 @@ import ( "testing" "time" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/cmd/evm/internal/compiler" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" @@ -141,9 +139,9 @@ func runCmd(ctx *cli.Context) error { db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) chainConfig = gen.Config - statedb, _ = state.NewWithEpoch(genesis.Root(), state.NewDatabase(db), nil, types.GetStateEpoch(chainConfig, genesis.Number())) + statedb, _ = state.New(genesis.Root(), state.NewDatabase(db), nil) } else { - statedb, _ = state.NewWithEpoch(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil, types.StateEpoch0) + statedb, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) genesisConfig = new(core.Genesis) } if ctx.GlobalString(SenderFlag.Name) != "" { diff --git a/core/blockchain.go b/core/blockchain.go index 09a7922065..223ff4a4cc 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -192,6 +192,8 @@ type BlockChain struct { gcproc time.Duration // Accumulates canonical block processing for trie dumping commitLock sync.Mutex // CommitLock is used to protect above field from being modified concurrently + shadowNodeTree *trie.ShadowNodeSnapTree + // txLookupLimit is the maximum number of blocks from head whose tx indices // are reserved: // * 0: means no limit and regenerate any missing indexes @@ -362,9 +364,14 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par return nil, err } + // load shadow node tree to R&W + if bc.shadowNodeTree, err = trie.NewShadowNodeSnapTree(db); err != nil { + return nil, err + } + // Make sure the state associated with the block is available head := bc.CurrentBlock() - if _, err := state.NewWithEpoch(head.Root(), bc.stateCache, bc.snaps, types.GetStateEpoch(chainConfig, head.Number())); err != nil { + if _, err := state.NewWithEpoch(chainConfig, head.Number(), head.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { // Head state is missing, before the state recovery, find out the // disk layer point of snapshot(if it's enabled). Make sure the // rewound point is lower than disk layer. @@ -715,7 +722,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo enoughBeyondCount = beyondCount > maxBeyondBlocks - if _, err := state.NewWithEpoch(newHeadBlock.Root(), bc.stateCache, bc.snaps, types.GetStateEpoch(bc.chainConfig, header.Number)); err != nil { + if _, err := state.NewWithEpoch(bc.chainConfig, newHeadBlock.Number(), newHeadBlock.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) @@ -1074,6 +1081,11 @@ func (bc *BlockChain) Stop() { log.Error("Failed to journal state snapshot", "err", err) } } + if bc.shadowNodeTree != nil { + if err := bc.shadowNodeTree.Journal(); err != nil { + log.Error("Failed to journal shadow node snapshot", "err", err) + } + } // Ensure the state of a recent block is also stored to disk before exiting. // We're writing three different states to catch different restart scenarios: @@ -1882,8 +1894,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) if parent == nil { parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1) } - statedb, err := state.NewWithEpoch(parent.Root, bc.stateCache, bc.snaps, - types.GetStateEpoch(bc.chainConfig, block.Number())) + statedb, err := state.NewWithEpoch(bc.chainConfig, block.Number(), parent.Root, bc.stateCache, bc.snaps, bc.ShadowNodeTree()) if err != nil { return it.index, err } diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index 3426bebd77..70a5138e12 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -19,6 +19,8 @@ package core import ( "math/big" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core/rawdb" @@ -314,7 +316,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) { - return state.NewWithEpoch(root, bc.stateCache, bc.snaps, types.GetStateEpoch(bc.chainConfig, number)) + return state.NewWithEpoch(bc.chainConfig, number, root, bc.stateCache, bc.snaps, bc.shadowNodeTree) } // Config retrieves the chain's fork configuration. @@ -328,6 +330,11 @@ func (bc *BlockChain) Snapshots() *snapshot.Tree { return bc.snaps } +// ShadowNodeTree returns the blockchain shadow node tree. +func (bc *BlockChain) ShadowNodeTree() *trie.ShadowNodeSnapTree { + return bc.shadowNodeTree +} + // Validator returns the current validator. func (bc *BlockChain) Validator() Validator { return bc.validator diff --git a/core/chain_makers.go b/core/chain_makers.go index e1e9aeec7c..24b103eb6e 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -20,6 +20,8 @@ import ( "fmt" "math/big" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/misc" @@ -276,7 +278,11 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse } for i := 0; i < n; i++ { number := new(big.Int).Add(parent.Number(), common.Big1) - statedb, err := state.NewWithEpoch(parent.Root(), state.NewDatabase(db), nil, types.GetStateEpoch(config, number)) + tree, err := trie.NewShadowNodeSnapTree(db) + if err != nil { + panic(err) + } + statedb, err := state.NewWithEpoch(config, number, parent.Root(), state.NewDatabase(db), nil, tree) if err != nil { panic(err) } diff --git a/core/genesis.go b/core/genesis.go index bc29e68757..75bd357cd4 100644 --- a/core/genesis.go +++ b/core/genesis.go @@ -182,7 +182,7 @@ func SetupGenesisBlockWithOverride(db ethdb.Database, genesis *Genesis, override // We have the genesis block in database(perhaps in ancient database) // but the corresponding state is missing. header := rawdb.ReadHeader(db, stored, 0) - if _, err := state.NewWithEpoch(header.Root, state.NewDatabaseWithConfigAndCache(db, nil), nil, types.StateEpoch0); err != nil { + if _, err := state.New(header.Root, state.NewDatabaseWithConfigAndCache(db, nil), nil); err != nil { if genesis == nil { genesis = DefaultGenesisBlock() } diff --git a/core/state/database.go b/core/state/database.go index a0bc01dedd..a90d03ff8f 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -55,6 +55,9 @@ type Database interface { // OpenStorageTrie opens the storage trie of an account. OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + // OpenStorageTrieWithShadowNode opens the storage trie of an account and allow rw shadow nodes. + OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (Trie, error) + // CopyTrie returns an independent copy of the given trie. CopyTrie(Trie) Trie @@ -93,6 +96,9 @@ type Trie interface { // trie.MissingNodeError is returned. TryGet(key []byte) ([]byte, error) + // TryUpdateEpoch just update key's epoch, only using in storage trie + TryUpdateEpoch(key []byte) error + // TryUpdateAccount abstract an account write in the trie. TryUpdateAccount(key []byte, account *types.StateAccount) error @@ -133,6 +139,8 @@ type Trie interface { ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error ReviveTrie(proof []*trie.MPTProofNub) []*trie.MPTProofNub + + Epoch() types.StateEpoch } // NewDatabase creates a backing store for state. The returned database is safe for @@ -246,6 +254,29 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { return tr, nil } +// OpenStorageTrieWithShadowNode opens the storage trie of an account. +func (db *cachingDB) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (Trie, error) { + if db.noTries { + return trie.NewEmptyTrie(), nil + } + if db.storageTrieCache != nil { + if tries, exist := db.storageTrieCache.Get(addrHash); exist { + triesPairs := tries.([3]*triePair) + for _, triePair := range triesPairs { + if triePair != nil && triePair.root == root && triePair.trie.Epoch() == curEpoch { + return triePair.trie.(*trie.SecureTrie).Copy(), nil + } + } + } + } + + tr, err := trie.NewStorageSecure(curEpoch, root, db.db, sndb) + if err != nil { + return nil, err + } + return tr, nil +} + func (db *cachingDB) CacheAccount(root common.Hash, t Trie) { if db.accountTrieCache == nil { return diff --git a/core/state/iterator.go b/core/state/iterator.go index 611df52431..bbb4329dab 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -109,6 +109,7 @@ func (it *NodeIterator) step() error { if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } + // TODO(0xbundler): fix iterator later dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root) if err != nil { return err diff --git a/core/state/state_object.go b/core/state/state_object.go index 26762b31f8..7cda8c0d3e 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -114,8 +114,8 @@ type StateObject struct { deleted bool //encode - encodeData []byte - Epoch types.StateEpoch + encodeData []byte + targetEpoch types.StateEpoch } // empty returns whether the account is considered empty. @@ -153,7 +153,7 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *St pendingReviveState: make(map[string]common.Hash), dirtyAccessedState: make(map[common.Hash]int), pendingAccessedState: make(map[common.Hash]int), - Epoch: db.Epoch, + targetEpoch: db.targetEpoch, } } @@ -196,9 +196,9 @@ func (s *StateObject) getTrie(db Database) Trie { } if s.trie == nil { var err error - s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + s.trie, err = db.OpenStorageTrieWithShadowNode(s.addrHash, s.data.Root, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) if err != nil { - s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.trie, _ = db.OpenStorageTrieWithShadowNode(s.addrHash, common.Hash{}, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } @@ -303,7 +303,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha } if err == nil { if snapVal, err := snapshot.ParseSnapValFromBytes(enc); err == nil { - if types.EpochExpired(snapVal.Epoch, s.Epoch) { + if types.EpochExpired(snapVal.Epoch, s.targetEpoch) { return common.Hash{}, NewPlainExpiredStateError(s.address, key, snapVal.Epoch) } return snapVal.Val, nil @@ -491,11 +491,11 @@ func (s *StateObject) updateTrie(db Database) Trie { } usedStorage = append(usedStorage, common.CopyBytes(key[:])) } - // TODO(0xbundler): call TryUpdateEpoch later - //for key := range accessStorage { - // s.setError(tr.TryUpdateEpoch(key[:])) - // usedStorage = append(usedStorage, common.CopyBytes(key[:])) - //} + // refresh accessed slots' epoch + for key := range accessStorage { + s.setError(tr.TryUpdateEpoch(key[:])) + usedStorage = append(usedStorage, common.CopyBytes(key[:])) + } }() if s.db.snap != nil { // If state snapshotting is active, cache the data til commit @@ -511,7 +511,7 @@ func (s *StateObject) updateTrie(db Database) Trie { } s.db.snapStorageMux.Unlock() for key, value := range dirtyStorage { - enc, err := snapshot.NewSnapValBytes(s.Epoch, value) + enc, err := snapshot.NewSnapValBytes(s.targetEpoch, value) if err != nil { s.setError(err) } diff --git a/core/state/statedb.go b/core/state/statedb.go index b9722d0c27..cf1c81921e 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/gopool" "github.com/ethereum/go-ethereum/core/rawdb" @@ -87,6 +89,7 @@ type StateDB struct { fullProcessed bool pipeCommit bool + shadowNodeRW *trie.ShadowNodeStorageRW snaps *snapshot.Tree snap snapshot.Snapshot snapAccountMux sync.Mutex // Mutex for snap account access @@ -128,8 +131,8 @@ type StateDB struct { validRevisions []revision nextRevisionId int - // state epoch - Epoch types.StateEpoch + targetEpoch types.StateEpoch + targetBlk *big.Int // Measurements gathered during execution for debugging purposes MetricsMux sync.Mutex @@ -152,8 +155,20 @@ type StateDB struct { } // NewWithEpoch creates a new state from a given trie. -func NewWithEpoch(root common.Hash, db Database, snaps *snapshot.Tree, epoch types.StateEpoch) (*StateDB, error) { - return newStateDB(root, db, snaps, epoch) +func NewWithEpoch(config *params.ChainConfig, targetBlock *big.Int, root common.Hash, db Database, snaps *snapshot.Tree, sntree *trie.ShadowNodeSnapTree) (*StateDB, error) { + targetEpoch := types.GetStateEpoch(config, targetBlock) + stateDB, err := newStateDB(root, db, snaps, targetEpoch) + if err != nil { + return nil, err + } + + // init target block and shadowNodeRW + stateDB.targetBlk = targetBlock + stateDB.shadowNodeRW, err = trie.NewShadowNodeStorageRW(sntree, root) + if err != nil { + return nil, err + } + return stateDB, nil } // New creates a new state from a given trie, it inits at Epoch0 @@ -173,7 +188,7 @@ func NewWithSharedPool(root common.Hash, db Database, snaps *snapshot.Tree) (*St return statedb, nil } -func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree, epoch types.StateEpoch) (*StateDB, error) { +func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree, targetEpoch types.StateEpoch) (*StateDB, error) { sdb := &StateDB{ db: db, originalRoot: root, @@ -185,7 +200,7 @@ func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree, epoch types preimages: make(map[common.Hash][]byte), journal: newJournal(), hasher: crypto.NewKeccakState(), - Epoch: epoch, + targetEpoch: targetEpoch, } if sdb.snaps != nil { @@ -1388,6 +1403,10 @@ func (s *StateDB) LightCommit() (common.Hash, *types.DiffLayer, error) { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() error) (common.Hash, *types.DiffLayer, error) { + if s.targetEpoch > 0 && s.shadowNodeRW == nil { + return common.Hash{}, nil, errors.New("cannot commit shadow node") + } + if s.dbErr != nil { s.StopPrefetcher() return common.Hash{}, nil, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) @@ -1625,6 +1644,12 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } + if s.shadowNodeRW != nil && s.originalRoot != root { + if err := s.shadowNodeRW.Commit(s.targetBlk, root); err != nil { + return common.Hash{}, nil, err + } + } + return root, diffLayer, nil } @@ -1806,3 +1831,7 @@ func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { return nil } + +func (s *StateDB) openShadowStorage(addr common.Hash) trie.ShadowNodeStorage { + return trie.NewShadowNodeStorage4Trie(addr, s.shadowNodeRW) +} diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index cd51820e9e..6350b6b54d 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -473,6 +473,7 @@ func (sf *subfetcher) loop() { trie, err = sf.db.OpenTrie(sf.root) } else { // address is useless + // TODO(0xbundler): fix fetcher later trie, err = sf.db.OpenStorageTrie(sf.accountHash, sf.root) } if err != nil { @@ -491,6 +492,7 @@ func (sf *subfetcher) loop() { sf.trie, err = sf.db.OpenTrie(sf.root) } else { // address is useless + // TODO(0xbundler): fix fetcher later sf.trie, err = sf.db.OpenStorageTrie(sf.accountHash, sf.root) } if err != nil { diff --git a/core/state_processor.go b/core/state_processor.go index 7b926b742b..869f617bcc 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -122,7 +122,7 @@ func (p *LightStateProcessor) Process(block *types.Block, statedb *state.StateDB // prepare new statedb statedb.StopPrefetcher() parent := p.bc.GetHeader(block.ParentHash(), block.NumberU64()-1) - statedb, err = state.NewWithEpoch(parent.Root, p.bc.stateCache, p.bc.snaps, types.GetStateEpoch(p.config, block.Number())) + statedb, err = state.NewWithEpoch(p.config, block.Number(), parent.Root, p.bc.stateCache, p.bc.snaps, p.bc.shadowNodeTree) if err != nil { return statedb, nil, nil, 0, err } @@ -261,6 +261,7 @@ func (p *LightStateProcessor) LightProcess(diffLayer *types.DiffLayer, block *ty //update storage latestRoot := common.BytesToHash(latestAccount.Root) if latestRoot != previousAccount.Root { + // TODO(0xbundler): fix light process later accountTrie, err := statedb.Database().OpenStorageTrie(addrHash, previousAccount.Root) if err != nil { errChan <- err diff --git a/core/tx_pool.go b/core/tx_pool.go index c6d7383ba6..896eb9ad41 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -1428,7 +1428,8 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { if newHead == nil { newHead = pool.chain.CurrentBlock().Header() // Special case during testing } - statedb, err := pool.chain.StateAt(newHead.Root, newHead.Number) + next := new(big.Int).Add(newHead.Number, big.NewInt(1)) + statedb, err := pool.chain.StateAt(newHead.Root, next) if err != nil { log.Error("Failed to reset txpool state", "err", err) return @@ -1443,7 +1444,6 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { pool.addTxsLocked(reinject, false) // Update all fork indicator by next pending block number. - next := new(big.Int).Add(newHead.Number, big.NewInt(1)) pool.istanbul = pool.chainconfig.IsIstanbul(next) pool.eip2718 = pool.chainconfig.IsBerlin(next) pool.eip1559 = pool.chainconfig.IsLondon(next) diff --git a/eth/state_accessor.go b/eth/state_accessor.go index d6ca529589..41459f2eeb 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -66,7 +66,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // Create an ephemeral trie.Database for isolating the live one. Otherwise // the internal junks created by tracing will be persisted into the disk. database = state.NewDatabaseWithConfig(eth.chainDb, &trie.Config{Cache: 16}) - if statedb, err = state.NewWithEpoch(block.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), block.Number())); err == nil { + if statedb, err = state.NewWithEpoch(eth.blockchain.Config(), block.Number(), block.Root(), database, nil, eth.blockchain.ShadowNodeTree()); err == nil { log.Info("Found disk backend for state trie", "root", block.Root(), "number", block.Number()) return statedb, nil } @@ -89,7 +89,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // we would rewind past a persisted block (specific corner case is chain // tracing from the genesis). if !checkLive { - statedb, err = state.NewWithEpoch(current.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) + statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { return statedb, nil } @@ -105,7 +105,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state } current = parent - statedb, err = state.NewWithEpoch(current.Root(), database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) + statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { break } @@ -148,7 +148,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state return nil, fmt.Errorf("stateAtBlock commit failed, number %d root %v: %w", current.NumberU64(), current.Root().Hex(), err) } - statedb, err = state.NewWithEpoch(root, database, nil, types.GetStateEpoch(eth.blockchain.Config(), current.Number())) + statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), root, database, nil, eth.blockchain.ShadowNodeTree()) if err != nil { return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) } diff --git a/les/server_requests.go b/les/server_requests.go index 7564420ce6..4a5456ba86 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -428,6 +428,7 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { p.bumpInvalid() continue } + // TODO(0xbundler): fix les later trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) if trie == nil || err != nil { p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) diff --git a/light/trie.go b/light/trie.go index 56f0451a9b..034e49fabc 100644 --- a/light/trie.go +++ b/light/trie.go @@ -38,7 +38,8 @@ var ( ) func NewState(ctx context.Context, config *params.ChainConfig, head *types.Header, odr OdrBackend) *state.StateDB { - state, _ := state.NewWithEpoch(head.Root, NewStateDatabase(ctx, head, odr), nil, types.GetStateEpoch(config, head.Number)) + tree, _ := trie.NewShadowNodeSnapTree(odr.Database()) + state, _ := state.NewWithEpoch(config, head.Number, head.Root, NewStateDatabase(ctx, head, odr), nil, tree) return state } @@ -64,6 +65,10 @@ func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil } +func (db *odrDatabase) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (state.Trie, error) { + return db.OpenStorageTrie(addrHash, root) +} + func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { switch t := t.(type) { case *odrTrie: @@ -146,6 +151,10 @@ func (t *odrTrie) TryUpdate(key, value []byte) error { }) } +func (t *odrTrie) TryUpdateEpoch(key []byte) error { + return nil +} + func (t *odrTrie) TryDelete(key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { @@ -179,6 +188,10 @@ func (t *odrTrie) HashKey(key []byte) []byte { return nil } +func (t *odrTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { return errors.New("not implemented, needs client/server interface split") } diff --git a/trie/database.go b/trie/database.go index a4838425c5..c4a3b87aec 100644 --- a/trie/database.go +++ b/trie/database.go @@ -432,7 +432,7 @@ func (db *Database) node(hash common.Hash) node { return mustDecodeNodeUnsafe(hash[:], enc) } -func (db *Database) RootNode(hash common.Hash) *RootNode { +func (db *Database) RootNode(hash common.Hash) *rootNode { return nil } diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 9b8a29f888..4d9abc410a 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -49,6 +49,10 @@ func (t *EmptyTrie) TryGet(key []byte) ([]byte, error) { return nil, nil } +func (t *EmptyTrie) TryUpdateEpoch(key []byte) error { + return nil +} + func (t *EmptyTrie) TryGetNode(path []byte) ([]byte, int, error) { return nil, 0, nil } @@ -74,6 +78,10 @@ func (t *EmptyTrie) GetKey(shaKey []byte) []byte { return nil } +func (t *EmptyTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + func (t *EmptyTrie) HashKey(key []byte) []byte { return nil } diff --git a/trie/node.go b/trie/node.go index 1db3d4310f..a18c3b8dea 100644 --- a/trie/node.go +++ b/trie/node.go @@ -38,6 +38,7 @@ const ( rawNodeType rawShortNodeType rawFullNodeType + rootNodeType ) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} @@ -69,10 +70,35 @@ type ( valueNode []byte ) -type RootNode struct { +type rootNode struct { Epoch types.StateEpoch TrieHash common.Hash ShadowHash common.Hash + flags nodeFlag `rlp:"-" json:"-"` +} + +func (n *rootNode) cache() (hashNode, bool) { + return n.flags.hash, n.flags.dirty +} + +func (n *rootNode) encode(w rlp.EncoderBuffer) { + rlp.Encode(w, n) +} + +func (n *rootNode) fstring(s string) string { + return fmt.Sprintf("{%v: %x: %x} ", n.Epoch, n.TrieHash, n.ShadowHash) +} + +func (n *rootNode) nodeType() int { + return rootNodeType +} + +func (n *rootNode) setEpoch(epoch types.StateEpoch) { + n.Epoch = epoch +} + +func (n *rootNode) getEpoch() types.StateEpoch { + return n.Epoch } // nilValueNode is used when collapsing internal trie nodes for hashing, since @@ -87,8 +113,7 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { } func (n *fullNode) GetShadowNode() *shadowBranchNode { - // TODO:get shadow node from cache or disk if shadow node is nil - return &shadowBranchNode{} + return &n.shadowNode } func (n *fullNode) GetChildEpoch(index int) types.StateEpoch { @@ -117,7 +142,7 @@ func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.Sta } func (n *shortNode) GetShadowNode() *shadowExtensionNode { - return &shadowExtensionNode{} + return &n.shadowNode } func (n *fullNode) copy() *fullNode { copy := *n; return © } @@ -230,6 +255,9 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { case 2: n, err := decodeShort(hash, elems) return n, wrapError(err, "short") + case 3: + n, err := decodeRootNode(hash, buf) + return n, wrapError(err, "root node") case 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") @@ -263,6 +291,14 @@ func decodeShort(hash, elems []byte) (node, error) { return n, nil } +func decodeRootNode(hash, elems []byte) (node, error) { + n := &rootNode{flags: nodeFlag{hash: hash}} + if err := rlp.DecodeBytes(elems, n); err != nil { + return nil, err + } + return n, nil +} + func decodeFull(hash, elems []byte) (*fullNode, error) { n := &fullNode{flags: nodeFlag{hash: hash}} for i := 0; i < 16; i++ { diff --git a/trie/node_test.go b/trie/node_test.go index d52b0cee24..fa308e97f1 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -20,6 +20,9 @@ import ( "bytes" "testing" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) @@ -94,6 +97,37 @@ func TestDecodeFullNode(t *testing.T) { } } +func TestRootNodeEncodeDecode(t *testing.T) { + datas := []struct { + r *rootNode + }{ + { + r: &rootNode{ + Epoch: 100, + TrieHash: makeHash("t1"), + ShadowHash: makeHash("s1"), + flags: nodeFlag{hash: makeHash("h1").Bytes()}, + }, + }, + { + r: &rootNode{ + Epoch: 0, + TrieHash: common.Hash{}, + ShadowHash: common.Hash{}, + flags: nodeFlag{hash: makeHash("h1").Bytes()}, + }, + }, + } + + for _, item := range datas { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.r.encode(buf) + p, err := decodeNode(makeHash("h1").Bytes(), buf.ToBytes()) + assert.NoError(t, err) + assert.Equal(t, item.r, p) + } +} + // goos: darwin // goarch: arm64 // pkg: github.com/ethereum/go-ethereum/trie diff --git a/trie/secure_trie.go b/trie/secure_trie.go index a2acc0c17a..a37ba4129b 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -79,6 +79,29 @@ func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, return &SecureTrie{trie: *trie}, nil } +func NewStorageSecure(curEpoch types.StateEpoch, root common.Hash, db *Database, sndb ShadowNodeStorage) (*SecureTrie, error) { + if db == nil || sndb == nil { + panic("trie.NewSecure called without a database") + } + + rn := rootNode{ + Epoch: types.StateEpoch0, + TrieHash: root, + } + hash := common.BytesToHash(root[:]) + if n := db.node(hash); n != nil { + if tmp, ok := n.(*rootNode); ok { + rn = *tmp + } + } + + trie, err := NewWithShadowNode(curEpoch, &rn, db, sndb) + if err != nil { + return nil, err + } + return &SecureTrie{trie: *trie}, nil +} + // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *SecureTrie) Get(key []byte) []byte { @@ -96,6 +119,11 @@ func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { return t.trie.TryGet(t.HashKey(key)) } +func (t *SecureTrie) TryUpdateEpoch(key []byte) error { + _, err := t.trie.TryGetAndUpdateEpoch(t.HashKey(key)) + return err +} + // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // possible to use keybyte-encoding as the path might contain odd nibbles. func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) { @@ -171,6 +199,10 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { return t.trie.db.preimage(common.BytesToHash(shaKey)) } +func (t *SecureTrie) Epoch() types.StateEpoch { + return t.trie.currentEpoch +} + // Commit writes all nodes and the secure hash pre-images to the trie's database. // Nodes are stored with their sha3 hash as the key. // diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 6e2f07fcb7..73bf6f1265 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -1,84 +1,162 @@ package trie import ( + "bytes" + "errors" + "math/big" + "sync" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/rlp" ) +type shadowExtensionNode struct { + ShadowHash common.Hash + Epoch types.StateEpoch +} + +func NewShadowExtensionNode(hash common.Hash, epoch types.StateEpoch) shadowExtensionNode { + return shadowExtensionNode{ + ShadowHash: hash, + Epoch: epoch, + } +} + +type shadowBranchNode struct { + ShadowHash common.Hash + EpochMap [16]types.StateEpoch +} + +func NewShadowBranchNode(hash common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { + return shadowBranchNode{hash, epochMap} +} + type ShadowNodeStorage interface { // Get key is the shadow node prefix path - Get(key []byte) ([]byte, error) - Put(key []byte, val []byte) error - Commit(root common.Hash) error + Get(path string) ([]byte, error) + Put(path string, val []byte) error + Delete(path string) error } -type ShadowNodeManager struct { - diskdb ethdb.KeyValueStore - // TODO diff layers - // TODO history states +type shadowNodeStorage4Trie struct { + addr common.Hash + rw *ShadowNodeStorageRW } -// NewShadowNodeManager TODO need reload diff layers and rebuild history metadata -func NewShadowNodeManager(diskdb ethdb.KeyValueStore) *ShadowNodeManager { - return &ShadowNodeManager{ - diskdb: diskdb, +func NewShadowNodeStorage4Trie(addr common.Hash, rw *ShadowNodeStorageRW) ShadowNodeStorage { + return &shadowNodeStorage4Trie{ + addr: addr, + rw: rw, } } -//// OpenStorage parentRoot is block root? or contract root ? later save block history? -//func (s *ShadowNodeManager) OpenStorage(parentRoot, addrHash common.Hash) ShadowNodeStorage { -// // TODO allow RW append on diff layer, only read plainState -// return &shadowNodeStorageReaderWriterMock{ -// s: s, -// parentRoot: parentRoot, -// addrHash: addrHash, -// } -//} +func (s *shadowNodeStorage4Trie) Get(path string) ([]byte, error) { + return s.rw.Get(s.addr, path) +} -//func (s *ShadowNodeManager) OpenHistoryStorage(blockAt uint64, addrHash common.Hash) ShadowNodeStorage { -// // TODO only allow read when access history -//} +func (s *shadowNodeStorage4Trie) Put(path string, val []byte) error { + return s.rw.Put(s.addr, path, val) +} -type shadowNodeStorageReaderWriterMock struct { - mockEpoch types.StateEpoch - nodeMap map[string][]byte +func (s *shadowNodeStorage4Trie) Delete(path string) error { + return s.rw.Delete(s.addr, path) } -func newShadowNodeStorageMock(epoch types.StateEpoch) ShadowNodeStorage { - return &shadowNodeStorageReaderWriterMock{ - mockEpoch: epoch, - nodeMap: make(map[string][]byte), - } +type ShadowNodeStorageRW struct { + snap shadowNodeSnapshot + tree *ShadowNodeSnapTree + dirties map[common.Hash]map[string][]byte + + stale bool + lock sync.RWMutex } -func (s *shadowNodeStorageReaderWriterMock) Get(key []byte) ([]byte, error) { - var err error - tmp := string(key) - val, ok := s.nodeMap[tmp] - if !ok { - n := shadowBranchNode{ - ShadowHash: common.Hash{}, - EpochMap: [16]types.StateEpoch{}, +func NewShadowNodeStorageRW(tree *ShadowNodeSnapTree, blockRoot common.Hash) (*ShadowNodeStorageRW, error) { + snap := tree.Snapshot(blockRoot) + if snap == nil { + // try using default snap + if snap = tree.Snapshot(emptyRoot); snap == nil { + return nil, errors.New("cannot find the snap") } - for i := range n.EpochMap { - n.EpochMap[i] = s.mockEpoch - } - val, err = rlp.EncodeToBytes(n) - if err != nil { - return nil, err + } + return &ShadowNodeStorageRW{ + snap: snap, + tree: tree, + dirties: make(map[common.Hash]map[string][]byte), + }, nil +} + +func (s *ShadowNodeStorageRW) Get(addr common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + if s.stale { + return nil, errors.New("storage has staled") + } + sub, exist := s.dirties[addr] + if exist { + if val, ok := sub[path]; ok { + return val, nil } - s.nodeMap[tmp] = val } - return val, nil + + return s.snap.ShadowNode(addr, path) } -func (s *shadowNodeStorageReaderWriterMock) Put(key []byte, val []byte) error { - s.nodeMap[string(key)] = val +func (s *ShadowNodeStorageRW) Delete(addr common.Hash, path string) error { + s.lock.RLock() + defer s.lock.RUnlock() + if s.stale { + return errors.New("storage has staled") + } + _, ok := s.dirties[addr] + if !ok { + s.dirties[addr] = make(map[string][]byte) + } + + s.dirties[addr][path] = nil return nil } -func (s *shadowNodeStorageReaderWriterMock) Commit(root common.Hash) error { +func (s *ShadowNodeStorageRW) Put(addr common.Hash, path string, val []byte) error { + prev, err := s.Get(addr, path) + if err != nil { + return err + } + if bytes.Equal(prev, val) { + return nil + } + + s.lock.RLock() + defer s.lock.RUnlock() + if s.stale { + return errors.New("storage has staled") + } + + _, ok := s.dirties[addr] + if !ok { + s.dirties[addr] = make(map[string][]byte) + } + s.dirties[addr][path] = val return nil } + +func (s *ShadowNodeStorageRW) OpenStorage(addr common.Hash) ShadowNodeStorage { + return NewShadowNodeStorage4Trie(addr, s) +} + +// Commit if you commit to an unknown parent, like deeper than 128 layers, will get error +func (s *ShadowNodeStorageRW) Commit(number *big.Int, blockRoot common.Hash) error { + s.lock.Lock() + defer s.lock.Unlock() + if s.stale { + return errors.New("storage has staled") + } + + s.stale = true + err := s.tree.Update(s.snap.Root(), number, blockRoot, s.dirties) + if err != nil { + return err + } + + return s.tree.Cap(blockRoot) +} diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go index 5271b1ef58..c7c0a5d1ad 100644 --- a/trie/shadow_node_difflayer.go +++ b/trie/shadow_node_difflayer.go @@ -4,20 +4,20 @@ import ( "bytes" "errors" "fmt" + "io" + "math/big" + "sync" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" - "io" - "math/big" - "sync" ) const ( // MaxShadowNodeDiffDepth default is 128 layers - MaxShadowNodeDiffDepth = 128 - journalVersion = 1 + MaxShadowNodeDiffDepth = 128 + journalVersion uint64 = 1 ) // shadowNodeSnapshot record diff layer and disk layer of shadow nodes, support mini reorg @@ -38,9 +38,9 @@ type shadowNodeSnapshot interface { Journal(buffer *bytes.Buffer) (common.Hash, error) } -// shadowNodeSnapTree maintain all diff layers support reorg, will flush to db when MaxShadowNodeDiffDepth reach +// ShadowNodeSnapTree maintain all diff layers support reorg, will flush to db when MaxShadowNodeDiffDepth reach // every layer response to a block state change set, there no flatten layers operation. -type shadowNodeSnapTree struct { +type ShadowNodeSnapTree struct { diskdb ethdb.KeyValueStore // diff layers @@ -54,21 +54,29 @@ type shadowNodeSnapTree struct { lock sync.RWMutex } -func newShadowNodeDiffTree(diskdb ethdb.KeyValueStore) (*shadowNodeSnapTree, error) { +func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore) (*ShadowNodeSnapTree, error) { diskLayer, err := loadDiskLayer(diskdb) if err != nil { return nil, err } + // if there is no disk layer, will construct a fake disk layer + if diskLayer == nil { + diskLayer, err = newShadowNodeDiskLayer(diskdb, common.Big0, emptyRoot) + if err != nil { + return nil, err + } + } layers, children, err := loadDiffLayers(diskdb, diskLayer) if err != nil { return nil, err } + layers[diskLayer.blockRoot] = diskLayer // check if continuously after disk layer - if len(children[diskLayer.blockRoot]) == 0 { + if len(layers) > 1 && len(children[diskLayer.blockRoot]) == 0 { return nil, errors.New("cannot found any diff layers link to disk layer") } - return &shadowNodeSnapTree{ + return &ShadowNodeSnapTree{ diskdb: diskdb, layers: layers, children: children, @@ -76,20 +84,85 @@ func newShadowNodeDiffTree(diskdb ethdb.KeyValueStore) (*shadowNodeSnapTree, err }, nil } -// Cap TODO(0xbundler): keep tree depth not greater MaxShadowNodeDiffDepth, all forks parent to disk layer will delete +// Cap keep tree depth not greater MaxShadowNodeDiffDepth, all forks parent to disk layer will delete // TODO(0xbundler): store disk layer meta(blockNumber, blockRoot) too -func (s *shadowNodeSnapTree) Cap(blockRoot common.Hash) error { +func (s *ShadowNodeSnapTree) Cap(blockRoot common.Hash) error { + snap := s.Snapshot(blockRoot) + if snap == nil { + return errors.New("snapshot missing") + } + nextDiff, ok := snap.(*shadowNodeDiffLayer) + if !ok { + return nil + } + for i := 0; i < MaxShadowNodeDiffDepth-1; i++ { + nextDiff, ok = nextDiff.Parent().(*shadowNodeDiffLayer) + // if depth less MaxShadowNodeDiffDepth, just return + if !ok { + return nil + } + } + + flatten := make([]shadowNodeSnapshot, 0) + parent := nextDiff.Parent() + for parent != nil { + flatten = append(flatten, parent) + parent = parent.Parent() + } + if len(flatten) <= 1 { + return nil + } + + last, ok := flatten[len(flatten)-1].(*shadowNodeDiskLayer) + if !ok { + return errors.New("the diff layers not link to disk layer") + } + + s.lock.Lock() + defer s.lock.Unlock() + newDiskLayer, err := s.flattenDiffs2Disk(flatten[:len(flatten)-1], last) + if err != nil { + return err + } + + // clear forks, but keep latest disk forks + for i := len(flatten) - 1; i > 0; i-- { + var childRoot common.Hash + if i > 0 { + childRoot = flatten[i-1].Root() + } else { + childRoot = nextDiff.Root() + } + root := flatten[i].Root() + s.removeSubLayers(s.children[root], &childRoot) + delete(s.layers, root) + delete(s.children, root) + } + + // reset newDiskLayer and children's parent + s.layers[newDiskLayer.Root()] = newDiskLayer + for _, child := range s.children[newDiskLayer.Root()] { + if diff, exist := s.layers[child].(*shadowNodeDiffLayer); exist { + diff.setParent(newDiskLayer) + } + } return nil } -func (s *shadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { +func (s *ShadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { if blockRoot == parentRoot { - return errors.New("snapshot cycle") + return errors.New("there no update in layers") } + // Generate a new snapshot on top of the parent parent := s.Snapshot(parentRoot) if parent == nil { - return fmt.Errorf("parent snapshot missing") + // just point to fake disk layers + parent = s.Snapshot(emptyRoot) + if parent == nil { + return errors.New("cannot find any suitable parent") + } + parentRoot = parent.Root() } snap, err := parent.Update(blockNumber, blockRoot, nodeSet) if err != nil { @@ -104,13 +177,13 @@ func (s *shadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int return nil } -func (s *shadowNodeSnapTree) Snapshot(blockRoot common.Hash) shadowNodeSnapshot { +func (s *ShadowNodeSnapTree) Snapshot(blockRoot common.Hash) shadowNodeSnapshot { s.lock.RLock() defer s.lock.RUnlock() return s.layers[blockRoot] } -func (s *shadowNodeSnapTree) Journal() error { +func (s *ShadowNodeSnapTree) Journal() error { s.lock.Lock() defer s.lock.Unlock() @@ -128,15 +201,43 @@ func (s *shadowNodeSnapTree) Journal() error { return nil } +func (s *ShadowNodeSnapTree) removeSubLayers(layers []common.Hash, skip *common.Hash) { + for _, layer := range layers { + if skip != nil && layer == *skip { + continue + } + s.removeSubLayers(s.children[layer], nil) + delete(s.layers, layer) + delete(s.children, layer) + } +} + +// flattenDiffs2Disk delete all flatten and push them to db +func (s *ShadowNodeSnapTree) flattenDiffs2Disk(flatten []shadowNodeSnapshot, diskLayer *shadowNodeDiskLayer) (*shadowNodeDiskLayer, error) { + var err error + for i := len(flatten) - 1; i >= 0; i-- { + diskLayer, err = diskLayer.PushDiff(flatten[i].(*shadowNodeDiffLayer)) + if err != nil { + return nil, err + } + } + + return diskLayer, nil +} + +// loadDiskLayer load from db, could be nil when none in db func loadDiskLayer(db ethdb.KeyValueStore) (*shadowNodeDiskLayer, error) { // TODO(0xbundler): load disk layer meta(blockNumber, blockRoot) - return newShadowNodeDiskLayer(db, common.Big0, common.Hash{}) + return nil, nil } func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map[common.Hash]shadowNodeSnapshot, map[common.Hash][]common.Hash, error) { + layers := make(map[common.Hash]shadowNodeSnapshot) + children := make(map[common.Hash][]common.Hash) + journal := rawdb.ReadShadowNodeSnapshotJournal(db) if len(journal) == 0 { - return nil, nil, errors.New("journal is empty") + return layers, children, nil } r := rlp.NewStream(bytes.NewReader(journal), 0) // Firstly, resolve the first element as the journal version @@ -148,7 +249,6 @@ func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map return nil, nil, errors.New("wrong journal version") } - layers := make(map[common.Hash]shadowNodeSnapshot) parents := make(map[common.Hash]common.Hash) for { var ( @@ -163,7 +263,7 @@ func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map if errors.Is(err, io.EOF) { break } - return nil, nil, fmt.Errorf("load diff parent: %v", err) + return nil, nil, fmt.Errorf("load diff number: %v", err) } if err := r.Decode(&parent); err != nil { return nil, nil, fmt.Errorf("load diff parent: %v", err) @@ -193,16 +293,15 @@ func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map layers[root] = newShadowNodeDiffLayer(&number, root, nil, nodeSet) } - children := make(map[common.Hash][]common.Hash) for t, s := range layers { parent := parents[t] children[parent] = append(children[parent], t) if p, ok := layers[parent]; ok { s.(*shadowNodeDiffLayer).parent = p - } else if parent == diskLayer.Root() { + } else if diskLayer != nil && parent == diskLayer.Root() { s.(*shadowNodeDiffLayer).parent = diskLayer } else { - return nil, nil, errors.New("cannot found the snap's parent") + return nil, nil, errors.New("cannot find it's parent") } } return layers, children, nil @@ -232,19 +331,17 @@ func (s *shadowNodeDiffLayer) Root() common.Hash { return s.blockRoot } -func (s *shadowNodeDiffLayer) ShadowNode(addrHash common.Hash, prefix string) ([]byte, error) { +func (s *shadowNodeDiffLayer) ShadowNode(addrHash common.Hash, path string) ([]byte, error) { s.lock.RLock() defer s.lock.RUnlock() cm, exist := s.nodeSet[addrHash] - if !exist { - return nil, errors.New("cannot find the address") - } - - ret, exist := cm[prefix] if exist { - return ret, nil + if ret, ok := cm[path]; ok { + return ret, nil + } } - return s.parent.ShadowNode(addrHash, prefix) + + return s.parent.ShadowNode(addrHash, path) } func (s *shadowNodeDiffLayer) Parent() shadowNodeSnapshot { @@ -253,9 +350,10 @@ func (s *shadowNodeDiffLayer) Parent() shadowNodeSnapshot { return s.parent } +// Update append new diff layer onto current, nodeSet when val is []byte{}, it delete the kv func (s *shadowNodeDiffLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { s.lock.RLock() - if s.blockNumber.Cmp(blockNumber) < 0 { + if s.blockNumber.Cmp(blockNumber) >= 0 { return nil, errors.New("update a unordered diff layer") } s.lock.RUnlock() @@ -266,6 +364,23 @@ func (s *shadowNodeDiffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) s.lock.RLock() defer s.lock.RUnlock() + if err := rlp.Encode(buffer, s.blockNumber); err != nil { + return common.Hash{}, err + } + + if s.parent != nil { + if err := rlp.Encode(buffer, s.parent.Root()); err != nil { + return common.Hash{}, err + } + } else { + if err := rlp.Encode(buffer, emptyRoot); err != nil { + return common.Hash{}, err + } + } + + if err := rlp.Encode(buffer, s.blockRoot); err != nil { + return common.Hash{}, err + } storage := make([]journalShadowNode, 0, len(s.nodeSet)) for hash, nodes := range s.nodeSet { keys := make([]string, 0, len(nodes)) @@ -282,6 +397,12 @@ func (s *shadowNodeDiffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) return s.blockRoot, nil } +func (s *shadowNodeDiffLayer) setParent(parent shadowNodeSnapshot) { + s.lock.Lock() + defer s.lock.Unlock() + s.parent = parent +} + type journalShadowNode struct { Hash common.Hash Keys []string @@ -308,10 +429,15 @@ func newShadowNodeDiskLayer(diskdb ethdb.KeyValueReader, blockNumber *big.Int, b } func (s *shadowNodeDiskLayer) Root() common.Hash { + s.lock.RLock() + defer s.lock.RUnlock() return s.blockRoot } func (s *shadowNodeDiskLayer) ShadowNode(addrHash common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + nodeSet, exist := s.cache[addrHash] if exist { if enc, ok := nodeSet[path]; ok { @@ -319,17 +445,19 @@ func (s *shadowNodeDiskLayer) ShadowNode(addrHash common.Hash, path string) ([]b } } //TODO(0xbundler): return history & changeSet later - s.cache[addrHash] = make(map[string][]byte) - n := shadowBranchNode{ - ShadowHash: common.Hash{}, - EpochMap: [16]types.StateEpoch{}, - } - enc, err := rlp.EncodeToBytes(n) - if err != nil { - return nil, err - } - s.cache[addrHash][path] = enc - return enc, nil + //s.cache[addrHash] = make(map[string][]byte) + //n := shadowBranchNode{ + // ShadowHash: common.Hash{}, + // EpochMap: [16]types.StateEpoch{}, + //} + //enc, err := rlp.EncodeToBytes(n) + //if err != nil { + // return nil, err + //} + //s.cache[addrHash][path] = enc + //return enc, nil + + return nil, nil } func (s *shadowNodeDiskLayer) Parent() shadowNodeSnapshot { @@ -338,7 +466,7 @@ func (s *shadowNodeDiskLayer) Parent() shadowNodeSnapshot { func (s *shadowNodeDiskLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { s.lock.RLock() - if s.blockNumber.Cmp(blockNumber) < 0 { + if s.blockNumber.Cmp(blockNumber) >= 0 { return nil, errors.New("update a unordered diff layer") } s.lock.RUnlock() @@ -348,3 +476,29 @@ func (s *shadowNodeDiskLayer) Update(blockNumber *big.Int, blockRoot common.Hash func (s *shadowNodeDiskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { return common.Hash{}, nil } + +func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDiskLayer, error) { + s.lock.Lock() + defer s.lock.Unlock() + + if s.blockNumber.Cmp(diff.blockNumber) >= 0 { + return nil, errors.New("push a lower block to disk") + } + // TODO(0xbundler): store diff to DB + diskLayer, err := newShadowNodeDiskLayer(s.diskdb, diff.blockNumber, diff.blockRoot) + if err != nil { + return nil, err + } + + // reuse cache + diskLayer.cache = s.cache + for addr, nodes := range diff.nodeSet { + if diskLayer.cache[addr] == nil { + diskLayer.cache[addr] = make(map[string][]byte) + } + for k, v := range nodes { + diskLayer.cache[addr][k] = v + } + } + return diskLayer, err +} diff --git a/trie/shadow_node_difflayer_test.go b/trie/shadow_node_difflayer_test.go new file mode 100644 index 0000000000..a31ee1ff30 --- /dev/null +++ b/trie/shadow_node_difflayer_test.go @@ -0,0 +1,263 @@ +package trie + +import ( + "math/big" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +var ( + blockRoot0 = makeHash("b0") + blockRoot1 = makeHash("b1") + blockRoot2 = makeHash("b2") + storageRoot0 = makeHash("s0") + storageRoot1 = makeHash("s1") + storageRoot2 = makeHash("s2") + contract1 = makeHash("c1") + contract2 = makeHash("c2") + contract3 = makeHash("c3") +) + +func TestShadowNodeDiffLayer_whenGenesis(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap = tree.Snapshot(blockRoot1) + assert.Nil(t, snap) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, makeNodeSet(contract1, []string{"hello", "world"})) + assert.NoError(t, err) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, makeNodeSet(contract1, []string{"hello2", "world2"})) + assert.NoError(t, err) + err = tree.Cap(blockRoot1) + assert.NoError(t, err) + err = tree.Journal() + assert.NoError(t, err) + + // reload + tree, err = NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + diskLayer := tree.Snapshot(emptyRoot) + assert.NotNil(t, diskLayer) + snap = tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap1 := tree.Snapshot(blockRoot1) + n, err := snap1.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + assert.Equal(t, diskLayer, snap1.Parent()) + assert.Equal(t, blockRoot1, snap1.Root()) + + // read from child + snap2 := tree.Snapshot(blockRoot2) + assert.Equal(t, snap1, snap2.Parent()) + assert.Equal(t, blockRoot2, snap2.Root()) + n, err = snap2.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + n, err = snap2.ShadowNode(contract1, "hello2") + assert.NoError(t, err) + assert.Equal(t, []byte("world2"), n) +} + +func TestShadowNodeDiffLayer_crud(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + set1 := makeNodeSet(contract1, []string{"hello", "world", "h1", "w1"}) + appendNodeSet(set1, contract3, []string{"h3", "w3"}) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, set1) + assert.NoError(t, err) + set2 := makeNodeSet(contract1, []string{"hello", "", "h1", ""}) + appendNodeSet(set2, contract2, []string{"hello", "", "h2", "w2"}) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, set2) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot1) + assert.NotNil(t, snap) + val, err := snap.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + val, err = snap.ShadowNode(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte("w1"), val) + val, err = snap.ShadowNode(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) + + snap = tree.Snapshot(blockRoot2) + assert.NotNil(t, snap) + val, err = snap.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract2, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract2, "h2") + assert.NoError(t, err) + assert.Equal(t, []byte("w2"), val) + val, err = snap.ShadowNode(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) +} + +func TestShadowNodeDiffLayer_capDiffLayers(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + + // push 200 diff layers + count := 1 + for i := 0; i < 200; i++ { + ns := strconv.Itoa(count) + root := makeHash("b" + ns) + parent := makeHash("b" + strconv.Itoa(count-1)) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, + root, makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 10 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + + err = tree.Cap(root) + assert.NoError(t, err) + count++ + } + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 100; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 20 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(makeHash("b"+strconv.Itoa(count-1)+"f"+fs), number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + count++ + } + lastRoot := makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 129; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + count++ + } + lastRoot = makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + + assert.Equal(t, 129, len(tree.layers)) + assert.Equal(t, 128, len(tree.children)) + for parent, children := range tree.children { + if tree.layers[parent] == nil { + t.Log(tree.layers[parent]) + } + assert.NotNil(t, tree.layers[parent]) + for _, child := range children { + if tree.layers[child] == nil { + t.Log(tree.layers[child]) + } + assert.NotNil(t, tree.layers[child]) + } + } + + snap := tree.Snapshot(lastRoot) + assert.NotNil(t, snap) + for i := 1; i < count; i++ { + ns := strconv.Itoa(i) + n, err := snap.ShadowNode(contract1, "hello"+ns) + assert.NoError(t, err) + assert.Equal(t, []byte("world"+ns), n) + } + + // store + err = tree.Journal() + assert.NoError(t, err) +} + +func makeHash(s string) common.Hash { + var ret common.Hash + if len(s) >= 32 { + copy(ret[:], []byte(s)[:hashLen]) + return ret + } + for i := 0; i < hashLen; i++ { + ret[i] = '0' + } + copy(ret[hashLen-len(s):hashLen], s) + return ret +} + +func makeNodeSet(addr common.Hash, kvs []string) map[common.Hash]map[string][]byte { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + ret := make(map[common.Hash]map[string][]byte) + ret[addr] = make(map[string][]byte) + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } + + return ret +} + +func appendNodeSet(ret map[common.Hash]map[string][]byte, addr common.Hash, kvs []string) { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + if _, ok := ret[addr]; !ok { + ret[addr] = make(map[string][]byte) + } + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } +} diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go new file mode 100644 index 0000000000..9a71bc095f --- /dev/null +++ b/trie/shadow_node_test.go @@ -0,0 +1,81 @@ +package trie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func TestShadowNodeRW_CRUD(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + assert.NoError(t, err) + + err = storageRW.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + err = storageRW.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + val, err := storageRW.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + err = storageRW.Delete(contract1, "hello") + assert.NoError(t, err) + val, err = storageRW.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) +} + +func TestShadowNodeRW_Commit(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + assert.NoError(t, err) + + err = storageRW.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + + err = storageRW.Commit(common.Big1, blockRoot1) + assert.NoError(t, err) + + storageRW, err = NewShadowNodeStorageRW(tree, blockRoot1) + assert.NoError(t, err) + val, err := storageRW.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) +} + +func TestNewShadowNodeStorage4Trie(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + assert.NoError(t, err) + + s1 := storageRW.OpenStorage(contract1) + s2 := storageRW.OpenStorage(contract2) + val, err := s1.Get("hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) + err = s1.Put("hello", []byte("world")) + assert.NoError(t, err) + val, _ = s1.Get("hello") + assert.Equal(t, []byte("world"), val) + val, _ = s2.Get("hello") + assert.Equal(t, []byte(nil), val) + err = s1.Delete("hello") + assert.NoError(t, err) + val, _ = s1.Get("hello") + assert.Equal(t, []byte(nil), val) + + s2.Put("h2", []byte("w2")) + val, _ = s2.Get("h2") + assert.Equal(t, []byte("w2"), val) + + err = storageRW.Commit(common.Big1, blockRoot2) + assert.NoError(t, err) +} diff --git a/trie/shadownodes.go b/trie/shadownodes.go deleted file mode 100644 index 240f241794..0000000000 --- a/trie/shadownodes.go +++ /dev/null @@ -1,27 +0,0 @@ -package trie - -import ( - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" -) - -//type shadowNode interface { -// encode(encoder rlp.EncoderBuffer) -//} - -type shadowExtensionNode struct { - ShadowHash common.Hash -} - -func NewShadowExtensionNode(hash common.Hash) shadowExtensionNode { - return shadowExtensionNode{hash} -} - -type shadowBranchNode struct { - ShadowHash common.Hash - EpochMap [16]types.StateEpoch -} - -func NewShadowBranchNode(hash common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { - return shadowBranchNode{hash, epochMap} -} diff --git a/trie/trie.go b/trie/trie.go index 0717e9d5ff..07af8dbb90 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -88,8 +88,7 @@ func New(root common.Hash, db *Database) (*Trie, error) { panic("trie.New called without a database") } trie := &Trie{ - db: db, - sndb: newShadowNodeStorageMock(0), + db: db, } if root != (common.Hash{}) && root != emptyRoot { rootnode, err := trie.resolveHash(root[:], nil) @@ -101,6 +100,30 @@ func New(root common.Hash, db *Database) (*Trie, error) { return trie, nil } +func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Database, sndb ShadowNodeStorage) (*Trie, error) { + if db == nil || sndb == nil { + panic("trie.New called without a database") + } + trie := &Trie{ + db: db, + sndb: sndb, + currentEpoch: curEpoch, + isStorageTrie: true, + shadowHash: rootNode.ShadowHash, + } + if rootNode.TrieHash != (common.Hash{}) && rootNode.TrieHash != emptyRoot { + root, err := trie.resolveHash(rootNode.TrieHash[:], nil) + if err != nil { + return nil, err + } + if err = trie.resolveShadowNode(rootNode.Epoch, root, nil); err != nil { + return nil, err + } + trie.root = root + } + return trie, nil +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. func (t *Trie) NodeIterator(start []byte) NodeIterator { @@ -242,7 +265,10 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S if err != nil { return nil, n, true, err } - child.setEpoch(epoch) + if err = t.resolveShadowNode(epoch, child, key[:pos]); err != nil { + return nil, nil, false, err + } + value, newnode, _, err = t.tryGetWithEpoch(child, key, pos, epoch, updateEpoch) return value, newnode, true, err default: @@ -468,7 +494,10 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - rn.setEpoch(epoch) + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return false, nil, err + } + dirty, nn, err := t.insert(rn, prefix, key, value, epoch) if !dirty || err != nil { return false, rn, err @@ -637,7 +666,10 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, if err != nil { return false, nil, err } - rn.setEpoch(epoch) + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return false, nil, err + } + dirty, nn, err := t.delete(rn, prefix, key, epoch) if !dirty || err != nil { return false, rn, err @@ -903,3 +935,73 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub) (node, bool, error panic(fmt.Sprintf("invalid node: %T", n)) } } + +func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte) error { + if t.currentEpoch < 1 { + return nil + } + + if t.currentEpoch > 0 && t.sndb == nil { + return errors.New("cannot resolve shadow node") + } + + switch n := cur.(type) { + case *shortNode: + n.shadowNode.Epoch = epoch + n.shadowNode.ShadowHash = common.Hash{} + return t.resolveShadowNode(epoch, n.Val, append(prefix, n.Key...)) + case *fullNode: + val, err := t.sndb.Get(string(hexToSuffixCompact(prefix))) + if err != nil { + return err + } + if len(val) == 0 { + // set default epoch map + n.shadowNode.EpochMap = [16]types.StateEpoch{} + n.shadowNode.ShadowHash = common.Hash{} + } else { + if err = rlp.DecodeBytes(val, &n.shadowNode); err != nil { + return err + } + } + for i := byte(0); i < BranchNodeLength-1; i++ { + if err := t.resolveShadowNode(n.shadowNode.EpochMap[i], n.Children[i], append(prefix, i)); err != nil { + return err + } + } + return nil + case valueNode, hashNode: + // just skip + return nil + default: + return errors.New("resolveShadowNode unsupported node type") + } +} + +// SUFFIX-COMPACT encoding is used for encoding trie node path in the trie node +// storage key. The main difference with COMPACT encoding is that the key flag +// is put at the end of the key. +// +// e.g. +// - the key [] is encoded as [0x00] +// - the key [0x1, 0x2, 0x3] is encoded as [0x12, 0x31] +// - the key [0x1, 0x2, 0x3, 0x0] is encoded as [0x12, 0x30, 0x00] +// +// The main benefit of this format is the continuous paths can retain the shared +// path prefix after encoding. +func hexToSuffixCompact(hex []byte) []byte { + terminator := byte(0) + if hasTerm(hex) { + terminator = 1 + hex = hex[:len(hex)-1] + } + buf := make([]byte, len(hex)/2+1) + buf[len(buf)-1] = terminator << 1 // the flag byte + if len(hex)&1 == 1 { + buf[len(buf)-1] |= 1 // odd flag + buf[len(buf)-1] |= hex[len(hex)-1] << 4 // last nibble is contained in the last byte + hex = hex[:len(hex)-1] + } + decodeNibbles(hex, buf[:len(buf)-1]) + return buf +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 271f1b376c..3d8fe5e2d7 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -832,6 +832,43 @@ func TestReviveValueAtFullNode(t *testing.T) { } } +func TestTrie_ShadowNodeRW(t *testing.T) { + diskdb := memorydb.New() + database := NewDatabase(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageRW, err := NewShadowNodeStorageRW(tree, blockRoot0) + assert.NoError(t, err) + + tr, err := NewStorageSecure(types.StateEpoch(1), emptyRoot, database, storageRW.OpenStorage(contract1)) + assert.NoError(t, err) + tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) + tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) + val, err := tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + assert.NoError(t, tr.TryDelete(makeHash("k1").Bytes())) + + // TODO(0xbundle): need MPT support commit with shadow node + //nextRoot, _, err := tr.Commit(nil) + //assert.NoError(t, err) + //assert.NoError(t, storageRW.Commit(common.Big1, blockRoot1)) + + // reload + //storageRW, err = NewShadowNodeStorageRW(tree, blockRoot1) + //assert.NoError(t, err) + //tr, err = NewStorageSecure(types.StateEpoch(2), nextRoot, database, storageRW.OpenStorage(contract1)) + //assert.NoError(t, err) + //val, err = tr.TryGet(makeHash("k2").Bytes()) + //assert.NoError(t, err) + //assert.Equal(t, makeHash("v2").Bytes(), val) + //// check expired + //tr, err = NewStorageSecure(types.StateEpoch(3), nextRoot, database, storageRW.OpenStorage(contract1)) + //assert.NoError(t, err) + //val, err = tr.TryGet(makeHash("k2").Bytes()) + //assert.Error(t, err) +} + func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } From d41b1e1720111e9b69acb98926756d47598f2408 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 5 May 2023 22:01:18 +0800 Subject: [PATCH 33/51] rawdb: handle shadow node history & changeSet & plainState; trie/shadownode: support shadow node history store & query; trie/shadownode_difflayer: support diff layers flatten to DB; --- core/rawdb/accessors_shadow_node.go | 91 +++++++++++++ core/rawdb/accessors_snapshot.go | 17 --- core/rawdb/schema.go | 7 + core/state/statedb.go | 12 +- go.mod | 3 + go.sum | 6 + trie/shadow_node.go | 58 ++++++-- trie/shadow_node_difflayer.go | 204 ++++++++++++++++++++-------- trie/shadow_node_history.go | 86 ++++++++++++ trie/shadow_node_history_test.go | 55 ++++++++ trie/shadow_node_test.go | 74 ++++++++-- trie/trie.go | 2 +- trie/trie_test.go | 14 +- 13 files changed, 522 insertions(+), 107 deletions(-) create mode 100644 core/rawdb/accessors_shadow_node.go create mode 100644 trie/shadow_node_history.go create mode 100644 trie/shadow_node_history_test.go diff --git a/core/rawdb/accessors_shadow_node.go b/core/rawdb/accessors_shadow_node.go new file mode 100644 index 0000000000..9b0322bbb2 --- /dev/null +++ b/core/rawdb/accessors_shadow_node.go @@ -0,0 +1,91 @@ +package rawdb + +import ( + "encoding/binary" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +func DeleteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(shadowNodeSnapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + +func ReadShadowNodeSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(shadowNodeSnapshotJournalKey) + return data +} + +func WriteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(shadowNodeSnapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +func ReadShadowNodePlainStateMeta(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(shadowNodePlainStateMeta) + return data +} + +func WriteShadowNodePlainStateMeta(db ethdb.KeyValueWriter, val []byte) error { + return db.Put(shadowNodePlainStateMeta, val) +} + +func ReadShadowNodeHistory(db ethdb.KeyValueReader, addr common.Hash, path string, number uint64) []byte { + val, _ := db.Get(shadowNodeHistoryKey(addr, path, number)) + return val +} + +func WriteShadowNodeHistory(db ethdb.KeyValueWriter, addr common.Hash, path string, number uint64, val []byte) error { + return db.Put(shadowNodeHistoryKey(addr, path, number), val) +} + +func ReadShadowNodeChangeSet(db ethdb.KeyValueReader, addr common.Hash, number uint64) []byte { + val, _ := db.Get(shadowNodeChangeSetKey(addr, number)) + return val +} + +func WriteShadowNodeChangeSet(db ethdb.KeyValueWriter, addr common.Hash, number uint64, val []byte) error { + return db.Put(shadowNodeChangeSetKey(addr, number), val) +} + +func ReadShadowNodePlainState(db ethdb.KeyValueReader, addr common.Hash, path string) []byte { + val, _ := db.Get(shadowNodePlainStateKey(addr, path)) + return val +} + +func WriteShadowNodePlainState(db ethdb.KeyValueWriter, addr common.Hash, path string, val []byte) error { + return db.Put(shadowNodePlainStateKey(addr, path), val) +} + +func DeleteShadowNodePlainState(db ethdb.KeyValueWriter, addr common.Hash, path string) error { + return db.Delete(shadowNodePlainStateKey(addr, path)) +} + +func shadowNodeChangeSetKey(addr common.Hash, number uint64) []byte { + key := make([]byte, len(ShadowNodeChangeSetPrefix)+8+len(addr)) + copy(key[:], ShadowNodeHistoryPrefix) + binary.BigEndian.PutUint64(key[len(ShadowNodeHistoryPrefix):], number) + copy(key[len(ShadowNodeHistoryPrefix)+8:], addr.Bytes()) + return key +} + +func shadowNodeHistoryKey(addr common.Hash, path string, number uint64) []byte { + key := make([]byte, len(ShadowNodeHistoryPrefix)+len(addr)+len(path)+8) + copy(key[:], ShadowNodeHistoryPrefix) + copy(key[len(ShadowNodeHistoryPrefix):], addr.Bytes()) + copy(key[len(ShadowNodeHistoryPrefix)+len(addr):], path) + binary.BigEndian.PutUint64(key[len(key)-8:], number) + return key +} + +func shadowNodePlainStateKey(addr common.Hash, path string) []byte { + key := make([]byte, len(ShadowNodePlainStatePrefix)+len(addr)+len(path)) + copy(key[:], ShadowNodeHistoryPrefix) + copy(key[len(ShadowNodeHistoryPrefix):], addr.Bytes()) + copy(key[len(ShadowNodeHistoryPrefix)+len(addr):], path) + return key +} diff --git a/core/rawdb/accessors_snapshot.go b/core/rawdb/accessors_snapshot.go index 220b41ca6f..1c828662c1 100644 --- a/core/rawdb/accessors_snapshot.go +++ b/core/rawdb/accessors_snapshot.go @@ -141,23 +141,6 @@ func DeleteSnapshotJournal(db ethdb.KeyValueWriter) { } } -func ReadShadowNodeSnapshotJournal(db ethdb.KeyValueReader) []byte { - data, _ := db.Get(shadowNodeSnapshotJournalKey) - return data -} - -func WriteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { - if err := db.Put(shadowNodeSnapshotJournalKey, journal); err != nil { - log.Crit("Failed to store snapshot journal", "err", err) - } -} - -func DeleteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter) { - if err := db.Delete(shadowNodeSnapshotJournalKey); err != nil { - log.Crit("Failed to remove snapshot journal", "err", err) - } -} - // ReadSnapshotGenerator retrieves the serialized snapshot generator saved at // the last shutdown. func ReadSnapshotGenerator(db ethdb.KeyValueReader) []byte { diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index d1e3aa76d5..9eeddc5efc 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -96,6 +96,9 @@ var ( // transitionStatusKey tracks the eth2 transition status. transitionStatusKey = []byte("eth2-transition") + // shadowNodePlainStateMeta save disk layer meta data + shadowNodePlainStateMeta = []byte("shadowNodePlainStateMeta") + // Data item prefixes (use single byte to avoid mixing data types, avoid `i`, used for indexes). headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td @@ -111,6 +114,10 @@ var ( SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value CodePrefix = []byte("c") // CodePrefix + code hash -> account code + ShadowNodeHistoryPrefix = []byte("sh") // ShadowNodeHistoryPrefix + addr hash + path + blockNr -> bitmap, default blockNr = math.MaxUint64 + ShadowNodeChangeSetPrefix = []byte("sc") // ShadowNodeChangeSetPrefix + addr hash + blockNr -> changeSet/prev val + ShadowNodePlainStatePrefix = []byte("sp") // ShadowNodePlainStatePrefix + addr hash + path -> val + // difflayer database diffLayerPrefix = []byte("d") // diffLayerPrefix + hash -> diffLayer diff --git a/core/state/statedb.go b/core/state/statedb.go index cf1c81921e..8dcc0056ef 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -89,7 +89,7 @@ type StateDB struct { fullProcessed bool pipeCommit bool - shadowNodeRW *trie.ShadowNodeStorageRW + shadowNodeDB trie.ShadowNodeDatabase snaps *snapshot.Tree snap snapshot.Snapshot snapAccountMux sync.Mutex // Mutex for snap account access @@ -164,7 +164,7 @@ func NewWithEpoch(config *params.ChainConfig, targetBlock *big.Int, root common. // init target block and shadowNodeRW stateDB.targetBlk = targetBlock - stateDB.shadowNodeRW, err = trie.NewShadowNodeStorageRW(sntree, root) + stateDB.shadowNodeDB, err = trie.NewShadowNodeDatabase(sntree, targetBlock, root) if err != nil { return nil, err } @@ -1403,7 +1403,7 @@ func (s *StateDB) LightCommit() (common.Hash, *types.DiffLayer, error) { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() error) (common.Hash, *types.DiffLayer, error) { - if s.targetEpoch > 0 && s.shadowNodeRW == nil { + if s.targetEpoch > 0 && s.shadowNodeDB == nil { return common.Hash{}, nil, errors.New("cannot commit shadow node") } @@ -1644,8 +1644,8 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } - if s.shadowNodeRW != nil && s.originalRoot != root { - if err := s.shadowNodeRW.Commit(s.targetBlk, root); err != nil { + if s.shadowNodeDB != nil && s.originalRoot != root { + if err := s.shadowNodeDB.Commit(s.targetBlk, root); err != nil { return common.Hash{}, nil, err } } @@ -1833,5 +1833,5 @@ func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { } func (s *StateDB) openShadowStorage(addr common.Hash) trie.ShadowNodeStorage { - return trie.NewShadowNodeStorage4Trie(addr, s.shadowNodeRW) + return trie.NewShadowNodeStorage4Trie(addr, s.shadowNodeDB) } diff --git a/go.mod b/go.mod index 8405c9661a..f894cd9d85 100644 --- a/go.mod +++ b/go.mod @@ -77,12 +77,14 @@ require ( require ( github.com/Azure/azure-pipeline-go v0.2.2 // indirect github.com/Azure/go-autorest/autorest/adal v0.8.0 // indirect + github.com/RoaringBitmap/roaring v1.2.3 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.2 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.1.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.1.1 // indirect github.com/aws/smithy-go v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91 // indirect @@ -101,6 +103,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mitchellh/mapstructure v1.4.1 // indirect github.com/mitchellh/pointerstructure v1.2.0 // indirect + github.com/mschoch/smat v0.2.0 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect github.com/opentracing/opentracing-go v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 5ce59554c8..c161ec913f 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= +github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= @@ -95,6 +97,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjLyS07ChA= +github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bmizerany/pat v0.0.0-20170815010413-6226ea591a40/go.mod h1:8rLXio+WjiTceGBHIoTvn60HIbs7Hm7bcHjyrSqYB9c= github.com/bnb-chain/ics23 v0.1.0 h1:DvjGOts2FBfbxB48384CYD1LbcrfjThFz8kowY/7KxU= github.com/bnb-chain/ics23 v0.1.0/go.mod h1:cU6lTGolbbLFsGCgceNB2AzplH1xecLp6+KXvxM32nI= @@ -402,6 +406,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg= +github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= +github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/naoina/go-stringutil v0.1.0 h1:rCUeRUHjBjGTSHl0VC00jUPLz8/F9dDzYI70Hzifhks= diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 73bf6f1265..fef68eb88f 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -6,6 +6,8 @@ import ( "math/big" "sync" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) @@ -38,28 +40,63 @@ type ShadowNodeStorage interface { Delete(path string) error } +type ShadowNodeDatabase interface { + Get(addr common.Hash, path string) ([]byte, error) + Delete(addr common.Hash, path string) error + Put(addr common.Hash, path string, val []byte) error + OpenStorage(addr common.Hash) ShadowNodeStorage + Commit(number *big.Int, blockRoot common.Hash) error +} + type shadowNodeStorage4Trie struct { addr common.Hash - rw *ShadowNodeStorageRW + db ShadowNodeDatabase } -func NewShadowNodeStorage4Trie(addr common.Hash, rw *ShadowNodeStorageRW) ShadowNodeStorage { +func NewShadowNodeStorage4Trie(addr common.Hash, db ShadowNodeDatabase) ShadowNodeStorage { return &shadowNodeStorage4Trie{ addr: addr, - rw: rw, + db: db, } } func (s *shadowNodeStorage4Trie) Get(path string) ([]byte, error) { - return s.rw.Get(s.addr, path) + return s.db.Get(s.addr, path) } func (s *shadowNodeStorage4Trie) Put(path string, val []byte) error { - return s.rw.Put(s.addr, path, val) + return s.db.Put(s.addr, path, val) } func (s *shadowNodeStorage4Trie) Delete(path string) error { - return s.rw.Delete(s.addr, path) + return s.db.Delete(s.addr, path) +} + +// ShadowNodeStorageRO shadow node only could modify the latest diff layers, +// if you want to modify older state, please unwind to thr older history +type ShadowNodeStorageRO struct { + diskdb ethdb.KeyValueStore + number *big.Int +} + +func (s *ShadowNodeStorageRO) Get(addr common.Hash, path string) ([]byte, error) { + return FindHistory(s.diskdb, s.number.Uint64()+1, addr, path) +} + +func (s *ShadowNodeStorageRO) Delete(addr common.Hash, path string) error { + return errors.New("ShadowNodeStorageRO unsupported") +} + +func (s *ShadowNodeStorageRO) Put(addr common.Hash, path string, val []byte) error { + return errors.New("ShadowNodeStorageRO unsupported") +} + +func (s *ShadowNodeStorageRO) OpenStorage(addr common.Hash) ShadowNodeStorage { + return NewShadowNodeStorage4Trie(addr, s) +} + +func (s *ShadowNodeStorageRO) Commit(number *big.Int, blockRoot common.Hash) error { + return errors.New("ShadowNodeStorageRO unsupported") } type ShadowNodeStorageRW struct { @@ -71,12 +108,17 @@ type ShadowNodeStorageRW struct { lock sync.RWMutex } -func NewShadowNodeStorageRW(tree *ShadowNodeSnapTree, blockRoot common.Hash) (*ShadowNodeStorageRW, error) { +// NewShadowNodeDatabase first find snap by blockRoot, if got nil, try using number to instance a read only storage +func NewShadowNodeDatabase(tree *ShadowNodeSnapTree, number *big.Int, blockRoot common.Hash) (ShadowNodeDatabase, error) { snap := tree.Snapshot(blockRoot) if snap == nil { // try using default snap if snap = tree.Snapshot(emptyRoot); snap == nil { - return nil, errors.New("cannot find the snap") + // open read only history + return &ShadowNodeStorageRO{ + diskdb: tree.DB(), + number: number, + }, nil } } return &ShadowNodeStorageRW{ diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go index c7c0a5d1ad..615fe1a595 100644 --- a/trie/shadow_node_difflayer.go +++ b/trie/shadow_node_difflayer.go @@ -8,6 +8,10 @@ import ( "math/big" "sync" + "github.com/RoaringBitmap/roaring/roaring64" + "github.com/ethereum/go-ethereum/common/math" + lru "github.com/hashicorp/golang-lru" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" @@ -16,8 +20,9 @@ import ( const ( // MaxShadowNodeDiffDepth default is 128 layers - MaxShadowNodeDiffDepth = 128 - journalVersion uint64 = 1 + MaxShadowNodeDiffDepth = 128 + journalVersion uint64 = 1 + defaultDiskLayerCacheSize = 100000 ) // shadowNodeSnapshot record diff layer and disk layer of shadow nodes, support mini reorg @@ -43,14 +48,10 @@ type shadowNodeSnapshot interface { type ShadowNodeSnapTree struct { diskdb ethdb.KeyValueStore - // diff layers + // diffLayers + diskLayer, disk layer, always not nil layers map[common.Hash]shadowNodeSnapshot children map[common.Hash][]common.Hash - // disk layer - // TODO(0xbundler): add disk layer for history & changeSet - diskCache map[common.Hash]map[string][]byte - lock sync.RWMutex } @@ -59,13 +60,6 @@ func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore) (*ShadowNodeSnapTree, err if err != nil { return nil, err } - // if there is no disk layer, will construct a fake disk layer - if diskLayer == nil { - diskLayer, err = newShadowNodeDiskLayer(diskdb, common.Big0, emptyRoot) - if err != nil { - return nil, err - } - } layers, children, err := loadDiffLayers(diskdb, diskLayer) if err != nil { return nil, err @@ -77,15 +71,13 @@ func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore) (*ShadowNodeSnapTree, err return nil, errors.New("cannot found any diff layers link to disk layer") } return &ShadowNodeSnapTree{ - diskdb: diskdb, - layers: layers, - children: children, - diskCache: make(map[common.Hash]map[string][]byte), + diskdb: diskdb, + layers: layers, + children: children, }, nil } // Cap keep tree depth not greater MaxShadowNodeDiffDepth, all forks parent to disk layer will delete -// TODO(0xbundler): store disk layer meta(blockNumber, blockRoot) too func (s *ShadowNodeSnapTree) Cap(blockRoot common.Hash) error { snap := s.Snapshot(blockRoot) if snap == nil { @@ -183,6 +175,12 @@ func (s *ShadowNodeSnapTree) Snapshot(blockRoot common.Hash) shadowNodeSnapshot return s.layers[blockRoot] } +func (s *ShadowNodeSnapTree) DB() ethdb.KeyValueStore { + s.lock.RLock() + defer s.lock.RUnlock() + return s.diskdb +} + func (s *ShadowNodeSnapTree) Journal() error { s.lock.Lock() defer s.lock.Unlock() @@ -227,8 +225,25 @@ func (s *ShadowNodeSnapTree) flattenDiffs2Disk(flatten []shadowNodeSnapshot, dis // loadDiskLayer load from db, could be nil when none in db func loadDiskLayer(db ethdb.KeyValueStore) (*shadowNodeDiskLayer, error) { - // TODO(0xbundler): load disk layer meta(blockNumber, blockRoot) - return nil, nil + val := rawdb.ReadShadowNodePlainStateMeta(db) + // if there is no disk layer, will construct a fake disk layer + if len(val) == 0 { + diskLayer, err := newShadowNodeDiskLayer(db, common.Big0, emptyRoot) + if err != nil { + return nil, err + } + return diskLayer, nil + } + var meta shadowNodePlainMeta + if err := rlp.DecodeBytes(val, &meta); err != nil { + return nil, err + } + + layer, err := newShadowNodeDiskLayer(db, meta.BlockNumber, meta.BlockRoot) + if err != nil { + return nil, err + } + return layer, nil } func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map[common.Hash]shadowNodeSnapshot, map[common.Hash][]common.Hash, error) { @@ -313,6 +328,7 @@ type shadowNodeDiffLayer struct { parent shadowNodeSnapshot nodeSet map[common.Hash]map[string][]byte + // TODO(0xbundler): add destruct handle later? lock sync.RWMutex } @@ -350,7 +366,7 @@ func (s *shadowNodeDiffLayer) Parent() shadowNodeSnapshot { return s.parent } -// Update append new diff layer onto current, nodeSet when val is []byte{}, it delete the kv +// Update append new diff layer onto current, nodeChgRecord when val is []byte{}, it delete the kv func (s *shadowNodeDiffLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { s.lock.RLock() if s.blockNumber.Cmp(blockNumber) >= 0 { @@ -403,28 +419,43 @@ func (s *shadowNodeDiffLayer) setParent(parent shadowNodeSnapshot) { s.parent = parent } +func (s *shadowNodeDiffLayer) getNodeSet() map[common.Hash]map[string][]byte { + s.lock.Lock() + defer s.lock.Unlock() + return s.nodeSet +} + type journalShadowNode struct { Hash common.Hash Keys []string Vals [][]byte } +type shadowNodePlainMeta struct { + BlockNumber *big.Int + BlockRoot common.Hash +} + type shadowNodeDiskLayer struct { // TODO(0xbundler): add history & changeSet later - diskdb ethdb.KeyValueReader + diskdb ethdb.KeyValueStore blockNumber *big.Int blockRoot common.Hash - cache map[common.Hash]map[string][]byte + cache *lru.Cache lock sync.RWMutex } -func newShadowNodeDiskLayer(diskdb ethdb.KeyValueReader, blockNumber *big.Int, blockRoot common.Hash) (*shadowNodeDiskLayer, error) { +func newShadowNodeDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, blockRoot common.Hash) (*shadowNodeDiskLayer, error) { + cache, err := lru.New(defaultDiskLayerCacheSize) + if err != nil { + return nil, err + } return &shadowNodeDiskLayer{ diskdb: diskdb, blockNumber: blockNumber, blockRoot: blockRoot, - cache: make(map[common.Hash]map[string][]byte), + cache: cache, }, nil } @@ -434,30 +465,23 @@ func (s *shadowNodeDiskLayer) Root() common.Hash { return s.blockRoot } -func (s *shadowNodeDiskLayer) ShadowNode(addrHash common.Hash, path string) ([]byte, error) { +func (s *shadowNodeDiskLayer) ShadowNode(addr common.Hash, path string) ([]byte, error) { s.lock.RLock() defer s.lock.RUnlock() - nodeSet, exist := s.cache[addrHash] + cacheKey := shadowNodeCacheKey(addr, path) + cached, exist := s.cache.Get(cacheKey) if exist { - if enc, ok := nodeSet[path]; ok { - return enc, nil - } + return cached.([]byte), nil } - //TODO(0xbundler): return history & changeSet later - //s.cache[addrHash] = make(map[string][]byte) - //n := shadowBranchNode{ - // ShadowHash: common.Hash{}, - // EpochMap: [16]types.StateEpoch{}, - //} - //enc, err := rlp.EncodeToBytes(n) - //if err != nil { - // return nil, err - //} - //s.cache[addrHash][path] = enc - //return enc, nil - return nil, nil + val, err := FindHistory(s.diskdb, s.blockNumber.Uint64()+1, addr, path) + if err != nil { + return nil, err + } + + s.cache.Add(cacheKey, val) + return val, err } func (s *shadowNodeDiskLayer) Parent() shadowNodeSnapshot { @@ -481,24 +505,98 @@ func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDi s.lock.Lock() defer s.lock.Unlock() - if s.blockNumber.Cmp(diff.blockNumber) >= 0 { + number := diff.blockNumber + if s.blockNumber.Cmp(number) >= 0 { return nil, errors.New("push a lower block to disk") } - // TODO(0xbundler): store diff to DB - diskLayer, err := newShadowNodeDiskLayer(s.diskdb, diff.blockNumber, diff.blockRoot) + batch := s.diskdb.NewBatch() + nodeSet := diff.getNodeSet() + for addr, subSet := range nodeSet { + changeSet := make([]nodeChgRecord, 0, len(subSet)) + for path, val := range subSet { + if err := refreshShadowNodeHistory(s.diskdb, batch, addr, path, number.Uint64()); err != nil { + return nil, err + } + prev := rawdb.ReadShadowNodePlainState(s.diskdb, addr, path) + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { + return nil, err + } + } else { + if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { + return nil, err + } + } + + changeSet = append(changeSet, nodeChgRecord{ + Path: path, + Prev: prev, + }) + } + enc, err := rlp.EncodeToBytes(changeSet) + if err != nil { + return nil, err + } + if err = rawdb.WriteShadowNodeChangeSet(batch, addr, number.Uint64(), enc); err != nil { + return nil, err + } + } + + // update meta + meta := shadowNodePlainMeta{ + BlockNumber: number, + BlockRoot: diff.blockRoot, + } + enc, err := rlp.EncodeToBytes(meta) if err != nil { return nil, err } + if err = rawdb.WriteShadowNodePlainStateMeta(batch, enc); err != nil { + return nil, err + } + + if err = batch.Write(); err != nil { + return nil, err + } + diskLayer := &shadowNodeDiskLayer{ + diskdb: s.diskdb, + blockNumber: number, + blockRoot: diff.blockRoot, + cache: s.cache, + } // reuse cache - diskLayer.cache = s.cache for addr, nodes := range diff.nodeSet { - if diskLayer.cache[addr] == nil { - diskLayer.cache[addr] = make(map[string][]byte) + for path, val := range nodes { + diskLayer.cache.Add(shadowNodeCacheKey(addr, path), val) } - for k, v := range nodes { - diskLayer.cache[addr][k] = v + } + return diskLayer, nil +} + +func shadowNodeCacheKey(addr common.Hash, path string) string { + key := make([]byte, len(addr)+len(path)) + copy(key[:], addr.Bytes()) + copy(key[len(addr):], path) + return string(key) +} + +func refreshShadowNodeHistory(db ethdb.KeyValueReader, batch ethdb.Batch, addr common.Hash, path string, number uint64) error { + enc := rawdb.ReadShadowNodeHistory(db, addr, path, math.MaxUint64) + index := roaring64.New() + if len(enc) > 0 { + if _, err := index.ReadFrom(bytes.NewReader(enc)); err != nil { + return err } } - return diskLayer, err + index.Add(number) + enc, err := index.ToBytes() + if err != nil { + return err + } + if err = rawdb.WriteShadowNodeHistory(batch, addr, path, math.MaxUint64, enc); err != nil { + return err + } + return nil } diff --git a/trie/shadow_node_history.go b/trie/shadow_node_history.go new file mode 100644 index 0000000000..6457e88ca3 --- /dev/null +++ b/trie/shadow_node_history.go @@ -0,0 +1,86 @@ +package trie + +import ( + "bytes" + "errors" + + "github.com/RoaringBitmap/roaring/roaring64" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" +) + +func FindHistory(db ethdb.KeyValueStore, number uint64, addr common.Hash, path string) ([]byte, error) { + val, found, err := findHistory(db, number, addr, path) + if err != nil { + return nil, err + } + + if found { + return val, nil + } + + // query from plain state + return rawdb.ReadShadowNodePlainState(db, addr, path), nil +} + +type nodeChgRecord struct { + Path string + Prev []byte +} + +func findHistory(db ethdb.KeyValueStore, number uint64, addr common.Hash, path string) ([]byte, bool, error) { + // TODO(0xbundler): split shards according bitmap size later, less than 1mb + hbytes := rawdb.ReadShadowNodeHistory(db, addr, path, math.MaxUint64) + if len(hbytes) == 0 { + return nil, false, nil + } + + index := roaring64.New() + if _, err := index.ReadFrom(bytes.NewReader(hbytes)); err != nil { + return nil, false, err + } + found, ok := SeekInBitmap64(index, number) + if !ok { + return nil, false, nil + } + + // TODO(0xbundler): using mdbx's DupSort? + changeSet := rawdb.ReadShadowNodeChangeSet(db, addr, found) + if len(changeSet) == 0 { + return nil, false, errors.New("cannot find target changeSet") + } + var ns []nodeChgRecord + if err := rlp.DecodeBytes(changeSet, &ns); err != nil { + return nil, false, err + } + nodeSetMap := make(map[string][]byte) + for _, n := range ns { + nodeSetMap[n.Path] = n.Prev + } + + val, exist := nodeSetMap[path] + if !exist { + return nil, false, errors.New("cannot find path's change val") + } + + return val, true, nil +} + +// SeekInBitmap - returns value in bitmap which is >= n +func SeekInBitmap64(m *roaring64.Bitmap, n uint64) (found uint64, ok bool) { + if m.IsEmpty() { + return 0, false + } + if n == 0 { + return m.Minimum(), true + } + searchRank := m.Rank(n - 1) + if searchRank >= m.GetCardinality() { + return 0, false + } + found, _ = m.Select(searchRank) + return found, true +} diff --git a/trie/shadow_node_history_test.go b/trie/shadow_node_history_test.go new file mode 100644 index 0000000000..d502cbcc8d --- /dev/null +++ b/trie/shadow_node_history_test.go @@ -0,0 +1,55 @@ +package trie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func TestShadowNodeHistory_Diff2Disk(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb) + assert.NoError(t, err) + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + _, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + // find history + val, err := FindHistory(diskdb, common.Big2.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + val, err = FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + + // reload disk layer + diskLayer, err = loadDiskLayer(diskdb) + assert.NoError(t, err) + val, err = diskLayer.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) +} + +func TestShadowNodeHistory_case2(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb) + assert.NoError(t, err) + + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + diff = newShadowNodeDiffLayer(common.Big2, blockRoot2, nil, makeNodeSet(contract1, []string{"hello", "world1"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err := FindHistory(diskdb, common.Big2.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + + val, err = FindHistory(diskdb, common.Big3.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world1"), val) +} diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go index 9a71bc095f..2d8853d941 100644 --- a/trie/shadow_node_test.go +++ b/trie/shadow_node_test.go @@ -1,8 +1,12 @@ package trie import ( + "math/big" "testing" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/stretchr/testify/assert" @@ -12,39 +16,79 @@ func TestShadowNodeRW_CRUD(t *testing.T) { diskdb := memorydb.New() tree, err := NewShadowNodeSnapTree(diskdb) assert.NoError(t, err) - storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) - err = storageRW.Put(contract1, "hello", []byte("world")) + err = storageDB.Put(contract1, "hello", []byte("world")) assert.NoError(t, err) - err = storageRW.Put(contract1, "hello", []byte("world")) + err = storageDB.Put(contract1, "hello", []byte("world")) assert.NoError(t, err) - val, err := storageRW.Get(contract1, "hello") + val, err := storageDB.Get(contract1, "hello") assert.NoError(t, err) assert.Equal(t, []byte("world"), val) - err = storageRW.Delete(contract1, "hello") + err = storageDB.Delete(contract1, "hello") + assert.NoError(t, err) + val, err = storageDB.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) +} + +func TestShadowNodeRO_Get(t *testing.T) { + diskdb := memorydb.New() + makeDiskLayer(diskdb, common.Big2, blockRoot2, contract1, []string{"k1", "v1"}) + + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageRO, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) - val, err = storageRW.Get(contract1, "hello") + + err = storageRO.Put(contract1, "hello", []byte("world")) + assert.Error(t, err) + err = storageRO.Delete(contract1, "hello") + assert.Error(t, err) + err = storageRO.Commit(common.Big2, blockRoot2) + assert.Error(t, err) + + val, err := storageRO.Get(contract1, "hello") assert.NoError(t, err) assert.Equal(t, []byte(nil), val) + val, err = storageRO.Get(contract1, "k1") + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), val) +} + +func makeDiskLayer(diskdb *memorydb.Database, number *big.Int, root common.Hash, addr common.Hash, kv []string) { + if len(kv)%2 != 0 { + panic("wrong kv") + } + meta := shadowNodePlainMeta{ + BlockNumber: number, + BlockRoot: root, + } + enc, _ := rlp.EncodeToBytes(&meta) + rawdb.WriteShadowNodePlainStateMeta(diskdb, enc) + + for i := 0; i < len(kv); i += 2 { + rawdb.WriteShadowNodePlainState(diskdb, addr, kv[i], []byte(kv[i+1])) + } } func TestShadowNodeRW_Commit(t *testing.T) { diskdb := memorydb.New() tree, err := NewShadowNodeSnapTree(diskdb) assert.NoError(t, err) - storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) - err = storageRW.Put(contract1, "hello", []byte("world")) + err = storageDB.Put(contract1, "hello", []byte("world")) assert.NoError(t, err) - err = storageRW.Commit(common.Big1, blockRoot1) + err = storageDB.Commit(common.Big1, blockRoot1) assert.NoError(t, err) - storageRW, err = NewShadowNodeStorageRW(tree, blockRoot1) + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) - val, err := storageRW.Get(contract1, "hello") + val, err := storageDB.Get(contract1, "hello") assert.NoError(t, err) assert.Equal(t, []byte("world"), val) } @@ -53,11 +97,11 @@ func TestNewShadowNodeStorage4Trie(t *testing.T) { diskdb := memorydb.New() tree, err := NewShadowNodeSnapTree(diskdb) assert.NoError(t, err) - storageRW, err := NewShadowNodeStorageRW(tree, blockRoot1) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) - s1 := storageRW.OpenStorage(contract1) - s2 := storageRW.OpenStorage(contract2) + s1 := storageDB.OpenStorage(contract1) + s2 := storageDB.OpenStorage(contract2) val, err := s1.Get("hello") assert.NoError(t, err) assert.Equal(t, []byte(nil), val) @@ -76,6 +120,6 @@ func TestNewShadowNodeStorage4Trie(t *testing.T) { val, _ = s2.Get("h2") assert.Equal(t, []byte("w2"), val) - err = storageRW.Commit(common.Big1, blockRoot2) + err = storageDB.Commit(common.Big1, blockRoot2) assert.NoError(t, err) } diff --git a/trie/trie.go b/trie/trie.go index 07af8dbb90..750b06a7f3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -941,7 +941,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte return nil } - if t.currentEpoch > 0 && t.sndb == nil { + if t.sndb == nil { return errors.New("cannot resolve shadow node") } diff --git a/trie/trie_test.go b/trie/trie_test.go index 3d8fe5e2d7..f297c5d794 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -837,10 +837,10 @@ func TestTrie_ShadowNodeRW(t *testing.T) { database := NewDatabase(diskdb) tree, err := NewShadowNodeSnapTree(diskdb) assert.NoError(t, err) - storageRW, err := NewShadowNodeStorageRW(tree, blockRoot0) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) assert.NoError(t, err) - tr, err := NewStorageSecure(types.StateEpoch(1), emptyRoot, database, storageRW.OpenStorage(contract1)) + tr, err := NewStorageSecure(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) assert.NoError(t, err) tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) @@ -849,21 +849,21 @@ func TestTrie_ShadowNodeRW(t *testing.T) { assert.Equal(t, makeHash("v1").Bytes(), val) assert.NoError(t, tr.TryDelete(makeHash("k1").Bytes())) - // TODO(0xbundle): need MPT support commit with shadow node + // TODO(0xbundler): need MPT support commit with shadow node //nextRoot, _, err := tr.Commit(nil) //assert.NoError(t, err) - //assert.NoError(t, storageRW.Commit(common.Big1, blockRoot1)) + //assert.NoError(t, storageDB.Commit(common.Big1, blockRoot1)) // reload - //storageRW, err = NewShadowNodeStorageRW(tree, blockRoot1) + //storageDB, err = NewShadowNodestorageDB(tree, blockRoot1) //assert.NoError(t, err) - //tr, err = NewStorageSecure(types.StateEpoch(2), nextRoot, database, storageRW.OpenStorage(contract1)) + //tr, err = NewStorageSecure(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) //assert.NoError(t, err) //val, err = tr.TryGet(makeHash("k2").Bytes()) //assert.NoError(t, err) //assert.Equal(t, makeHash("v2").Bytes(), val) //// check expired - //tr, err = NewStorageSecure(types.StateEpoch(3), nextRoot, database, storageRW.OpenStorage(contract1)) + //tr, err = NewStorageSecure(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) //assert.NoError(t, err) //val, err = tr.TryGet(makeHash("k2").Bytes()) //assert.Error(t, err) From 3b9f70e73a8c3ba0da4d3126ad197312660eb182 Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 27 Apr 2023 14:41:43 +0800 Subject: [PATCH 34/51] feat(RPC): add EstimateGasAndReviveState RPC method complete estimate gas and revive state add test and function comments remove TODO --- ethclient/ethclient.go | 18 +++ ethclient/ethclient_test.go | 23 ++++ internal/ethapi/api.go | 260 ++++++++++++++++++++++++++++++++++++ 3 files changed, 301 insertions(+) diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index db388a3d9d..d708064b34 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -559,6 +559,24 @@ func (ec *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64 return uint64(hex), nil } +// Result struct for EstimateGasAndReviveState +type EstimateGasAndReviveStateResult struct { + Hex hexutil.Uint64 `json:"gas"` + ReviveWitness []types.ReviveWitness `json:"reviveWitness"` +} + +// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It also attempts to revive expired +// storage trie and returns the revive witness list. +func (ec *Client) EstimateGasAndReviveState(ctx context.Context, msg ethereum.CallMsg) (*EstimateGasAndReviveStateResult, error) { + var result EstimateGasAndReviveStateResult + err := ec.c.CallContext(ctx, &result, "eth_estimateGasAndReviveState", toCallArg(msg)) + if err != nil { + return &result, err + } + return &result, nil +} + // SendTransaction injects a signed transaction into the pending pool for execution. // // If the transaction was a contract creation use the TransactionReceipt method to get the diff --git a/ethclient/ethclient_test.go b/ethclient/ethclient_test.go index b488487620..b0b4074997 100644 --- a/ethclient/ethclient_test.go +++ b/ethclient/ethclient_test.go @@ -376,6 +376,9 @@ func TestEthClient(t *testing.T) { "TestDiffAccounts": { func(t *testing.T) { testDiffAccounts(t, client) }, }, + "TestEstimateGasAndReviveState": { + func(t *testing.T) { testEstimateGasAndReviveState(t, client) }, + }, // DO not have TestAtFunctions now, because we do not have pending block now } @@ -620,6 +623,26 @@ func testCallContract(t *testing.T, client *rpc.Client) { } } +func testEstimateGasAndReviveState(t *testing.T, client *rpc.Client) { + ec := NewClient(client) + + // EstimateGas + msg := ethereum.CallMsg{ + From: testAddr, + To: &common.Address{}, + Gas: 21000, + Value: big.NewInt(1), + } + result, err := ec.EstimateGasAndReviveState(context.Background(), msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Hex != 21000 { + t.Fatalf("unexpected gas price: %v", result.Hex) + } + // TODO(asyukii): add more test cases here +} + func testDiffAccounts(t *testing.T, client *rpc.Client) { ec := NewClient(client) ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond) diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 8c63b94442..cb07d39a7d 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -957,6 +957,61 @@ func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash return result, nil } +func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, []*vm.EVMError, error) { + defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) + + state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if state == nil || err != nil { + return nil, nil, err + } + if err := overrides.Apply(state); err != nil { + return nil, nil, err + } + // Setup context so it may be cancelled the call has completed + // or, in case of unmetered gas, setup a context with a timeout. + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + // Make sure the context is cancelled when the call has completed + // this makes sure resources are cleaned up. + defer cancel() + + // Get a new instance of the EVM. + msg, err := args.ToMessage(globalGasCap, header.BaseFee) + if err != nil { + return nil, nil, err + } + evm, vmError, err := b.GetEVM(ctx, msg, state, header, &vm.Config{NoBaseFee: true}) + if err != nil { + return nil, nil, err + } + // Wait for the context to be done and cancel the evm. Even if the + // EVM has finished, cancelling may be done (repeatedly) + gopool.Submit(func() { + <-ctx.Done() + evm.Cancel() + }) + + // Execute the message. + gp := new(core.GasPool).AddGas(math.MaxUint64) + result, err := core.ApplyMessage(evm, msg, gp) + if err := vmError(); err != nil { + return nil, evm.ErrorCollection, err + } + + // If the timer caused an abort, return an appropriate error message + if evm.Cancelled() { + return nil, evm.ErrorCollection, fmt.Errorf("execution aborted (timeout = %v)", timeout) + } + if err != nil { + return result, evm.ErrorCollection, fmt.Errorf("err: %w (supplied gas %d)", err, msg.Gas()) + } + return result, evm.ErrorCollection, nil +} + func newRevertError(result *core.ExecutionResult) *revertError { reason, errUnpack := abi.UnpackRevert(result.Revert()) err := errors.New("execution reverted") @@ -1135,6 +1190,211 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args TransactionA return DoEstimateGas(ctx, s.b, args, bNrOrHash, s.b.RPCGasCap()) } +// Result structs for EstimateGasAndReviveState +type EstimateGasAndReviveStateResult struct { + Hex hexutil.Uint64 `json:"gas"` + ReviveWitness []types.ReviveWitness `json:"reviveWitness"` +} + +// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It also attempts to revive expired +// storage trie and returns the revive witness list. +func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, gasCap uint64) (*EstimateGasAndReviveStateResult, error) { + + var result EstimateGasAndReviveStateResult + var stateDb *state.StateDB + + // Initialize witnessList + var witnessList []types.ReviveWitness + if args.WitnessList != nil { + witnessList = *args.WitnessList + } + witLen := len(witnessList) + + // Binary search the gas requirement, as it may be higher than the amount used + var ( + lo uint64 = params.TxGas - 1 + hi uint64 + cap uint64 + ) + // Use zero address if sender unspecified. + if args.From == nil { + args.From = new(common.Address) + } + // Determine the highest gas limit can be used during the estimation. + if args.Gas != nil && uint64(*args.Gas) >= params.TxGas { + hi = uint64(*args.Gas) + } else { + // Retrieve the block to act as the gas ceiling + block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) + if err != nil { + return nil, err + } + if block == nil { + return nil, errors.New("block not found") + } + hi = block.GasLimit() + } + // Normalize the max fee per gas the call is willing to spend. + var feeCap *big.Int + if args.GasPrice != nil && (args.MaxFeePerGas != nil || args.MaxPriorityFeePerGas != nil) { + return nil, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } else if args.GasPrice != nil { + feeCap = args.GasPrice.ToInt() + } else if args.MaxFeePerGas != nil { + feeCap = args.MaxFeePerGas.ToInt() + } else { + feeCap = common.Big0 + } + // Recap the highest gas limit with account's available balance. + if feeCap.BitLen() != 0 { + stateDb, _, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if err != nil { + return nil, err + } + balance := stateDb.GetBalance(*args.From) // from can't be nil + available := new(big.Int).Set(balance) + if args.Value != nil { + if args.Value.ToInt().Cmp(available) >= 0 { + return nil, errors.New("insufficient funds for transfer") + } + available.Sub(available, args.Value.ToInt()) + } + allowance := new(big.Int).Div(available, feeCap) + + // If the allowance is larger than maximum uint64, skip checking + if allowance.IsUint64() && hi > allowance.Uint64() { + transfer := args.Value + if transfer == nil { + transfer = new(hexutil.Big) + } + log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance, + "sent", transfer.ToInt(), "maxFeePerGas", feeCap, "fundable", allowance) + hi = allowance.Uint64() + } + } + // Recap the highest gas allowance with specified gascap. + if gasCap != 0 && hi > gasCap { + log.Debug("Caller gas above allowance, capping", "requested", hi, "cap", gasCap) + hi = gasCap + } + cap = hi + + // Create a helper to check if a gas allowance results in an executable transaction + executable := func(gas uint64, witnessList []types.ReviveWitness) (bool, *core.ExecutionResult, bool, error) { + args.Gas = (*hexutil.Uint64)(&gas) + + result, evmErrors, err := DoCallExpired(ctx, b, args, blockNrOrHash, nil, 0, gasCap) // TODO (asyukii): Use a different call function to return EVM errors + + // Create MPTProof + isExpiredError := false + if len(evmErrors) > 0 { + addressToProofMap := make(map[common.Address][]types.MPTProof) + for _, evmErr := range evmErrors { + if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { + isExpiredError = true + proof, err := stateDb.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) + if err != nil { + return true, nil, isExpiredError, err + } + addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ + RootKeyHex: stateErr.Path, + Proof: proof, + }) + } + } + + // Create a ReviveWitness object for each address and add it to witnessList + for addr, proofs := range addressToProofMap { + // Build a storageTrieWitness + storageTrieWitness := types.StorageTrieWitness{ + Address: addr, + ProofList: proofs, + } + // Encode StorageTrieWitness + enc, err := rlp.EncodeToBytes(storageTrieWitness) + if err != nil { + return true, nil, isExpiredError, err + } + // Create a ReviveWitness + reviveWitness := types.ReviveWitness{ + WitnessType: types.StorageTrieWitnessType, + Data: enc, + } + // Append to witness list + witnessList = append(witnessList, reviveWitness) + } + } + if err != nil { + if errors.Is(err, core.ErrIntrinsicGas) { + return true, nil, isExpiredError, nil // Special case, raise gas limit + } + return true, nil, isExpiredError, err // Bail out + } + return result.Failed(), result, isExpiredError, nil + } + // Execute the binary search and hone in on an executable gas limit + for lo+1 < hi { + mid := (hi + lo) / 2 + failed, _, isExpiredError, err := executable(mid, witnessList) + + if isExpiredError { + if witLen == len(witnessList) { + // If witnessList is not updated, it means that the proofs are not + // sufficient to revive the state. + return nil, fmt.Errorf("cannot generate enough proofs to revive the state") + } + witLen = len(witnessList) + continue + } + + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more. + if err != nil { + return nil, err + } + if failed { + lo = mid + } else { + hi = mid + } + } + // Reject the transaction as invalid if it still fails at the highest allowance + if hi == cap { + failed, result, _, err := executable(hi, witnessList) + if err != nil { + return nil, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } + return nil, result.Err + } + // Otherwise, the specified gas cap is too low + return nil, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } + } + result = EstimateGasAndReviveStateResult{ + Hex: hexutil.Uint64(hi), + ReviveWitness: witnessList, + } + return &result, nil +} + +// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It will also revive the state +// temporarily to estimate the gas. +func (s *PublicBlockChainAPI) EstimateGasAndReviveState(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash) (*EstimateGasAndReviveStateResult, error) { + bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash + } + return DoEstimateGasAndReviveState(ctx, s.b, args, bNrOrHash, s.b.RPCGasCap()) +} + // GetDiffAccounts returns changed accounts in a specific block number. func (s *PublicBlockChainAPI) GetDiffAccounts(ctx context.Context, blockNr rpc.BlockNumber) ([]common.Address, error) { if s.b.Chain() == nil { From 9f844da972c4de9f547ee9aedcd6d9eaa3bc94e7 Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 26 Apr 2023 16:11:26 +0800 Subject: [PATCH 35/51] refactor(trie): modify ReviveTrie to include shadow logic everything is fine at this point refactor(trie): add shadow logic to MPT Revive add shadow logic --- trie/proof.go | 4 +- trie/proof_test.go | 31 +------- trie/trie.go | 118 +++++++++++++++++++++++------ trie/trie_test.go | 182 +++++++++++++++++++++++++++++++++++---------- 4 files changed, 242 insertions(+), 93 deletions(-) diff --git a/trie/proof.go b/trie/proof.go index cec1564ab9..be492c1db8 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -113,10 +113,10 @@ func (t *Trie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethd defer returnHasherToPool(hasher) // construct the proof - for i, n := range nodes { + for _, n := range nodes { var hn node n, hn = hasher.proofHash(n) - if hash, ok := hn.(hashNode); ok || i == 0 { + if hash, ok := hn.(hashNode); ok { enc := nodeToBytes(n) if !ok { hash = hasher.hashData(enc) diff --git a/trie/proof_test.go b/trie/proof_test.go index 10eeb5cb32..b05f2ec446 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -923,33 +923,6 @@ func TestStorageProof(t *testing.T) { } } -// TestOneElementStorageProof tests the storage proof generation and verification -// for a trie with only one element. -func TestOneElementStorageProof(t *testing.T) { - trie := new(Trie) - updateString(trie, "k", "v") - - proof := memorydb.New() - key := []byte("k") - err := trie.ProveStorageWitness(key, nil, proof) - if err != nil { - t.Fatalf("missing key %x while constructing proof", key) - } - - if proof.Len() != 1 { - t.Errorf("proof should have one element") - } - - val, err := VerifyProof(trie.Hash(), []byte("k"), proof) - if err != nil { - t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proof) - } - - if !bytes.Equal(val, []byte("v")) { - t.Fatalf("verified value mismatch: have %x, want 'v'", val) - } -} - // TestEmptyStorageProof tests storage verification with empty proof. // The verifier should nil for both value and error. func TestEmptyStorageProof(t *testing.T) { @@ -1300,6 +1273,8 @@ func randBytes(n int) []byte { func nonRandomTrie(n int) (*Trie, map[string]*kv) { trie := new(Trie) + trie.isStorageTrie = true + trie.currentEpoch = 10 // TODO (asyukii): might need to change this vals := make(map[string]*kv) max := uint64(0xffffffffffffffff) for i := uint64(0); i < uint64(n); i++ { @@ -1309,7 +1284,7 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) { binary.LittleEndian.PutUint64(value, i-max) //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} elem := &kv{key, value, false} - trie.Update(elem.k, elem.v) + trie.Update(elem.k, elem.v) // TODO (asyukii): this is not working, the shadow branch node is not being updated vals[string(elem.k)] = elem } return trie, vals diff --git a/trie/trie.go b/trie/trie.go index 750b06a7f3..79268fc70f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -465,6 +465,9 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) + // TODO (asyukii) + // This code block has a problem. When inserting a new node, it will check childExpired, so + // by default, it will always return expired, so it will never insert a new node. if t.currentEpoch >= 2 { // if child is expired, return err if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { @@ -484,7 +487,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE return true, n, nil case nil: - return true, &shortNode{Key: key, Val: value, flags: t.newFlag()}, nil + return true, &shortNode{Key: key, Val: value, flags: t.newFlag(), epoch: t.currentEpoch}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -681,8 +684,12 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, } } +// ExpireByPrefix is used to simulate the expiration of a trie by prefix key. +// It is not used in the actual trie implementation. ExpireByPrefix makes sure +// only a child node of a full node is expired, if not an error is returned. func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { - hn, err := t.expireByPrefix(t.root, prefixKeyHex) + // TODO (asyukii): Add support for RootNode + hn, _, err := t.expireByPrefix(t.root, prefixKeyHex) if prefixKeyHex == nil && hn != nil { t.root = hn } @@ -692,7 +699,7 @@ func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { return nil } -func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { +func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, bool, error) { // Loop through prefix key // When prefix key is empty, generate the hash node of the current node // Replace current node with the hash node @@ -704,30 +711,30 @@ func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { var hn node _, hn = hasher.proofHash(n) if _, ok := hn.(hashNode); ok { - return hn, nil + return hn, false, nil } - return nil, nil + return nil, true, nil } switch n := n.(type) { case *shortNode: matchLen := prefixLen(prefixKeyHex, n.Key) - hn, err := t.expireByPrefix(n.Val, prefixKeyHex[matchLen:]) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Val, prefixKeyHex[matchLen:]) if err != nil { - return nil, err + return nil, didUpdateEpoch, err } - // Replace child node with hash node if hn != nil { - n.Val = hn + return nil, didUpdateEpoch, fmt.Errorf("can only expire child short node") } - return nil, err + return nil, didUpdateEpoch, err case *fullNode: - hn, err := t.expireByPrefix(n.Children[prefixKeyHex[0]], prefixKeyHex[1:]) + childIndex := int(prefixKeyHex[0]) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Children[childIndex], prefixKeyHex[1:]) if err != nil { - return nil, err + return nil, didUpdateEpoch, err } // Replace child node with hash node @@ -735,9 +742,15 @@ func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, error) { n.Children[prefixKeyHex[0]] = hn } - return nil, err + // Update the epoch so that it is expired + if !didUpdateEpoch { + n.UpdateChildEpoch(childIndex, 0) + didUpdateEpoch = true + } + + return nil, didUpdateEpoch, err default: - return nil, fmt.Errorf("invalid node type") + return nil, false, fmt.Errorf("invalid node type") } } @@ -855,6 +868,10 @@ func (t *Trie) Size() int { return estimateSize(t.root) } +// ReviveTrie attempts to revive a trie from a list of MPTProofNubs. +// ReviveTrie performs full or partial revive and returns a list of successful +// nubs. ReviveTrie does not guarantee that a value will be revived completely, +// if the proof is not fully valid. func (t *Trie) ReviveTrie(proof []*MPTProofNub) (successNubs []*MPTProofNub) { successNubs, err := t.TryRevive(proof) if err != nil { @@ -865,14 +882,26 @@ func (t *Trie) ReviveTrie(proof []*MPTProofNub) (successNubs []*MPTProofNub) { func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err error) { + // Revive trie with each proof nub for _, nub := range proof { - newNode, didResolve, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub) + // TODO (asyukii): Check if MPT Root is a RootNode object + // Get root node + // var root node + // if rt, ok := t.root.(RootNode); ok { + // root = rt + // } else { + // root = t.root + // } + root := t.root + path := []byte{} + newNode, didResolve, err := t.tryRevive(root, nub.n1PrefixKey, *nub, path, false) // TODO (asyukii): change isExpired to false if didResolve && err == nil { successNubs = append(successNubs, nub) t.root = newNode } } + // If no nubs were successful, return error if len(successNubs) == 0 && len(proof) != 0 { return successNubs, fmt.Errorf("all nubs failed to revive trie") } @@ -880,14 +909,17 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err return successNubs, nil } -func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub) (node, bool, error) { - if len(key) == 0 { +func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, path []byte, isExpired bool) (node, bool, error) { + // TODO (asyukii) + // 2 conditions must be met: 1) key is empty, 2) node is expired + if len(key) == 0 && isExpired { if hashNode, ok := n.(hashNode); ok { cachedHash, _ := nub.n1.cache() if bytes.Equal(cachedHash, hashNode) { - + nub.n1.setEpoch(t.currentEpoch) if nub.n2 != nil { + nub.n2.setEpoch(t.currentEpoch) switch n1 := nub.n1.(type) { case *shortNode: n1.Val = nub.n2 @@ -899,15 +931,46 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub) (node, bool, error return nub.n1, true, nil } } - return nil, false, nil + } else if len(key) == 0 && !isExpired { + return nil, false, fmt.Errorf("key %v not found", key) + } else if len(key) != 0 && isExpired { + return nil, false, &ExpiredNodeError{ + ExpiredNode: n, + Path: path, + Epoch: 0, // Set default value, will change later + } } + + // if len(key) == 0 && isExpired { + // if hashNode, ok := n.(hashNode); ok { + // cachedHash, _ := nub.n1.cache() + // if bytes.Equal(cachedHash, hashNode) { + // if n1, ok := nub.n1.(*fullNode); ok { + // n1.setEpoch(t.currentEpoch) + // } + // if nub.n2 != nil { + // switch n1 := nub.n1.(type) { + // case *shortNode: + // n1.Val = nub.n2 + // default: + // return nil, false, fmt.Errorf("invalid node type") + // } + // } + + // return nub.n1, true, nil + // } + // } + + // return nil, false, nil + // } + switch n := n.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { return nil, false, fmt.Errorf("key %v not found", key) } - newNode, didResolve, err := t.tryRevive(n.Val, key[len(n.Key):], nub) + newNode, didResolve, err := t.tryRevive(n.Val, key[len(n.Key):], nub, append(path, key[:len(n.Key)]...), isExpired) if didResolve && err == nil { n = n.copy() n.Val = newNode @@ -915,18 +978,27 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub) (node, bool, error return n, didResolve, err case *fullNode: childIndex := int(key[0]) - newNode, didResolve, err := t.tryRevive(n.Children[childIndex], key[1:], nub) + // TODO (asyukii): Check if child has expired + isExpired, _ := n.ChildExpired(nil, childIndex, t.currentEpoch) // TODO (asyukii): t.currentEpoch or t.root.getEpoch()? + newNode, didResolve, err := t.tryRevive(n.Children[childIndex], key[1:], nub, append(path, key[0]), isExpired) if didResolve && err == nil { n = n.copy() n.Children[childIndex] = newNode + n.UpdateChildEpoch(childIndex, t.currentEpoch) } + + if e, ok := err.(*ExpiredNodeError); ok { + e.Epoch = n.GetChildEpoch(childIndex) + return n, didResolve, e + } + return n, didResolve, err case hashNode: - tn, err := t.resolveHash(n, nil) // TODO (asyukii): Revisit epoch index + tn, err := t.resolveHash(n, nil) if err != nil { return nil, false, err } - return t.tryRevive(tn, key, nub) + return t.tryRevive(tn, key, nub, path, isExpired) case valueNode: return nil, false, nil case nil: diff --git a/trie/trie_test.go b/trie/trie_test.go index f297c5d794..e6c921c2e2 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -490,7 +490,7 @@ func TestExpireByPrefix(t *testing.T) { "degi": "H", } - trie := createCustomTrie(data) + trie := createCustomTrie(data, 0) rootHash := trie.Hash() for k := range data { @@ -502,13 +502,14 @@ func TestExpireByPrefix(t *testing.T) { assert.Equal(t, rootHash, currHash, "Root hash mismatch, got %x, expected %x", currHash, rootHash) // Reset trie - trie = createCustomTrie(data) + trie = createCustomTrie(data, 0) } } } -func createCustomTrie(data map[string]string) *Trie { +func createCustomTrie(data map[string]string, epoch types.StateEpoch) *Trie { trie := new(Trie) + trie.currentEpoch = epoch for k, v := range data { trie.Update([]byte(k), []byte(v)) } @@ -536,6 +537,44 @@ func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { } } +func getFullNodePrefixKeys(t *Trie, key []byte) [][]byte { + var prefixKeys [][]byte + key = keybytesToHex(key) + tn := t.root + currPath := []byte{} + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + prefixKeys = append(prefixKeys, currPath) + currPath = append(currPath, n.Key...) + key = key[len(n.Key):] + } + case *fullNode: + tn = n.Children[key[0]] + currPath = append(currPath, key[0]) + key = key[1:] + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + return nil + } + default: + return nil + } + } + + // Remove the first item in prefixKeys, which is the empty key + prefixKeys = prefixKeys[1:] + + return prefixKeys +} + // TestTryRevive tests that a trie can be revived from a proof func TestTryRevive(t *testing.T) { @@ -546,7 +585,7 @@ func TestTryRevive(t *testing.T) { for _, kv := range vals { key := kv.k val := kv.v - prefixKeys := getPrefixKeysHex(trie, key) + prefixKeys := getFullNodePrefixKeys(trie, key) for _, prefixKey := range prefixKeys { // Generate proof var proof proofList @@ -565,11 +604,11 @@ func TestTryRevive(t *testing.T) { // Revive trie _, err = trie.TryRevive(proofCache.cacheNubs) - assert.NoError(t, err) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) // Verify value exists after revive v := trie.Get(key) - assert.Equal(t, val, v) + assert.Equal(t, val, v, "value mismatch, got %x, expected %x. key %x, prefixKey %x", v, val, key, prefixKey) // Verify root hash currRootHash := trie.Hash() @@ -581,6 +620,94 @@ func TestTryRevive(t *testing.T) { } } +// TODO (asyukii): Delete this +func TestTestTryRevive(t *testing.T) { + + trie, _ := nonRandomTrie(500) + + oriRootHash := trie.Hash() + + key := []byte{0xf2, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + val := []byte{0xf3, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + prefixKey := []byte{0x0f, 0x02, 0x00, 0x01} + // Generate proof + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x", key, prefixKey) + + // Verify value exists after revive + v := trie.Get(key) + assert.Equal(t, val, v, "value mismatch, got %x, expected %x. key %x, prefixKey %x", v, val, key, prefixKey) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) +} + +// TestTryReviveCustomData tests that a trie can be revived from a proof +func TestTryReviveCustomData(t *testing.T) { + + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data, 10) + + oriRootHash := trie.Hash() + + for k, v := range data { + key := []byte(k) + val := []byte(v) + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err, "verify proof failed, key %x, prefixKey %x", key, prefixKey) + + // Revive trie + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err, "try revive failed, key %x, prefixKey %x", key, prefixKey) + + // Verify value exists after revive + v := trie.Get(key) + assert.Equal(t, val, v) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash, key, prefixKey) + + // Reset trie + trie = createCustomTrie(data, 10) + } + } +} + // TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when // the whole trie is expired. This test will fail because the parent node in // TryRevive is nil, set to RootNode when available @@ -640,8 +767,8 @@ func TestReviveBadProof(t *testing.T) { "abcd": "E", "abce": "F", "abde": "G", "abdf": "H", } - trieA := createCustomTrie(dataA) - trieB := createCustomTrie(dataB) + trieA := createCustomTrie(dataA, 0) + trieB := createCustomTrie(dataB, 0) var proofB proofList @@ -668,39 +795,13 @@ func TestReviveBadProof(t *testing.T) { } -// TestReviveAlreadyExists tests that a path cannot be revived -// again if it already exists -func TestReviveAlreadyExists(t *testing.T) { - trie := new(Trie) - key := []byte("k") - val := []byte("v") - trie.Update(key, val) - - // Generate proof - var proof proofList - - err := trie.ProveStorageWitness(key, nil, &proof) - assert.NoError(t, err) - - proofCache := makeRawMPTProofCache(nil, proof) - - err = proofCache.VerifyProof() - assert.NoError(t, err) - - _, err = trie.TryRevive(proofCache.cacheNubs) - assert.Error(t, err) - - v := trie.Get(key) - assert.Equal(t, val, v) -} - // TestReviveBadProofAfterUpdate tests that after reviving a path and // then update the value, old proof should be invalid func TestReviveBadProofAfterUpdate(t *testing.T) { trie, vals := nonRandomTrie(500) for _, kv := range vals { key := kv.k - prefixKeys := getPrefixKeysHex(trie, key) + prefixKeys := getFullNodePrefixKeys(trie, key) for _, prefixKey := range prefixKeys { var proof proofList err := trie.ProveStorageWitness(key, prefixKey, &proof) @@ -717,9 +818,9 @@ func TestReviveBadProofAfterUpdate(t *testing.T) { assert.NoError(t, err) // Revive first - _, err = trie.TryRevive(proofCache.cacheNubs) - assert.NoError(t, err) + trie.TryRevive(proofCache.cacheNubs) + // Update value trie.Update(key, []byte("new value")) // Revive again with old proof @@ -741,7 +842,7 @@ func TestPartialReviveFullProof(t *testing.T) { "defg": "E", "defh": "F", "degh": "G", "degi": "H", } - trie := createCustomTrie(data) + trie := createCustomTrie(data, 10) // Get proof var proof proofList @@ -791,13 +892,14 @@ func TestReviveValueAtFullNode(t *testing.T) { "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", } - trie := new(Trie) for i, hexKey := range hexKeys { hexKey = hexToKeybytes(hexKey) byteKeys[i] = hexKey } // Insert keys into trie + trie := new(Trie) + trie.currentEpoch = 10 for i, hexKey := range byteKeys { trie.Update(hexKey, []byte(vals[i])) } @@ -805,7 +907,7 @@ func TestReviveValueAtFullNode(t *testing.T) { key := byteKeys[9] val := vals[9] - prefixKeys := getPrefixKeysHex(trie, key) + prefixKeys := getFullNodePrefixKeys(trie, key) for _, prefixKey := range prefixKeys { var proof proofList From 12b95d010655d11193e54e6ed4a92a0e5c9f4fd4 Mon Sep 17 00:00:00 2001 From: asyukii Date: Sat, 6 May 2023 13:24:46 +0800 Subject: [PATCH 36/51] fix(trie): resolve insert error chore: remove comment refactor(trie): change variable naming add epoch parameter to tryRevive refactor(trie): restructure if-else statements refactor(trie): restructure if-else statements deconstruct if layers minor comment minor typo --- trie/trie.go | 132 ++++++++++++++++++++-------------------------- trie/trie_test.go | 84 ----------------------------- 2 files changed, 58 insertions(+), 158 deletions(-) diff --git a/trie/trie.go b/trie/trie.go index 79268fc70f..3019df07fb 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -465,13 +465,17 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) - // TODO (asyukii) - // This code block has a problem. When inserting a new node, it will check childExpired, so - // by default, it will always return expired, so it will never insert a new node. if t.currentEpoch >= 2 { - // if child is expired, return err - if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { - return false, n.Children[key[0]], err + child := n.Children[key[0]] + childKey := key[1:] + // if inserting a new node to this full node, there is no need to check whether this child is expired. + if len(childKey) != 0 { + if child != nil { + // if child is expired, return err + if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { + return false, n.Children[key[0]], err + } + } } } // else, set child node's epoch to current epoch @@ -688,7 +692,6 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, // It is not used in the actual trie implementation. ExpireByPrefix makes sure // only a child node of a full node is expired, if not an error is returned. func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { - // TODO (asyukii): Add support for RootNode hn, _, err := t.expireByPrefix(t.root, prefixKeyHex) if prefixKeyHex == nil && hn != nil { t.root = hn @@ -884,18 +887,10 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err // Revive trie with each proof nub for _, nub := range proof { - // TODO (asyukii): Check if MPT Root is a RootNode object - // Get root node - // var root node - // if rt, ok := t.root.(RootNode); ok { - // root = rt - // } else { - // root = t.root - // } - root := t.root path := []byte{} - newNode, didResolve, err := t.tryRevive(root, nub.n1PrefixKey, *nub, path, false) // TODO (asyukii): change isExpired to false - if didResolve && err == nil { + rootExpired, _ := t.nodeExpired(t.root, nil) + newNode, didRevive, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub, t.currentEpoch, path, rootExpired) + if didRevive && err == nil { successNubs = append(successNubs, nub) t.root = newNode } @@ -909,32 +904,42 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err return successNubs, nil } -func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, path []byte, isExpired bool) (node, bool, error) { - - // TODO (asyukii) - // 2 conditions must be met: 1) key is empty, 2) node is expired - if len(key) == 0 && isExpired { - if hashNode, ok := n.(hashNode); ok { - cachedHash, _ := nub.n1.cache() - if bytes.Equal(cachedHash, hashNode) { - nub.n1.setEpoch(t.currentEpoch) - if nub.n2 != nil { - nub.n2.setEpoch(t.currentEpoch) - switch n1 := nub.n1.(type) { - case *shortNode: - n1.Val = nub.n2 - default: - return nil, false, fmt.Errorf("invalid node type") - } - } +func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateEpoch, path []byte, isExpired bool) (node, bool, error) { + + // To revive a node, few conditions must be met: + // 1. key length must be 0, indicating that the targeted node is reached + // 2. the node must be expired + // 3. the node must be a hash node + // 4. the node hash must match the hash value of nub + + if len(key) == 0 { + if !isExpired { + return nil, false, fmt.Errorf("key %v not found", key) + } - return nub.n1, true, nil + hn, ok := n.(hashNode) + if !ok { + return nil, false, fmt.Errorf("node is not a hash node") + } + + cachedHash, _ := nub.n1.cache() + if !bytes.Equal(cachedHash, hn) { + return nil, false, fmt.Errorf("hash values does not match") + } + + nub.n1.setEpoch(t.currentEpoch) + if nub.n2 != nil { + nub.n2.setEpoch(t.currentEpoch) + if n1, ok := nub.n1.(*shortNode); ok { // n2 can only be followed by a short node + n1.Val = nub.n2 + } else { + return nil, false, fmt.Errorf("invalid node type") } } - return nil, false, nil - } else if len(key) == 0 && !isExpired { - return nil, false, fmt.Errorf("key %v not found", key) - } else if len(key) != 0 && isExpired { + return nub.n1, true, nil + } + + if isExpired { // the node is expired but targeted node is not reached return nil, false, &ExpiredNodeError{ ExpiredNode: n, Path: path, @@ -942,46 +947,22 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, path []byte, isExp } } - // if len(key) == 0 && isExpired { - // if hashNode, ok := n.(hashNode); ok { - // cachedHash, _ := nub.n1.cache() - // if bytes.Equal(cachedHash, hashNode) { - // if n1, ok := nub.n1.(*fullNode); ok { - // n1.setEpoch(t.currentEpoch) - // } - // if nub.n2 != nil { - // switch n1 := nub.n1.(type) { - // case *shortNode: - // n1.Val = nub.n2 - // default: - // return nil, false, fmt.Errorf("invalid node type") - // } - // } - - // return nub.n1, true, nil - // } - // } - - // return nil, false, nil - // } - switch n := n.(type) { case *shortNode: if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { return nil, false, fmt.Errorf("key %v not found", key) } - newNode, didResolve, err := t.tryRevive(n.Val, key[len(n.Key):], nub, append(path, key[:len(n.Key)]...), isExpired) - if didResolve && err == nil { + newNode, didRevive, err := t.tryRevive(n.Val, key[len(n.Key):], nub, epoch, append(path, key[:len(n.Key)]...), isExpired) + if didRevive && err == nil { n = n.copy() n.Val = newNode } - return n, didResolve, err + return n, didRevive, err case *fullNode: childIndex := int(key[0]) - // TODO (asyukii): Check if child has expired isExpired, _ := n.ChildExpired(nil, childIndex, t.currentEpoch) // TODO (asyukii): t.currentEpoch or t.root.getEpoch()? - newNode, didResolve, err := t.tryRevive(n.Children[childIndex], key[1:], nub, append(path, key[0]), isExpired) - if didResolve && err == nil { + newNode, didRevive, err := t.tryRevive(n.Children[childIndex], key[1:], nub, epoch, append(path, key[0]), isExpired) + if didRevive && err == nil { n = n.copy() n.Children[childIndex] = newNode n.UpdateChildEpoch(childIndex, t.currentEpoch) @@ -989,16 +970,19 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, path []byte, isExp if e, ok := err.(*ExpiredNodeError); ok { e.Epoch = n.GetChildEpoch(childIndex) - return n, didResolve, e + return n, didRevive, e } - return n, didResolve, err + return n, didRevive, err case hashNode: - tn, err := t.resolveHash(n, nil) + tn, err := t.resolveHash(n, path) // TODO(asyukii): may need to copy resolved hash node if err != nil { return nil, false, err } - return t.tryRevive(tn, key, nub, path, isExpired) + if err = t.resolveShadowNode(epoch, tn, path); err != nil { + return nil, false, err + } + return t.tryRevive(tn, key, nub, epoch, path, isExpired) case valueNode: return nil, false, nil case nil: diff --git a/trie/trie_test.go b/trie/trie_test.go index e6c921c2e2..10dc28e4a0 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -620,44 +620,6 @@ func TestTryRevive(t *testing.T) { } } -// TODO (asyukii): Delete this -func TestTestTryRevive(t *testing.T) { - - trie, _ := nonRandomTrie(500) - - oriRootHash := trie.Hash() - - key := []byte{0xf2, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - val := []byte{0xf3, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - prefixKey := []byte{0x0f, 0x02, 0x00, 0x01} - // Generate proof - var proof proofList - err := trie.ProveStorageWitness(key, prefixKey, &proof) - assert.NoError(t, err) - - // Expire trie - trie.ExpireByPrefix(prefixKey) - - // Construct MPTProofCache - proofCache := makeRawMPTProofCache(prefixKey, proof) - - // VerifyProof - err = proofCache.VerifyProof() - assert.NoError(t, err) - - // Revive trie - _, err = trie.TryRevive(proofCache.cacheNubs) - assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x", key, prefixKey) - - // Verify value exists after revive - v := trie.Get(key) - assert.Equal(t, val, v, "value mismatch, got %x, expected %x. key %x, prefixKey %x", v, val, key, prefixKey) - - // Verify root hash - currRootHash := trie.Hash() - assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) -} - // TestTryReviveCustomData tests that a trie can be revived from a proof func TestTryReviveCustomData(t *testing.T) { @@ -708,52 +670,6 @@ func TestTryReviveCustomData(t *testing.T) { } } -// TODO (asyukii): TestReviveAtRoot tests that a key can be revived at root when -// the whole trie is expired. This test will fail because the parent node in -// TryRevive is nil, set to RootNode when available -// func TestReviveAtRoot(t *testing.T) { -// trie, vals := nonRandomTrie(500) - -// oriRootHash := trie.Hash() - -// for _, kv := range vals { -// key := []byte(kv.k) -// val := []byte(kv.v) - -// fmt.Printf("key: %x, val: %x", key, val) -// var proof proofList - -// err := trie.ProveStorageWitness(key, nil, &proof) -// assert.NoError(t, err) - -// // Expire trie -// trie.ExpireByPrefix(nil) - -// // Construct MPTProofCache -// proofCache := makeRawMPTProofCache(nil, proof) - -// // VerifyProof -// err = proofCache.VerifyProof() -// assert.NoError(t, err) - -// // Revive trie -// err = trie.TryRevive(proofCache) -// assert.NoError(t, err) - -// // Verify value exists after revive -// v := trie.Get(key) -// assert.Equal(t, val, v) - -// // Verify root hash -// currRootHash := trie.Hash() -// assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) - -// // Reset trie -// trie, _ = nonRandomTrie(500) -// } - -// } - // TestReviveBadProof tests that a trie cannot be revived from a bad proof func TestReviveBadProof(t *testing.T) { From dc1ebbeb22539716a9009fffaeb2d01ca5636673 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Tue, 9 May 2023 19:36:41 +0800 Subject: [PATCH 37/51] trie/shadownode: support shadow node hash calculation; --- trie/hasher.go | 60 +++++++++++++++++++++ trie/node.go | 19 +++++-- trie/node_test.go | 16 +++--- trie/proof_test.go | 2 +- trie/secure_trie.go | 17 +++--- trie/shadow_node.go | 10 ++-- trie/trie.go | 129 ++++++++++++++++++++++++++++++++++++-------- trie/trie_test.go | 53 +++++++++++++++--- 8 files changed, 247 insertions(+), 59 deletions(-) diff --git a/trie/hasher.go b/trie/hasher.go index e9f45f8341..3b0761caf1 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -19,6 +19,8 @@ package trie import ( "sync" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" @@ -139,6 +141,55 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached return collapsed, cached } +// shadowExtendNodeToHash hash shadowExtendNode +func (h *hasher) shadowExtendNodeToHash(n *shadowExtensionNode) *common.Hash { + if n.ShadowHash == nil { + return nil + } + w := h.encbuf + offset := w.List() + w.WriteBytes(n.ShadowHash[:]) + w.ListEnd(offset) + + enc := h.encodedBytes() + return h.hashCommon(enc) +} + +// shadowFullNodeToHash hash shadowBranchNode +func (h *hasher) shadowBranchNodeToHash(n *shadowBranchNode) *common.Hash { + w := h.encbuf + outerList := w.List() + if n.ShadowHash == nil { + w.WriteBytes(rlp.EmptyString) + } else { + w.WriteBytes(n.ShadowHash[:]) + } + epochList := w.List() + for _, epoch := range n.EpochMap { + w.WriteUint64(uint64(epoch)) + } + w.ListEnd(epochList) + w.ListEnd(outerList) + + enc := h.encodedBytes() + return h.hashCommon(enc) +} + +func (h *hasher) shadowNodeHashListToHash(hashList []*common.Hash) *common.Hash { + if len(hashList) == 0 { + return nil + } + w := h.encbuf + offset := w.List() + for _, hash := range hashList { + w.WriteBytes(hash[:]) + } + w.ListEnd(offset) + + enc := h.encodedBytes() + return h.hashCommon(enc) +} + // shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode // should have hex-type Key, which will be converted (without modification) // into compact form for RLP encoding. @@ -190,6 +241,15 @@ func (h *hasher) hashData(data []byte) hashNode { return n } +// hashCommon hashes the provided data +func (h *hasher) hashCommon(data []byte) *common.Hash { + var n common.Hash + h.sha.Reset() + h.sha.Write(data) + h.sha.Read(n[:]) + return &n +} + // proofHash is used to construct trie proofs, and returns the 'collapsed' // node (for later RLP encoding) aswell as the hashed node -- unless the // node is smaller than 32 bytes, in which case it will be returned as is. diff --git a/trie/node.go b/trie/node.go index a18c3b8dea..8abe7ef860 100644 --- a/trie/node.go +++ b/trie/node.go @@ -71,10 +71,19 @@ type ( ) type rootNode struct { - Epoch types.StateEpoch - TrieHash common.Hash - ShadowHash common.Hash - flags nodeFlag `rlp:"-" json:"-"` + Epoch types.StateEpoch + TrieRoot common.Hash + ShadowTreeRoot common.Hash + flags nodeFlag `rlp:"-" json:"-"` +} + +func newEpoch0RootNode(trieRoot common.Hash) *rootNode { + return &rootNode{ + Epoch: types.StateEpoch0, + TrieRoot: trieRoot, + ShadowTreeRoot: emptyRoot, + flags: nodeFlag{dirty: true}, + } } func (n *rootNode) cache() (hashNode, bool) { @@ -86,7 +95,7 @@ func (n *rootNode) encode(w rlp.EncoderBuffer) { } func (n *rootNode) fstring(s string) string { - return fmt.Sprintf("{%v: %x: %x} ", n.Epoch, n.TrieHash, n.ShadowHash) + return fmt.Sprintf("{%v: %x: %x} ", n.Epoch, n.TrieRoot, n.ShadowTreeRoot) } func (n *rootNode) nodeType() int { diff --git a/trie/node_test.go b/trie/node_test.go index fa308e97f1..62cc2e2874 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -103,18 +103,18 @@ func TestRootNodeEncodeDecode(t *testing.T) { }{ { r: &rootNode{ - Epoch: 100, - TrieHash: makeHash("t1"), - ShadowHash: makeHash("s1"), - flags: nodeFlag{hash: makeHash("h1").Bytes()}, + Epoch: 100, + TrieRoot: makeHash("t1"), + ShadowTreeRoot: makeHash("s1"), + flags: nodeFlag{hash: makeHash("h1").Bytes()}, }, }, { r: &rootNode{ - Epoch: 0, - TrieHash: common.Hash{}, - ShadowHash: common.Hash{}, - flags: nodeFlag{hash: makeHash("h1").Bytes()}, + Epoch: 0, + TrieRoot: common.Hash{}, + ShadowTreeRoot: common.Hash{}, + flags: nodeFlag{hash: makeHash("h1").Bytes()}, }, }, } diff --git a/trie/proof_test.go b/trie/proof_test.go index b05f2ec446..0b914bcef8 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -1273,7 +1273,7 @@ func randBytes(n int) []byte { func nonRandomTrie(n int) (*Trie, map[string]*kv) { trie := new(Trie) - trie.isStorageTrie = true + trie.useShadowTree = true trie.currentEpoch = 10 // TODO (asyukii): might need to change this vals := make(map[string]*kv) max := uint64(0xffffffffffffffff) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index a37ba4129b..d176845c20 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -60,9 +60,9 @@ func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, shadowHash := common.Hash{} if isStorageTrie { if rootNode := db.RootNode(root); rootNode != nil { - root = rootNode.TrieHash + root = rootNode.TrieRoot epoch = rootNode.Epoch - shadowHash = rootNode.ShadowHash + shadowHash = rootNode.ShadowTreeRoot } } trie, err := New(root, db) @@ -73,9 +73,9 @@ func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, if trie.root != nil { trie.root.setEpoch(epoch) } - trie.shadowHash = shadowHash + trie.shadowTreeRoot = shadowHash } - trie.isStorageTrie = isStorageTrie + trie.useShadowTree = isStorageTrie return &SecureTrie{trie: *trie}, nil } @@ -84,18 +84,15 @@ func NewStorageSecure(curEpoch types.StateEpoch, root common.Hash, db *Database, panic("trie.NewSecure called without a database") } - rn := rootNode{ - Epoch: types.StateEpoch0, - TrieHash: root, - } + rn := newEpoch0RootNode(root) hash := common.BytesToHash(root[:]) if n := db.node(hash); n != nil { if tmp, ok := n.(*rootNode); ok { - rn = *tmp + *rn = *tmp } } - trie, err := NewWithShadowNode(curEpoch, &rn, db, sndb) + trie, err := NewWithShadowNode(curEpoch, rn, db, sndb) if err != nil { return nil, err } diff --git a/trie/shadow_node.go b/trie/shadow_node.go index fef68eb88f..55ea9d19c2 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -13,23 +13,21 @@ import ( ) type shadowExtensionNode struct { - ShadowHash common.Hash - Epoch types.StateEpoch + ShadowHash *common.Hash } -func NewShadowExtensionNode(hash common.Hash, epoch types.StateEpoch) shadowExtensionNode { +func NewShadowExtensionNode(hash *common.Hash) shadowExtensionNode { return shadowExtensionNode{ ShadowHash: hash, - Epoch: epoch, } } type shadowBranchNode struct { - ShadowHash common.Hash + ShadowHash *common.Hash EpochMap [16]types.StateEpoch } -func NewShadowBranchNode(hash common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { +func NewShadowBranchNode(hash *common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { return shadowBranchNode{hash, epochMap} } diff --git a/trie/trie.go b/trie/trie.go index 3019df07fb..511a20891f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -66,10 +66,13 @@ type Trie struct { // Keep track of the number leafs which have been inserted since the last // hashing operation. This number will not directly map to the number of // actually unhashed nodes - unhashed int - currentEpoch types.StateEpoch - isStorageTrie bool - shadowHash common.Hash + unhashed int + + // fields for shadow tree and rootNode + useShadowTree bool + currentEpoch types.StateEpoch + shadowTreeRoot common.Hash + rootEpoch types.StateEpoch } // newFlag returns the cache flag value for a newly created node. @@ -104,15 +107,20 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa if db == nil || sndb == nil { panic("trie.New called without a database") } + if curEpoch < rootNode.Epoch { + return nil, errors.New("open trie at a wrong epoch") + } + trie := &Trie{ - db: db, - sndb: sndb, - currentEpoch: curEpoch, - isStorageTrie: true, - shadowHash: rootNode.ShadowHash, + db: db, + sndb: sndb, + currentEpoch: curEpoch, + useShadowTree: false, + shadowTreeRoot: rootNode.ShadowTreeRoot, + rootEpoch: rootNode.Epoch, } - if rootNode.TrieHash != (common.Hash{}) && rootNode.TrieHash != emptyRoot { - root, err := trie.resolveHash(rootNode.TrieHash[:], nil) + if rootNode.TrieRoot != (common.Hash{}) && rootNode.TrieRoot != emptyRoot { + root, err := trie.resolveHash(rootNode.TrieRoot[:], nil) if err != nil { return nil, err } @@ -121,6 +129,10 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa } trie.root = root } + // only enable after first state expiry's hard fork + if curEpoch > types.StateEpoch0 { + trie.useShadowTree = true + } return trie, nil } @@ -146,7 +158,7 @@ func (t *Trie) Get(key []byte) []byte { func (t *Trie) TryGet(key []byte) (value []byte, err error) { var newroot node var didResolve bool - if t.isStorageTrie { + if t.useShadowTree { var nextEpoch types.StateEpoch if t.root != nil { nextEpoch = t.root.getEpoch() @@ -413,7 +425,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } switch n := n.(type) { case *shortNode: - if t.isStorageTrie && t.currentEpoch >= 2 { + if t.useShadowTree && t.currentEpoch >= 2 { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -436,7 +448,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - if t.isStorageTrie { + if t.useShadowTree { branch.setEpoch(t.currentEpoch) branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) } @@ -444,7 +456,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - if t.isStorageTrie { + if t.useShadowTree { branch.setEpoch(t.currentEpoch) branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) } @@ -456,7 +468,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil case *fullNode: - if t.isStorageTrie { + if t.useShadowTree { if t.currentEpoch >= 2 { // this full node is expired, return err if expired, err := t.nodeExpired(n, prefix); expired { @@ -546,7 +558,7 @@ func (t *Trie) TryDelete(key []byte) error { func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { switch n := n.(type) { case *shortNode: - if t.isStorageTrie && t.currentEpoch >= 2 { + if t.useShadowTree && t.currentEpoch >= 2 { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -581,7 +593,7 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, } case *fullNode: - if t.isStorageTrie { + if t.useShadowTree { if t.currentEpoch >= 2 { // this full node is expired, return err if expired, err := t.nodeExpired(n, prefix); expired { @@ -813,6 +825,19 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { // Derive the hash for all dirty nodes first. We hold the assumption // in the following procedure that all nodes are hashed. rootHash := t.Hash() + if t.useShadowTree { + shadowRoot, err := t.ShadowHash() + if err != nil { + return common.Hash{}, 0, err + } + // replace shadowTreeRoot for rootNode + if shadowRoot != nil { + t.shadowTreeRoot = *shadowRoot + } else { + t.shadowTreeRoot = emptyRoot + } + } + h := newCommitter() defer returnCommitterToPool(h) @@ -993,7 +1018,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE } func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte) error { - if t.currentEpoch < 1 { + if !t.useShadowTree { return nil } @@ -1003,8 +1028,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte switch n := cur.(type) { case *shortNode: - n.shadowNode.Epoch = epoch - n.shadowNode.ShadowHash = common.Hash{} + n.shadowNode.ShadowHash = nil return t.resolveShadowNode(epoch, n.Val, append(prefix, n.Key...)) case *fullNode: val, err := t.sndb.Get(string(hexToSuffixCompact(prefix))) @@ -1014,7 +1038,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte if len(val) == 0 { // set default epoch map n.shadowNode.EpochMap = [16]types.StateEpoch{} - n.shadowNode.ShadowHash = common.Hash{} + n.shadowNode.ShadowHash = nil } else { if err = rlp.DecodeBytes(val, &n.shadowNode); err != nil { return err @@ -1026,7 +1050,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte } } return nil - case valueNode, hashNode: + case valueNode, hashNode, nil: // just skip return nil default: @@ -1034,6 +1058,65 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte } } +func (t *Trie) ShadowHash() (*common.Hash, error) { + if t.root == nil { + return nil, nil + } + h := newHasher(true) + defer returnHasherToPool(h) + return t.shadowHash(t.root, h, nil, t.root.getEpoch()) +} + +// shadowHash calculate node's shadow node hash, recalculate needn't a copy +func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.StateEpoch) (*common.Hash, error) { + switch n := origin.(type) { + case *shortNode: + var err error + if n.shadowNode.ShadowHash, err = t.shadowHash(n.Val, h, append(prefix, n.Key...), epoch); err != nil { + return nil, err + } + return h.shadowExtendNodeToHash(&n.shadowNode), nil + case *fullNode: + epochSelf := n.epoch + epochMap := n.shadowNode.EpochMap + hashList := make([]*common.Hash, 0, BranchNodeLength-1) + for i := byte(0); i < BranchNodeLength-1; i++ { + child := n.Children[i] + if child == nil { + continue + } + // skip expired node. + if epochSelf >= epochMap[i]+2 { + continue + } + + subHash, err := t.shadowHash(child, h, append(prefix, i), epochMap[i]) + if err != nil { + return nil, err + } + if subHash != nil { + hashList = append(hashList, subHash) + } + } + n.shadowNode.ShadowHash = h.shadowNodeHashListToHash(hashList) + return h.shadowBranchNodeToHash(&n.shadowNode), nil + case valueNode: + return nil, nil + case hashNode: + // resolve temporary, not add to trie + rn, err := t.resolveHash(n, prefix) + if err != nil { + return nil, err + } + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return nil, err + } + return t.shadowHash(rn, h, prefix, epoch) + default: + return nil, errors.New("cannot get shortNode's child shadow node") + } +} + // SUFFIX-COMPACT encoding is used for encoding trie node path in the trie node // storage key. The main difference with COMPACT encoding is that the key flag // is put at the end of the key. diff --git a/trie/trie_test.go b/trie/trie_test.go index 10dc28e4a0..f14cf607fb 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -851,12 +851,7 @@ func TestReviveValueAtFullNode(t *testing.T) { } func TestTrie_ShadowNodeRW(t *testing.T) { - diskdb := memorydb.New() - database := NewDatabase(diskdb) - tree, err := NewShadowNodeSnapTree(diskdb) - assert.NoError(t, err) - storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) - assert.NoError(t, err) + database, storageDB := makeStorageTrieDatabase(t) tr, err := NewStorageSecure(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) assert.NoError(t, err) @@ -887,6 +882,52 @@ func TestTrie_ShadowNodeRW(t *testing.T) { //assert.Error(t, err) } +func TestTrie_ShadowHash(t *testing.T) { + database, storageDB := makeStorageTrieDatabase(t) + tr, err := NewWithShadowNode(types.StateEpoch0, newEpoch0RootNode(emptyRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + + batchUpdateTrie(t, tr, []string{"a711355", "450", "a77d337", "100", "a7f9365", "110", "a77d397", "012"}) + sh1, err := tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0xdb27cdd3bb63fff59572d38425e82df16706590fe1daa704d7b66a6171b38216"), *sh1) + + // commit and shadow hash again + newRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + tr, err = NewWithShadowNode(types.StateEpoch(1), newEpoch0RootNode(newRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + sh1, err = tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0xdb27cdd3bb63fff59572d38425e82df16706590fe1daa704d7b66a6171b38216"), *sh1) + + err = tr.TryUpdate(common.Hex2Bytes("a711355"), common.Hex2Bytes("800")) + assert.NoError(t, err) + sh1, err = tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0x5afa84184d90b266a63ac857641f4679a10817bdc48f66c91150c30d92ff9e1b"), *sh1) +} + +func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { + if len(kvs)%2 != 0 { + panic("wrong kvs") + } + for i := 0; i < len(kvs); i += 2 { + err := tr.TryUpdate(common.Hex2Bytes(kvs[i]), common.Hex2Bytes(kvs[i+1])) + assert.NoError(t, err) + } +} + +func makeStorageTrieDatabase(t *testing.T) (*Database, ShadowNodeDatabase) { + diskdb := memorydb.New() + database := NewDatabase(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb) + assert.NoError(t, err) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) + assert.NoError(t, err) + return database, storageDB +} + func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } From ed0011668f673ccb1bcddc735ebcd10f32fb53b4 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Wed, 10 May 2023 17:26:10 +0800 Subject: [PATCH 38/51] trie: fix some epoch update bugs; trie/shadownode: support commit root node & shadow nodes; --- core/state/database.go | 2 +- trie/hasher.go | 20 +----- trie/node.go | 54 +-------------- trie/node_test.go | 34 ---------- trie/secure_trie.go | 11 ++-- trie/shadow_node.go | 123 ++++++++++++++++++++++++++++++++++ trie/shadow_node_test.go | 120 +++++++++++++++++++++++++++++++++ trie/trie.go | 139 +++++++++++++++++++++++++-------------- trie/trie_test.go | 113 +++++++++++++++++++++++-------- 9 files changed, 426 insertions(+), 190 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index a90d03ff8f..137e694350 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -270,7 +270,7 @@ func (db *cachingDB) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, c } } - tr, err := trie.NewStorageSecure(curEpoch, root, db.db, sndb) + tr, err := trie.NewSecureWithShadowNodes(curEpoch, root, db.db, sndb) if err != nil { return nil, err } diff --git a/trie/hasher.go b/trie/hasher.go index 3b0761caf1..589f4417a5 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -146,31 +146,15 @@ func (h *hasher) shadowExtendNodeToHash(n *shadowExtensionNode) *common.Hash { if n.ShadowHash == nil { return nil } - w := h.encbuf - offset := w.List() - w.WriteBytes(n.ShadowHash[:]) - w.ListEnd(offset) + n.encode(h.encbuf) enc := h.encodedBytes() return h.hashCommon(enc) } // shadowFullNodeToHash hash shadowBranchNode func (h *hasher) shadowBranchNodeToHash(n *shadowBranchNode) *common.Hash { - w := h.encbuf - outerList := w.List() - if n.ShadowHash == nil { - w.WriteBytes(rlp.EmptyString) - } else { - w.WriteBytes(n.ShadowHash[:]) - } - epochList := w.List() - for _, epoch := range n.EpochMap { - w.WriteUint64(uint64(epoch)) - } - w.ListEnd(epochList) - w.ListEnd(outerList) - + n.encode(h.encbuf) enc := h.encodedBytes() return h.hashCommon(enc) } diff --git a/trie/node.go b/trie/node.go index 8abe7ef860..4ead424510 100644 --- a/trie/node.go +++ b/trie/node.go @@ -38,7 +38,6 @@ const ( rawNodeType rawShortNodeType rawFullNodeType - rootNodeType ) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} @@ -70,46 +69,6 @@ type ( valueNode []byte ) -type rootNode struct { - Epoch types.StateEpoch - TrieRoot common.Hash - ShadowTreeRoot common.Hash - flags nodeFlag `rlp:"-" json:"-"` -} - -func newEpoch0RootNode(trieRoot common.Hash) *rootNode { - return &rootNode{ - Epoch: types.StateEpoch0, - TrieRoot: trieRoot, - ShadowTreeRoot: emptyRoot, - flags: nodeFlag{dirty: true}, - } -} - -func (n *rootNode) cache() (hashNode, bool) { - return n.flags.hash, n.flags.dirty -} - -func (n *rootNode) encode(w rlp.EncoderBuffer) { - rlp.Encode(w, n) -} - -func (n *rootNode) fstring(s string) string { - return fmt.Sprintf("{%v: %x: %x} ", n.Epoch, n.TrieRoot, n.ShadowTreeRoot) -} - -func (n *rootNode) nodeType() int { - return rootNodeType -} - -func (n *rootNode) setEpoch(epoch types.StateEpoch) { - n.Epoch = epoch -} - -func (n *rootNode) getEpoch() types.StateEpoch { - return n.Epoch -} - // nilValueNode is used when collapsing internal trie nodes for hashing, since // unset children need to serialize correctly. var nilValueNode = valueNode(nil) @@ -140,7 +99,7 @@ func (n *fullNode) UpdateChildEpoch(index int, epoch types.StateEpoch) { func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.StateEpoch) (bool, error) { childEpoch := n.GetChildEpoch(index) - if currentEpoch-childEpoch >= 2 { + if types.EpochExpired(childEpoch, currentEpoch) { return true, &ExpiredNodeError{ ExpiredNode: n.Children[index], Path: prefix, @@ -264,9 +223,6 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { case 2: n, err := decodeShort(hash, elems) return n, wrapError(err, "short") - case 3: - n, err := decodeRootNode(hash, buf) - return n, wrapError(err, "root node") case 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") @@ -300,14 +256,6 @@ func decodeShort(hash, elems []byte) (node, error) { return n, nil } -func decodeRootNode(hash, elems []byte) (node, error) { - n := &rootNode{flags: nodeFlag{hash: hash}} - if err := rlp.DecodeBytes(elems, n); err != nil { - return nil, err - } - return n, nil -} - func decodeFull(hash, elems []byte) (*fullNode, error) { n := &fullNode{flags: nodeFlag{hash: hash}} for i := 0; i < 16; i++ { diff --git a/trie/node_test.go b/trie/node_test.go index 62cc2e2874..d52b0cee24 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -20,9 +20,6 @@ import ( "bytes" "testing" - "github.com/ethereum/go-ethereum/common" - "github.com/stretchr/testify/assert" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) @@ -97,37 +94,6 @@ func TestDecodeFullNode(t *testing.T) { } } -func TestRootNodeEncodeDecode(t *testing.T) { - datas := []struct { - r *rootNode - }{ - { - r: &rootNode{ - Epoch: 100, - TrieRoot: makeHash("t1"), - ShadowTreeRoot: makeHash("s1"), - flags: nodeFlag{hash: makeHash("h1").Bytes()}, - }, - }, - { - r: &rootNode{ - Epoch: 0, - TrieRoot: common.Hash{}, - ShadowTreeRoot: common.Hash{}, - flags: nodeFlag{hash: makeHash("h1").Bytes()}, - }, - }, - } - - for _, item := range datas { - buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) - item.r.encode(buf) - p, err := decodeNode(makeHash("h1").Bytes(), buf.ToBytes()) - assert.NoError(t, err) - assert.Equal(t, item.r, p) - } -} - // goos: darwin // goarch: arm64 // pkg: github.com/ethereum/go-ethereum/trie diff --git a/trie/secure_trie.go b/trie/secure_trie.go index d176845c20..4f52ad648e 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -79,17 +79,14 @@ func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, return &SecureTrie{trie: *trie}, nil } -func NewStorageSecure(curEpoch types.StateEpoch, root common.Hash, db *Database, sndb ShadowNodeStorage) (*SecureTrie, error) { +func NewSecureWithShadowNodes(curEpoch types.StateEpoch, root common.Hash, db *Database, sndb ShadowNodeStorage) (*SecureTrie, error) { if db == nil || sndb == nil { panic("trie.NewSecure called without a database") } - rn := newEpoch0RootNode(root) - hash := common.BytesToHash(root[:]) - if n := db.node(hash); n != nil { - if tmp, ok := n.(*rootNode); ok { - *rn = *tmp - } + rn, err := resolveRootNode(sndb, root) + if err != nil { + return nil, err } trie, err := NewWithShadowNode(curEpoch, rn, db, sndb) diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 55ea9d19c2..5cde68966c 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -6,12 +6,64 @@ import ( "math/big" "sync" + "github.com/ethereum/go-ethereum/rlp" + "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" ) +const ( + ShadowTreeRootNodePath = "s" +) + +type rootNode struct { + Epoch types.StateEpoch + TrieRoot common.Hash + ShadowTreeRoot common.Hash + cachedHash common.Hash `rlp:"-" json:"-"` + cachedEnc []byte `rlp:"-" json:"-"` +} + +func newEpoch0RootNode(trieRoot common.Hash) *rootNode { + return newRootNode(types.StateEpoch0, trieRoot, emptyRoot) +} + +func newRootNode(epoch types.StateEpoch, trieRoot, shadowTreeRoot common.Hash) *rootNode { + n := &rootNode{ + Epoch: epoch, + TrieRoot: trieRoot, + ShadowTreeRoot: shadowTreeRoot, + } + n.resolveCache() + return n +} +func (n *rootNode) encode(w rlp.EncoderBuffer) { + rlp.Encode(w, n) +} +func (n *rootNode) resolveCache() { + buf := rlp.NewEncoderBuffer(nil) + n.encode(buf) + n.cachedEnc = buf.ToBytes() + + // cache hash + h := newHasher(false) + h.sha.Reset() + h.sha.Write(n.cachedEnc) + h.sha.Read(n.cachedHash[:]) + returnHasherToPool(h) +} + +func decodeRootNode(enc []byte) (*rootNode, error) { + n := &rootNode{} + if err := rlp.DecodeBytes(enc, n); err != nil { + return nil, err + } + n.resolveCache() + return n, nil +} + type shadowExtensionNode struct { ShadowHash *common.Hash } @@ -22,6 +74,37 @@ func NewShadowExtensionNode(hash *common.Hash) shadowExtensionNode { } } +func (n *shadowExtensionNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + if n.ShadowHash == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteBytes(n.ShadowHash[:]) + } + w.ListEnd(offset) +} + +func decodeShadowExtensionNode(enc []byte) (*shadowExtensionNode, error) { + var n shadowExtensionNode + elems, _, err := rlp.SplitList(enc) + if err != nil { + return nil, err + } + + sh, _, err := rlp.SplitString(elems) + if err != nil { + return nil, err + } + + if len(sh) == 0 { + n.ShadowHash = nil + } else { + hash := common.BytesToHash(sh) + n.ShadowHash = &hash + } + return &n, nil +} + type shadowBranchNode struct { ShadowHash *common.Hash EpochMap [16]types.StateEpoch @@ -30,6 +113,46 @@ type shadowBranchNode struct { func NewShadowBranchNode(hash *common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { return shadowBranchNode{hash, epochMap} } +func (n *shadowBranchNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + if n.ShadowHash == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteBytes(n.ShadowHash[:]) + } + epochsList := w.List() + for _, e := range n.EpochMap { + w.WriteUint64(uint64(e)) + } + w.ListEnd(epochsList) + w.ListEnd(offset) +} + +func decodeShadowBranchNode(enc []byte) (*shadowBranchNode, error) { + var n shadowBranchNode + elems, _, err := rlp.SplitList(enc) + if err != nil { + return nil, err + } + + sh, rest, err := rlp.SplitString(elems) + if err != nil { + return nil, err + } + + if len(sh) == 0 { + n.ShadowHash = nil + } else { + hash := common.BytesToHash(sh) + n.ShadowHash = &hash + } + + if err = rlp.DecodeBytes(rest, &n.EpochMap); err != nil { + return nil, err + } + + return &n, nil +} type ShadowNodeStorage interface { // Get key is the shadow node prefix path diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go index 2d8853d941..1626635f94 100644 --- a/trie/shadow_node_test.go +++ b/trie/shadow_node_test.go @@ -1,9 +1,12 @@ package trie import ( + "bytes" "math/big" "testing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/rlp" @@ -123,3 +126,120 @@ func TestNewShadowNodeStorage4Trie(t *testing.T) { err = storageDB.Commit(common.Big1, blockRoot2) assert.NoError(t, err) } + +func TestShadowExtendNode_encodeDecode(t *testing.T) { + dt := []struct { + n shadowExtensionNode + }{ + { + n: shadowExtensionNode{ + ShadowHash: nil, + }, + }, + { + n: shadowExtensionNode{ + ShadowHash: &blockRoot0, + }, + }, + { + n: shadowExtensionNode{ + ShadowHash: &blockRoot1, + }, + }, + } + for _, item := range dt { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := decodeShadowExtensionNode(enc) + assert.NoError(t, err) + assert.Equal(t, &item.n, rn) + } +} + +func TestShadowBranchNode_encodeDecode(t *testing.T) { + dt := []struct { + n shadowBranchNode + }{ + { + n: shadowBranchNode{ + ShadowHash: nil, + EpochMap: [16]types.StateEpoch{}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: nil, + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: &blockRoot0, + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: &blockRoot1, + EpochMap: [16]types.StateEpoch{}, + }, + }, + } + for _, item := range dt { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := decodeShadowBranchNode(enc) + assert.NoError(t, err) + assert.Equal(t, &item.n, rn) + } +} + +func TestRootNode_encodeDecode(t *testing.T) { + dt := []struct { + n rootNode + isEqual bool + }{ + { + n: rootNode{ + Epoch: 10, + TrieRoot: blockRoot0, + ShadowTreeRoot: blockRoot1, + }, + isEqual: true, + }, + { + n: rootNode{}, + isEqual: true, + }, + { + n: rootNode{ + Epoch: 100, + TrieRoot: blockRoot2, + ShadowTreeRoot: common.Hash{}, + }, + isEqual: true, + }, + { + n: rootNode{}, + }, + } + + for _, item := range dt { + item.n.resolveCache() + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := decodeRootNode(enc) + assert.NoError(t, err) + if !item.isEqual { + assert.NotEqual(t, item.n, rn) + continue + } + assert.Equal(t, &item.n, rn) + } +} diff --git a/trie/trie.go b/trie/trie.go index 511a20891f..65d80ba525 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -111,11 +111,17 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa return nil, errors.New("open trie at a wrong epoch") } + useShadowTree := false + // only enable after first state expiry's hard fork + if curEpoch > types.StateEpoch0 { + useShadowTree = true + } + trie := &Trie{ db: db, sndb: sndb, currentEpoch: curEpoch, - useShadowTree: false, + useShadowTree: useShadowTree, shadowTreeRoot: rootNode.ShadowTreeRoot, rootEpoch: rootNode.Epoch, } @@ -129,10 +135,6 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa } trie.root = root } - // only enable after first state expiry's hard fork - if curEpoch > types.StateEpoch0 { - trie.useShadowTree = true - } return trie, nil } @@ -425,7 +427,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } switch n := n.(type) { case *shortNode: - if t.useShadowTree && t.currentEpoch >= 2 { + if t.useShadowTree { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -469,25 +471,17 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE case *fullNode: if t.useShadowTree { - if t.currentEpoch >= 2 { - // this full node is expired, return err - if expired, err := t.nodeExpired(n, prefix); expired { - return false, n, err - } + // this full node is expired, return err + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) - if t.currentEpoch >= 2 { - child := n.Children[key[0]] - childKey := key[1:] - // if inserting a new node to this full node, there is no need to check whether this child is expired. - if len(childKey) != 0 { - if child != nil { - // if child is expired, return err - if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { - return false, n.Children[key[0]], err - } - } + // if inserting a new node to this full node, there is no need to check whether this child is expired. + if len(key) > 0 && n.Children[key[0]] != nil { + // if child is expired, return err + if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { + return false, n.Children[key[0]], err } } // else, set child node's epoch to current epoch @@ -558,7 +552,7 @@ func (t *Trie) TryDelete(key []byte) error { func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { switch n := n.(type) { case *shortNode: - if t.useShadowTree && t.currentEpoch >= 2 { + if t.useShadowTree { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -594,19 +588,15 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, case *fullNode: if t.useShadowTree { - if t.currentEpoch >= 2 { - // this full node is expired, return err - if expired, err := t.nodeExpired(n, prefix); expired { - return false, n, err - } + // this full node is expired, return err + if expired, err := t.nodeExpired(n, prefix); expired { + return false, n, err } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) - if t.currentEpoch >= 2 { - // if child is expired, return err - if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { - return false, n.Children[key[0]], err - } + // if child is expired, return err + if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { + return false, n.Children[key[0]], err } // else, set child node's epoch to current epoch n.UpdateChildEpoch(int(key[0]), t.currentEpoch) @@ -661,12 +651,12 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, } if cnode, ok := cnode.(*shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag()}, nil + return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag()}, nil + return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag(), epoch: t.currentEpoch}, nil } // n still contains at least two values and cannot be reduced. return true, n, nil @@ -795,7 +785,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { } func (t *Trie) nodeExpired(n node, prefix []byte) (bool, error) { - if t.currentEpoch-n.getEpoch() >= 2 { + if types.EpochExpired(n.getEpoch(), t.currentEpoch) { return true, &ExpiredNodeError{ ExpiredNode: n, Path: prefix, @@ -824,17 +814,16 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { } // Derive the hash for all dirty nodes first. We hold the assumption // in the following procedure that all nodes are hashed. - rootHash := t.Hash() + newRootHash := t.Hash() + newShadowTreeRoot := emptyRoot if t.useShadowTree { - shadowRoot, err := t.ShadowHash() + shadowTreeRoot, err := t.ShadowHash() if err != nil { return common.Hash{}, 0, err } // replace shadowTreeRoot for rootNode - if shadowRoot != nil { - t.shadowTreeRoot = *shadowRoot - } else { - t.shadowTreeRoot = emptyRoot + if shadowTreeRoot != nil { + newShadowTreeRoot = *shadowTreeRoot } } @@ -845,7 +834,14 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { // up goroutines. This can happen e.g. if we load a trie for reading storage // values, but don't write to it. if _, dirty := t.root.cache(); !dirty { - return rootHash, 0, nil + if t.useShadowTree { + rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) + if err != nil { + return common.Hash{}, 0, err + } + return rootNodeHash, 0, nil + } + return newRootHash, 0, nil } var wg sync.WaitGroup if onleaf != nil { @@ -869,8 +865,16 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if err != nil { return common.Hash{}, 0, err } + if t.useShadowTree { + rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) + if err != nil { + return common.Hash{}, 0, err + } + t.root = newRoot + return rootNodeHash, committed, nil + } t.root = newRoot - return rootHash, committed, nil + return newRootHash, committed, nil } // hashRoot calculates the root hash of the given trie @@ -1017,7 +1021,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE } } -func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte) error { +func (t *Trie) resolveShadowNode(epoch types.StateEpoch, origin node, prefix []byte) error { if !t.useShadowTree { return nil } @@ -1026,11 +1030,13 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte return errors.New("cannot resolve shadow node") } - switch n := cur.(type) { + switch n := origin.(type) { case *shortNode: + n.setEpoch(epoch) n.shadowNode.ShadowHash = nil return t.resolveShadowNode(epoch, n.Val, append(prefix, n.Key...)) case *fullNode: + n.setEpoch(epoch) val, err := t.sndb.Get(string(hexToSuffixCompact(prefix))) if err != nil { return err @@ -1040,9 +1046,11 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, cur node, prefix []byte n.shadowNode.EpochMap = [16]types.StateEpoch{} n.shadowNode.ShadowHash = nil } else { - if err = rlp.DecodeBytes(val, &n.shadowNode); err != nil { - return err + tmp, decErr := decodeShadowBranchNode(val) + if decErr != nil { + return decErr } + n.shadowNode = *tmp } for i := byte(0); i < BranchNodeLength-1; i++ { if err := t.resolveShadowNode(n.shadowNode.EpochMap[i], n.Children[i], append(prefix, i)); err != nil { @@ -1099,6 +1107,12 @@ func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.Sta } } n.shadowNode.ShadowHash = h.shadowNodeHashListToHash(hashList) + // TODO(0xbundler): just save shadowNode, will revert in later cryyl version + encBuf := rlp.NewEncoderBuffer(nil) + n.shadowNode.encode(encBuf) + if err := t.sndb.Put(string(hexToSuffixCompact(prefix)), encBuf.ToBytes()); err != nil { + return nil, err + } return h.shadowBranchNodeToHash(&n.shadowNode), nil case valueNode: return nil, nil @@ -1117,6 +1131,35 @@ func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.Sta } } +func (t *Trie) storeRootNode(newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { + rn := newRootNode(t.root.getEpoch(), newRootHash, newShadowTreeRoot) + if err := t.sndb.Put(ShadowTreeRootNodePath, rn.cachedEnc); err != nil { + return common.Hash{}, err + } + return rn.cachedHash, nil +} + +func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { + + expectHash := common.BytesToHash(root[:]) + val, err := sndb.Get(ShadowTreeRootNodePath) + if err != nil { + return nil, err + } + if len(val) == 0 { + return newEpoch0RootNode(root), nil + } + n, err := decodeRootNode(val) + if err != nil { + return nil, err + } + + if n.cachedHash != expectHash { + return nil, errors.New("found the wrong rootNode") + } + return n, nil +} + // SUFFIX-COMPACT encoding is used for encoding trie node path in the trie node // storage key. The main difference with COMPACT encoding is that the key flag // is put at the end of the key. diff --git a/trie/trie_test.go b/trie/trie_test.go index f14cf607fb..f66f9213be 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -850,10 +850,12 @@ func TestReviveValueAtFullNode(t *testing.T) { } } -func TestTrie_ShadowNodeRW(t *testing.T) { - database, storageDB := makeStorageTrieDatabase(t) +func TestTrie_ShadowNodeRW_expired(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) + assert.NoError(t, err) - tr, err := NewStorageSecure(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) + tr, err := NewSecureWithShadowNodes(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) assert.NoError(t, err) tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) @@ -862,50 +864,105 @@ func TestTrie_ShadowNodeRW(t *testing.T) { assert.Equal(t, makeHash("v1").Bytes(), val) assert.NoError(t, tr.TryDelete(makeHash("k1").Bytes())) - // TODO(0xbundler): need MPT support commit with shadow node - //nextRoot, _, err := tr.Commit(nil) - //assert.NoError(t, err) - //assert.NoError(t, storageDB.Commit(common.Big1, blockRoot1)) - - // reload - //storageDB, err = NewShadowNodestorageDB(tree, blockRoot1) - //assert.NoError(t, err) - //tr, err = NewStorageSecure(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) - //assert.NoError(t, err) - //val, err = tr.TryGet(makeHash("k2").Bytes()) - //assert.NoError(t, err) - //assert.Equal(t, makeHash("v2").Bytes(), val) - //// check expired - //tr, err = NewStorageSecure(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) - //assert.NoError(t, err) - //val, err = tr.TryGet(makeHash("k2").Bytes()) - //assert.Error(t, err) + // commit + nextRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big1, nextRoot)) + + // reload in epoch2 + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + val, err = tr.TryGet(makeHash("k2").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v2").Bytes(), val) + + // reload in epoch3, check expired + tr, err = NewSecureWithShadowNodes(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + _, err = tr.TryGet(makeHash("k2").Bytes()) + assert.Error(t, err) +} + +func TestTrie_ShadowNodeRW_accessStates(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) + assert.NoError(t, err) + + tr, err := NewSecureWithShadowNodes(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) + tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) + val, err := tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + + // commit + nextRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big1, nextRoot)) + + // reload in epoch2, access K1, k2 + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + err = tr.TryUpdateEpoch(makeHash("k1").Bytes()) + assert.NoError(t, err) + err = tr.TryUpdateEpoch(makeHash("k2").Bytes()) + assert.NoError(t, err) + + // commit + nextRoot, _, err = tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big2, nextRoot)) + + // reload in epoch3 + storageDB, err = NewShadowNodeDatabase(tree, common.Big2, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + val, err = tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + val, err = tr.TryGet(makeHash("k2").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v2").Bytes(), val) } func TestTrie_ShadowHash(t *testing.T) { - database, storageDB := makeStorageTrieDatabase(t) + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, emptyRoot) + assert.NoError(t, err) + tr, err := NewWithShadowNode(types.StateEpoch0, newEpoch0RootNode(emptyRoot), database, storageDB.OpenStorage(contract1)) assert.NoError(t, err) batchUpdateTrie(t, tr, []string{"a711355", "450", "a77d337", "100", "a7f9365", "110", "a77d397", "012"}) sh1, err := tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0xdb27cdd3bb63fff59572d38425e82df16706590fe1daa704d7b66a6171b38216"), *sh1) + assert.Equal(t, common.HexToHash("0x73325476298d27129c8b8d64e8d0abd66d6cc26601c9a012304170432ad3a00d"), *sh1) // commit and shadow hash again newRoot, _, err := tr.Commit(nil) assert.NoError(t, err) + err = storageDB.Commit(common.Big1, newRoot) + assert.NoError(t, err) + + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, newRoot) + assert.NoError(t, err) tr, err = NewWithShadowNode(types.StateEpoch(1), newEpoch0RootNode(newRoot), database, storageDB.OpenStorage(contract1)) assert.NoError(t, err) sh1, err = tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0xdb27cdd3bb63fff59572d38425e82df16706590fe1daa704d7b66a6171b38216"), *sh1) + assert.Equal(t, common.HexToHash("0x73325476298d27129c8b8d64e8d0abd66d6cc26601c9a012304170432ad3a00d"), *sh1) err = tr.TryUpdate(common.Hex2Bytes("a711355"), common.Hex2Bytes("800")) assert.NoError(t, err) sh1, err = tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0x5afa84184d90b266a63ac857641f4679a10817bdc48f66c91150c30d92ff9e1b"), *sh1) + assert.Equal(t, common.HexToHash("0x3e17653305330721f2a9792b510247a411efedf8d687fa3fb1ae32d2c4325511"), *sh1) } func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { @@ -918,14 +975,12 @@ func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { } } -func makeStorageTrieDatabase(t *testing.T) (*Database, ShadowNodeDatabase) { +func makeStorageTrieDatabase(t *testing.T) (*Database, *ShadowNodeSnapTree) { diskdb := memorydb.New() database := NewDatabase(diskdb) tree, err := NewShadowNodeSnapTree(diskdb) assert.NoError(t, err) - storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) - assert.NoError(t, err) - return database, storageDB + return database, tree } func BenchmarkGet(b *testing.B) { benchGet(b, false) } From c8145b39f45b1c4dab54131993c015f92e92ca29 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Tue, 9 May 2023 23:00:26 +0800 Subject: [PATCH 39/51] state/statedb: opt state epoch check and logic; fix: fix some broken UTs; parlia: add state epoch config; --- consensus/parlia/parlia.go | 9 +++++++-- core/chain_makers.go | 11 +++++++---- core/state/database.go | 2 +- core/state/state_object.go | 36 +++++++++++++++++++++++++++--------- core/state/statedb.go | 12 ++++++++++++ core/types/state_epoch.go | 8 ++++---- params/config.go | 20 ++++++++++++-------- trie/secure_trie.go | 2 ++ trie/trie.go | 3 +++ 9 files changed, 75 insertions(+), 28 deletions(-) diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index b1abc1221f..fb338d4ebc 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -46,8 +46,9 @@ const ( inMemorySnapshots = 128 // Number of recent snapshots to keep in memory inMemorySignatures = 4096 // Number of recent block signatures to keep in memory - checkpointInterval = 1024 // Number of blocks after which to save the snapshot to the database - defaultEpochLength = uint64(100) // Default number of blocks of checkpoint to update validatorSet from contract + checkpointInterval = 1024 // Number of blocks after which to save the snapshot to the database + defaultEpochLength = uint64(100) // Default number of blocks of checkpoint to update validatorSet from contract + defaultStateEpochPeriod = uint64(7_008_000) extraVanity = 32 // Fixed number of extra-data prefix bytes reserved for signer vanity extraSeal = 65 // Fixed number of extra-data suffix bytes reserved for signer seal @@ -230,6 +231,10 @@ func New( if parliaConfig != nil && parliaConfig.Epoch == 0 { parliaConfig.Epoch = defaultEpochLength } + if parliaConfig != nil && parliaConfig.StateEpochPeriod == 0 { + parliaConfig.StateEpochPeriod = defaultStateEpochPeriod + } + log.Info("instance parlia with config", "period", parliaConfig.Period, "epoch", parliaConfig.Epoch, "stateEpochPeriod", parliaConfig.StateEpochPeriod) // Allocate the snapshot caches and create the engine recentSnaps, err := lru.NewARC(inMemorySnapshots) diff --git a/core/chain_makers.go b/core/chain_makers.go index 24b103eb6e..4272bd6b38 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -276,12 +276,12 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse } return nil, nil } + tree, err := trie.NewShadowNodeSnapTree(db) + if err != nil { + panic(err) + } for i := 0; i < n; i++ { number := new(big.Int).Add(parent.Number(), common.Big1) - tree, err := trie.NewShadowNodeSnapTree(db) - if err != nil { - panic(err) - } statedb, err := state.NewWithEpoch(config, number, parent.Root(), state.NewDatabase(db), nil, tree) if err != nil { panic(err) @@ -291,6 +291,9 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse receipts[i] = receipt parent = block } + if err = tree.Journal(); err != nil { + panic(err) + } return blocks, receipts } diff --git a/core/state/database.go b/core/state/database.go index 137e694350..00b407c799 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -247,7 +247,7 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { } } - tr, err := trie.NewSecure(root, db.db, true) + tr, err := trie.NewSecure(root, db.db, false) if err != nil { return nil, err } diff --git a/core/state/state_object.go b/core/state/state_object.go index 7cda8c0d3e..8916a261f2 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -194,13 +194,24 @@ func (s *StateObject) getTrie(db Database) Trie { // prefetcher s.trie = prefetcher.trie(s.data.Root) } - if s.trie == nil { - var err error + if s.trie != nil { + return s.trie + } + var err error + // check if enable state epoch + if s.db.enableStateEpoch(false) { s.trie, err = db.OpenStorageTrieWithShadowNode(s.addrHash, s.data.Root, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) if err != nil { s.trie, _ = db.OpenStorageTrieWithShadowNode(s.addrHash, common.Hash{}, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) - s.setError(fmt.Errorf("can't create storage trie: %v", err)) + s.setError(fmt.Errorf("can't create storage trie with shadowNode: %v", err)) } + return s.trie + } + + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + if err != nil { + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } return s.trie @@ -231,9 +242,11 @@ func (s *StateObject) GetState(db Database, key common.Hash) (common.Hash, error s.accessState(key) return value, nil } - if revived, revive := s.queryFromReviveState(db, s.dirtyReviveState, key); revive { - s.accessState(key) - return revived, nil + if s.db.enableStateEpoch(true) { + if revived, revive := s.queryFromReviveState(db, s.dirtyReviveState, key); revive { + s.accessState(key) + return revived, nil + } } // Otherwise return the entry's original value @@ -278,8 +291,10 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha if value, pending := s.pendingStorage[key]; pending { return value, nil } - if revived, revive := s.queryFromReviveState(db, s.pendingReviveState, key); revive { - return revived, nil + if s.db.enableStateEpoch(true) { + if revived, revive := s.queryFromReviveState(db, s.pendingReviveState, key); revive { + return revived, nil + } } if value, cached := s.getOriginStorage(key); cached { @@ -360,7 +375,7 @@ func (s *StateObject) SetState(db Database, key, value common.Hash) error { return nil } // when state insert, check if valid to insert new state - if prev != (common.Hash{}) { + if s.db.enableStateEpoch(true) && prev != (common.Hash{}) { _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { @@ -786,6 +801,9 @@ func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { } func (s *StateObject) accessState(key common.Hash) { + if !s.db.enableStateEpoch(false) { + return + } s.db.journal.append(accessedStorageStateChange{ address: &s.address, slot: &key, diff --git a/core/state/statedb.go b/core/state/statedb.go index 8dcc0056ef..7aef7f0a64 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -1794,6 +1794,9 @@ func (s *StateDB) GetStorage(address common.Address) *sync.Map { // ReviveTrie revive a trie with a given witness list func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { + if !s.enableStateEpoch(true) { + return errors.New("cannot revive any state before epoch2") + } for i := range witnessList { wit := witnessList[i] // got specify witness, verify proof and check if revive success @@ -1835,3 +1838,12 @@ func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { func (s *StateDB) openShadowStorage(addr common.Hash) trie.ShadowNodeStorage { return trie.NewShadowNodeStorage4Trie(addr, s.shadowNodeDB) } + +// enableStateEpoch return if enable state expiry hard fork, if inExpired, return if after epoch1 +func (s *StateDB) enableStateEpoch(inExpired bool) bool { + if !inExpired { + return s.targetEpoch > types.StateEpoch0 + } + + return s.targetEpoch > types.StateEpoch1 +} diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go index 157f41d7b7..dd97f0f826 100644 --- a/core/types/state_epoch.go +++ b/core/types/state_epoch.go @@ -8,9 +8,8 @@ import ( ) var ( - // EpochPeriod indicates the state rotate epoch block length - EpochPeriod = big.NewInt(7_008_000) StateEpoch0 = StateEpoch(0) + StateEpoch1 = StateEpoch(1) ) type StateEpoch uint16 @@ -19,12 +18,13 @@ type StateEpoch uint16 // state epoch will indicate if the state is accessible or expiry. // Before ClaudeBlock indicates state epoch0. // ClaudeBlock indicates start state epoch1. -// ElwoodBlock indicates start state epoch2 and start epoch rotate by EpochPeriod. +// ElwoodBlock indicates start state epoch2 and start epoch rotate by StateEpochPeriod. // When N>=2 and epochN started, epoch(N-2)'s state will expire. func GetStateEpoch(config *params.ChainConfig, blockNumber *big.Int) StateEpoch { if config.IsElwood(blockNumber) { + epochPeriod := new(big.Int).SetUint64(config.Parlia.StateEpochPeriod) ret := new(big.Int).Sub(blockNumber, config.ElwoodBlock) - ret.Div(ret, EpochPeriod) + ret.Div(ret, epochPeriod) ret.Add(ret, common.Big2) return StateEpoch(ret.Uint64()) } else if config.IsClaude(blockNumber) { diff --git a/params/config.go b/params/config.go index 22812c7ec3..10e420d75f 100644 --- a/params/config.go +++ b/params/config.go @@ -119,8 +119,9 @@ var ( //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -148,8 +149,9 @@ var ( //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -177,8 +179,9 @@ var ( //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -323,8 +326,9 @@ func (c *CliqueConfig) String() string { // ParliaConfig is the consensus engine configs for proof-of-staked-authority based sealing. type ParliaConfig struct { - Period uint64 `json:"period"` // Number of seconds between blocks to enforce - Epoch uint64 `json:"epoch"` // Epoch length to update validatorSet + Period uint64 `json:"period"` // Number of seconds between blocks to enforce + Epoch uint64 `json:"epoch"` // Epoch length to update validatorSet + StateEpochPeriod uint64 `json:"stateEpochPeriod"` // StateEpochPeriod it indicates the length of a state epoch, default 7_008_000 } // String implements the stringer interface, returning the consensus engine details. diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 4f52ad648e..6456dae840 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -79,6 +79,8 @@ func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, return &SecureTrie{trie: *trie}, nil } +// NewSecureWithShadowNodes it opens a trie with shadow nodes, it needs to know current epoch to check if expired +// it uses sndb to query or store shadow nodes, if you using NewSecure, it opens storage trie at epoch0. func NewSecureWithShadowNodes(curEpoch types.StateEpoch, root common.Hash, db *Database, sndb ShadowNodeStorage) (*SecureTrie, error) { if db == nil || sndb == nil { panic("trie.NewSecure called without a database") diff --git a/trie/trie.go b/trie/trie.go index 65d80ba525..757abba8d2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -1070,6 +1070,9 @@ func (t *Trie) ShadowHash() (*common.Hash, error) { if t.root == nil { return nil, nil } + if t.sndb == nil { + return nil, errors.New("ShadowHash sndb is nil") + } h := newHasher(true) defer returnHasherToPool(h) return t.shadowHash(t.root, h, nil, t.root.getEpoch()) From 551b63d96f7c8baf744072364bc5693f6f687dc2 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Thu, 11 May 2023 20:55:25 +0800 Subject: [PATCH 40/51] state/statedb: opt some function naming; trie: fix trie hash root node bug; state/statedb: add system contracts to expiry white list; state/state_object: using expiry white list; state/state_object: opt shadow node diff layers update; --- accounts/abi/bind/backends/simulated.go | 6 +- cmd/evm/internal/t8ntool/execution.go | 4 +- cmd/geth/snapshot.go | 10 +- consensus/parlia/parlia.go | 13 ++- core/blockchain.go | 8 +- core/blockchain_reader.go | 2 +- core/chain_makers.go | 2 +- core/state/database.go | 5 +- core/state/errors.go | 11 +- core/state/pruner/pruner.go | 5 +- core/state/snapshot/generate_test.go | 35 ++++--- core/state/state_object.go | 16 +-- core/state/statedb.go | 46 ++++++++- core/state_processor.go | 2 +- core/types/state_epoch.go | 13 ++- core/types/state_epoch_test.go | 14 +-- core/vm/evm.go | 8 ++ eth/api.go | 4 +- eth/protocols/snap/handler.go | 5 +- eth/protocols/snap/sync_test.go | 3 +- eth/state_accessor.go | 8 +- les/downloader/downloader_test.go | 2 +- light/trie.go | 2 +- trie/database.go | 4 - trie/iterator_test.go | 2 +- trie/proof_test.go | 19 +++- trie/secure_trie.go | 18 +--- trie/secure_trie_test.go | 4 +- trie/shadow_node.go | 7 +- trie/shadow_node_difflayer.go | 3 +- trie/sync_test.go | 6 +- trie/trie.go | 128 +++++++++++++++++------- trie/trie_test.go | 4 +- 33 files changed, 276 insertions(+), 143 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 91a40206ff..32809d2a46 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -132,7 +132,7 @@ func (b *SimulatedBackend) rollback(parent *types.Block) { b.pendingBlock = blocks[0] blockNum := new(big.Int).Add(parent.Number(), common.Big1) - b.pendingState, _ = state.NewWithEpoch(b.config, blockNum, b.pendingBlock.Root(), b.blockchain.StateCache(), nil, b.blockchain.ShadowNodeTree()) + b.pendingState, _ = state.NewWithStateEpoch(b.config, blockNum, b.pendingBlock.Root(), b.blockchain.StateCache(), nil, b.blockchain.ShadowNodeTree()) } // Fork creates a side-chain that can be used to simulate reorgs. @@ -901,7 +901,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.NewWithEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) + b.pendingState, _ = state.NewWithStateEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) b.pendingReceipts = receipts[0] return nil } @@ -1017,7 +1017,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.NewWithEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) + b.pendingState, _ = state.NewWithStateEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) return nil } diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 4b7f2fbe6f..ba656da33d 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -273,7 +273,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) *state.StateDB { sdb := state.NewDatabase(db) tree, _ := trie.NewShadowNodeSnapTree(db) - statedb, _ := state.NewWithEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1), common.Hash{}, sdb, nil, tree) + statedb, _ := state.NewWithStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1), common.Hash{}, sdb, nil, tree) for addr, a := range pre.Pre { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) @@ -286,7 +286,7 @@ func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) statedb.Finalise(false) statedb.AccountsIntermediateRoot() root, _, _ := statedb.Commit(nil) - statedb, _ = state.NewWithEpoch(config, new(big.Int).SetUint64(pre.Env.Number), root, sdb, nil, tree) + statedb, _ = state.NewWithStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number), root, sdb, nil, tree) return statedb } diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 32898a2045..f2243ff390 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -542,7 +542,7 @@ func traverseState(ctx *cli.Context) error { log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } triedb := trie.NewDatabase(chaindb) - t, err := trie.NewSecure(root, triedb, false) + t, err := trie.NewSecure(root, triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) return err @@ -563,7 +563,8 @@ func traverseState(ctx *cli.Context) error { return err } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + storageTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return err @@ -631,7 +632,7 @@ func traverseRawState(ctx *cli.Context) error { log.Info("Start traversing the state", "root", root, "number", headBlock.NumberU64()) } triedb := trie.NewDatabase(chaindb) - t, err := trie.NewSecure(root, triedb, false) + t, err := trie.NewSecure(root, triedb) if err != nil { log.Error("Failed to open trie", "root", root, "err", err) return err @@ -667,7 +668,8 @@ func traverseRawState(ctx *cli.Context) error { return errors.New("invalid account") } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + storageTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) return errors.New("missing storage trie") diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index fb338d4ebc..161141bca5 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -46,9 +46,8 @@ const ( inMemorySnapshots = 128 // Number of recent snapshots to keep in memory inMemorySignatures = 4096 // Number of recent block signatures to keep in memory - checkpointInterval = 1024 // Number of blocks after which to save the snapshot to the database - defaultEpochLength = uint64(100) // Default number of blocks of checkpoint to update validatorSet from contract - defaultStateEpochPeriod = uint64(7_008_000) + checkpointInterval = 1024 // Number of blocks after which to save the snapshot to the database + defaultEpochLength = uint64(100) // Default number of blocks of checkpoint to update validatorSet from contract extraVanity = 32 // Fixed number of extra-data prefix bytes reserved for signer vanity extraSeal = 65 // Fixed number of extra-data suffix bytes reserved for signer seal @@ -149,7 +148,7 @@ var ( type SignerFn func(accounts.Account, string, []byte) ([]byte, error) type SignerTxFn func(accounts.Account, *types.Transaction, *big.Int) (*types.Transaction, error) -func isToSystemContract(to common.Address) bool { +func IsToSystemContract(to common.Address) bool { return systemContracts[to] } @@ -232,7 +231,7 @@ func New( parliaConfig.Epoch = defaultEpochLength } if parliaConfig != nil && parliaConfig.StateEpochPeriod == 0 { - parliaConfig.StateEpochPeriod = defaultStateEpochPeriod + parliaConfig.StateEpochPeriod = types.DefaultStateEpochPeriod } log.Info("instance parlia with config", "period", parliaConfig.Period, "epoch", parliaConfig.Epoch, "stateEpochPeriod", parliaConfig.StateEpochPeriod) @@ -278,7 +277,7 @@ func (p *Parlia) IsSystemTransaction(tx *types.Transaction, header *types.Header if err != nil { return false, errors.New("UnAuthorized transaction") } - if sender == header.Coinbase && isToSystemContract(*tx.To()) && tx.GasPrice().Cmp(big.NewInt(0)) == 0 { + if sender == header.Coinbase && IsToSystemContract(*tx.To()) && tx.GasPrice().Cmp(big.NewInt(0)) == 0 { return true, nil } return false, nil @@ -288,7 +287,7 @@ func (p *Parlia) IsSystemContract(to *common.Address) bool { if to == nil { return false } - return isToSystemContract(*to) + return IsToSystemContract(*to) } // Author implements consensus.Engine, returning the SystemAddress diff --git a/core/blockchain.go b/core/blockchain.go index 223ff4a4cc..f2e69e6e32 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -371,7 +371,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par // Make sure the state associated with the block is available head := bc.CurrentBlock() - if _, err := state.NewWithEpoch(chainConfig, head.Number(), head.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { + if _, err := state.NewWithStateEpoch(chainConfig, head.Number(), head.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { // Head state is missing, before the state recovery, find out the // disk layer point of snapshot(if it's enabled). Make sure the // rewound point is lower than disk layer. @@ -722,7 +722,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo enoughBeyondCount = beyondCount > maxBeyondBlocks - if _, err := state.NewWithEpoch(bc.chainConfig, newHeadBlock.Number(), newHeadBlock.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { + if _, err := state.NewWithStateEpoch(bc.chainConfig, newHeadBlock.Number(), newHeadBlock.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) @@ -846,7 +846,7 @@ func (bc *BlockChain) SnapSyncCommitHead(hash common.Hash) error { if block == nil { return fmt.Errorf("non existent block [%x..]", hash[:4]) } - if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), false); err != nil { + if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB()); err != nil { return err } @@ -1894,7 +1894,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) if parent == nil { parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1) } - statedb, err := state.NewWithEpoch(bc.chainConfig, block.Number(), parent.Root, bc.stateCache, bc.snaps, bc.ShadowNodeTree()) + statedb, err := state.NewWithStateEpoch(bc.chainConfig, block.Number(), parent.Root, bc.stateCache, bc.snaps, bc.ShadowNodeTree()) if err != nil { return it.index, err } diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index 70a5138e12..203f2473f6 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -316,7 +316,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) { // StateAt returns a new mutable state based on a particular point in time. func (bc *BlockChain) StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) { - return state.NewWithEpoch(bc.chainConfig, number, root, bc.stateCache, bc.snaps, bc.shadowNodeTree) + return state.NewWithStateEpoch(bc.chainConfig, number, root, bc.stateCache, bc.snaps, bc.shadowNodeTree) } // Config retrieves the chain's fork configuration. diff --git a/core/chain_makers.go b/core/chain_makers.go index 4272bd6b38..1e86f9a245 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -282,7 +282,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse } for i := 0; i < n; i++ { number := new(big.Int).Add(parent.Number(), common.Big1) - statedb, err := state.NewWithEpoch(config, number, parent.Root(), state.NewDatabase(db), nil, tree) + statedb, err := state.NewWithStateEpoch(config, number, parent.Root(), state.NewDatabase(db), nil, tree) if err != nil { panic(err) } diff --git a/core/state/database.go b/core/state/database.go index 00b407c799..744f605e0a 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -224,7 +224,7 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { return tr.(Trie).(*trie.SecureTrie).Copy(), nil } } - tr, err := trie.NewSecure(root, db.db, false) + tr, err := trie.NewSecure(root, db.db) if err != nil { return nil, err } @@ -247,7 +247,8 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { } } - tr, err := trie.NewSecure(root, db.db, false) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + tr, err := trie.NewSecure(root, db.db) if err != nil { return nil, err } diff --git a/core/state/errors.go b/core/state/errors.go index ac7649c5c3..3b941cacd6 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -1,6 +1,8 @@ package state import ( + "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/trie" @@ -13,6 +15,7 @@ type ExpiredStateError struct { Path []byte Epoch types.StateEpoch isInsert bool // when true it through expired path, must recovery the expired path + reason string } func NewPlainExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { @@ -22,6 +25,7 @@ func NewPlainExpiredStateError(addr common.Address, key common.Hash, epoch types Path: []byte{}, Epoch: epoch, isInsert: false, + reason: "snap query", } } @@ -32,6 +36,7 @@ func NewExpiredStateError(addr common.Address, key common.Hash, err *trie.Expire Path: err.Path, Epoch: err.Epoch, isInsert: false, + reason: "query", } } @@ -42,12 +47,10 @@ func NewInsertExpiredStateError(addr common.Address, key common.Hash, err *trie. Path: err.Path, Epoch: err.Epoch, isInsert: true, + reason: "insert", } } func (e *ExpiredStateError) Error() string { - if e.isInsert { - return "Insert state through expired path" - } - return "Access expired state" + return fmt.Sprintf("Access expired state, addr: %v, key: %v, epoch: %v, reason: %v", e.Addr, e.Key, e.Epoch, e.reason) } diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 7b70b2b310..8032d7294c 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -736,7 +736,7 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { if genesis == nil { return errors.New("missing genesis block") } - t, err := trie.NewSecure(genesis.Root(), trie.NewDatabase(db), false) + t, err := trie.NewSecure(genesis.Root(), trie.NewDatabase(db)) if err != nil { return err } @@ -756,7 +756,8 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { return err } if acc.Root != emptyRoot { - storageTrie, err := trie.NewSecure(acc.Root, trie.NewDatabase(db), true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + storageTrie, err := trie.NewSecure(acc.Root, trie.NewDatabase(db)) if err != nil { return err } diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index b95bd49ed4..1f45fe72a6 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -42,13 +42,14 @@ func TestGeneration(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -99,13 +100,14 @@ func TestGenerateExistentState(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -179,7 +181,7 @@ type testHelper struct { func newHelper() *testHelper { diskdb := memorydb.New() triedb := trie.NewDatabase(diskdb) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) return &testHelper{ diskdb: diskdb, triedb: triedb, @@ -211,7 +213,8 @@ func (t *testHelper) addSnapStorage(accKey string, keys []string, vals []string) } func (t *testHelper) makeStorageTrie(keys []string, vals []string) []byte { - stTrie, _ := trie.NewSecure(common.Hash{}, t.triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, t.triedb) for i, k := range keys { stTrie.Update([]byte(k), []byte(vals[i])) } @@ -384,7 +387,7 @@ func TestGenerateCorruptAccountTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - tr, _ := trie.NewSecure(common.Hash{}, triedb, false) + tr, _ := trie.NewSecure(common.Hash{}, triedb) acc := &Account{Balance: big.NewInt(1), Root: emptyRoot.Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) tr.Update([]byte("acc-1"), val) // 0xc7a30f39aff471c95d8a837497ad0e49b65be475cc0953540f80cfcdbdcd9074 @@ -428,13 +431,14 @@ func TestGenerateMissingStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -487,13 +491,15 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) - stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) + + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 stTrie.Update([]byte("key-3"), []byte("val-3")) // 0x51c71a47af0695957647fb68766d0becee77e953df17c29b3c2f25436f055c78 stTrie.Commit(nil) // Root: 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67 - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e @@ -537,7 +543,8 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { } func getStorageTrie(n int, triedb *trie.Database) *trie.SecureTrie { - stTrie, _ := trie.NewSecure(common.Hash{}, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, _ := trie.NewSecure(common.Hash{}, triedb) for i := 0; i < n; i++ { k := fmt.Sprintf("key-%d", i) v := fmt.Sprintf("val-%d", i) @@ -554,7 +561,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) { triedb = trie.NewDatabase(diskdb) stTrie = getStorageTrie(5, triedb) ) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) { // Account one in the trie acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) @@ -618,7 +625,7 @@ func TestGenerateWithManyExtraAccounts(t *testing.T) { triedb = trie.NewDatabase(diskdb) stTrie = getStorageTrie(3, triedb) ) - accTrie, _ := trie.NewSecure(common.Hash{}, triedb, false) + accTrie, _ := trie.NewSecure(common.Hash{}, triedb) { // Account one in the trie acc := &Account{Balance: big.NewInt(1), Root: stTrie.Hash().Bytes(), CodeHash: emptyCode.Bytes()} val, _ := rlp.EncodeToBytes(acc) diff --git a/core/state/state_object.go b/core/state/state_object.go index 8916a261f2..5b137143ba 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -24,6 +24,8 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/trie" @@ -199,9 +201,10 @@ func (s *StateObject) getTrie(db Database) Trie { } var err error // check if enable state epoch - if s.db.enableStateEpoch(false) { + if s.db.enableAccStateEpoch(false, s.address) { s.trie, err = db.OpenStorageTrieWithShadowNode(s.addrHash, s.data.Root, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) if err != nil { + log.Error("OpenStorageTrieWithShadowNode err", "targetEpoch", s.targetEpoch, "err", err) s.trie, _ = db.OpenStorageTrieWithShadowNode(s.addrHash, common.Hash{}, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) s.setError(fmt.Errorf("can't create storage trie with shadowNode: %v", err)) } @@ -242,7 +245,7 @@ func (s *StateObject) GetState(db Database, key common.Hash) (common.Hash, error s.accessState(key) return value, nil } - if s.db.enableStateEpoch(true) { + if s.db.enableAccStateEpoch(true, s.address) { if revived, revive := s.queryFromReviveState(db, s.dirtyReviveState, key); revive { s.accessState(key) return revived, nil @@ -291,7 +294,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha if value, pending := s.pendingStorage[key]; pending { return value, nil } - if s.db.enableStateEpoch(true) { + if s.db.enableAccStateEpoch(true, s.address) { if revived, revive := s.queryFromReviveState(db, s.pendingReviveState, key); revive { return revived, nil } @@ -318,7 +321,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha } if err == nil { if snapVal, err := snapshot.ParseSnapValFromBytes(enc); err == nil { - if types.EpochExpired(snapVal.Epoch, s.targetEpoch) { + if s.db.enableAccStateEpoch(true, s.address) && types.EpochExpired(snapVal.Epoch, s.targetEpoch) { return common.Hash{}, NewPlainExpiredStateError(s.address, key, snapVal.Epoch) } return snapVal.Val, nil @@ -375,7 +378,7 @@ func (s *StateObject) SetState(db Database, key, value common.Hash) error { return nil } // when state insert, check if valid to insert new state - if s.db.enableStateEpoch(true) && prev != (common.Hash{}) { + if s.db.enableAccStateEpoch(true, s.address) && prev != (common.Hash{}) { _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { @@ -603,6 +606,7 @@ func (s *StateObject) CommitTrie(db Database) (int, error) { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } root, committed, err := s.trie.Commit(nil) + log.Info("obj CommitTrie", "addr", s.address, "root", root) if err == nil { s.data.Root = root } @@ -801,7 +805,7 @@ func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { } func (s *StateObject) accessState(key common.Hash) { - if !s.db.enableStateEpoch(false) { + if !s.db.enableAccStateEpoch(false, s.address) { return } s.db.journal.append(accessedStorageStateChange{ diff --git a/core/state/statedb.go b/core/state/statedb.go index 7aef7f0a64..86558b2e10 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -154,14 +154,15 @@ type StateDB struct { StorageDeleted int } -// NewWithEpoch creates a new state from a given trie. -func NewWithEpoch(config *params.ChainConfig, targetBlock *big.Int, root common.Hash, db Database, snaps *snapshot.Tree, sntree *trie.ShadowNodeSnapTree) (*StateDB, error) { +// NewWithStateEpoch creates a new state from a given trie. +func NewWithStateEpoch(config *params.ChainConfig, targetBlock *big.Int, root common.Hash, db Database, snaps *snapshot.Tree, sntree *trie.ShadowNodeSnapTree) (*StateDB, error) { targetEpoch := types.GetStateEpoch(config, targetBlock) stateDB, err := newStateDB(root, db, snaps, targetEpoch) if err != nil { return nil, err } + log.Info("NewWithStateEpoch", "targetBlock", targetBlock, "targetEpoch", targetEpoch, "root", root) // init target block and shadowNodeRW stateDB.targetBlk = targetBlock stateDB.shadowNodeDB, err = trie.NewShadowNodeDatabase(sntree, targetBlock, root) @@ -1644,11 +1645,12 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } - if s.shadowNodeDB != nil && s.originalRoot != root { + if s.shadowNodeDB != nil { if err := s.shadowNodeDB.Commit(s.targetBlk, root); err != nil { return common.Hash{}, nil, err } } + log.Info("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) return root, diffLayer, nil } @@ -1847,3 +1849,41 @@ func (s *StateDB) enableStateEpoch(inExpired bool) bool { return s.targetEpoch > types.StateEpoch1 } + +// enableAccStateEpoch return if enable account state expiry hard fork, if inExpired, return if after epoch1 +func (s *StateDB) enableAccStateEpoch(inExpired bool, addr common.Address) bool { + // TODO(0xbundler): temporary code, add IsToSystemContract in whitelist for avoid expiry system contract, + // it uses in testnet, because there no crossChain msg to update them + if systemContracts[addr] { + return false + } + return s.enableStateEpoch(inExpired) +} + +// TODO(0xbundler): temporary code, remove in release version +const ( + // genesis contracts + ValidatorContract = "0x0000000000000000000000000000000000001000" + SlashContract = "0x0000000000000000000000000000000000001001" + SystemRewardContract = "0x0000000000000000000000000000000000001002" + LightClientContract = "0x0000000000000000000000000000000000001003" + TokenHubContract = "0x0000000000000000000000000000000000001004" + RelayerIncentivizeContract = "0x0000000000000000000000000000000000001005" + RelayerHubContract = "0x0000000000000000000000000000000000001006" + GovHubContract = "0x0000000000000000000000000000000000001007" + TokenManagerContract = "0x0000000000000000000000000000000000001008" + CrossChainContract = "0x0000000000000000000000000000000000002000" + StakingContract = "0x0000000000000000000000000000000000002001" +) + +var systemContracts = map[common.Address]bool{ + common.HexToAddress(ValidatorContract): true, + common.HexToAddress(SlashContract): true, + common.HexToAddress(SystemRewardContract): true, + common.HexToAddress(LightClientContract): true, + common.HexToAddress(RelayerHubContract): true, + common.HexToAddress(GovHubContract): true, + common.HexToAddress(TokenHubContract): true, + common.HexToAddress(RelayerIncentivizeContract): true, + common.HexToAddress(CrossChainContract): true, +} diff --git a/core/state_processor.go b/core/state_processor.go index 869f617bcc..0dc1d72607 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -122,7 +122,7 @@ func (p *LightStateProcessor) Process(block *types.Block, statedb *state.StateDB // prepare new statedb statedb.StopPrefetcher() parent := p.bc.GetHeader(block.ParentHash(), block.NumberU64()-1) - statedb, err = state.NewWithEpoch(p.config, block.Number(), parent.Root, p.bc.stateCache, p.bc.snaps, p.bc.shadowNodeTree) + statedb, err = state.NewWithStateEpoch(p.config, block.Number(), parent.Root, p.bc.stateCache, p.bc.snaps, p.bc.shadowNodeTree) if err != nil { return statedb, nil, nil, 0, err } diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go index dd97f0f826..eb677009d7 100644 --- a/core/types/state_epoch.go +++ b/core/types/state_epoch.go @@ -8,8 +8,9 @@ import ( ) var ( - StateEpoch0 = StateEpoch(0) - StateEpoch1 = StateEpoch(1) + DefaultStateEpochPeriod = uint64(7_008_000) + StateEpoch0 = StateEpoch(0) + StateEpoch1 = StateEpoch(1) ) type StateEpoch uint16 @@ -21,10 +22,14 @@ type StateEpoch uint16 // ElwoodBlock indicates start state epoch2 and start epoch rotate by StateEpochPeriod. // When N>=2 and epochN started, epoch(N-2)'s state will expire. func GetStateEpoch(config *params.ChainConfig, blockNumber *big.Int) StateEpoch { + epochPeriod := DefaultStateEpochPeriod + if config.Parlia != nil && config.Parlia.StateEpochPeriod != 0 { + epochPeriod = config.Parlia.StateEpochPeriod + } if config.IsElwood(blockNumber) { - epochPeriod := new(big.Int).SetUint64(config.Parlia.StateEpochPeriod) + epochPeriodInt := new(big.Int).SetUint64(epochPeriod) ret := new(big.Int).Sub(blockNumber, config.ElwoodBlock) - ret.Div(ret, epochPeriod) + ret.Div(ret, epochPeriodInt) ret.Add(ret, common.Big2) return StateEpoch(ret.Uint64()) } else if config.IsClaude(blockNumber) { diff --git a/core/types/state_epoch_test.go b/core/types/state_epoch_test.go index 8286d21c4a..892ae89bef 100644 --- a/core/types/state_epoch_test.go +++ b/core/types/state_epoch_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" ) +var epochPeriod = new(big.Int).SetUint64(DefaultStateEpochPeriod) + func TestStateForkConfig(t *testing.T) { temp := ¶ms.ChainConfig{} assert.NoError(t, temp.CheckConfigForkOrder()) @@ -65,8 +67,8 @@ func TestSimpleStateEpoch(t *testing.T) { assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(19999))) assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(20000))) - assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), EpochPeriod))) - assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) } func TestNoZeroStateEpoch(t *testing.T) { @@ -80,8 +82,8 @@ func TestNoZeroStateEpoch(t *testing.T) { assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(1))) assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(2))) assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10000))) - assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), EpochPeriod))) - assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) } func TestNearestStateEpoch(t *testing.T) { @@ -94,6 +96,6 @@ func TestNearestStateEpoch(t *testing.T) { assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(0))) assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10001))) - assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), EpochPeriod))) - assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), EpochPeriod)))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) } diff --git a/core/vm/evm.go b/core/vm/evm.go index 3408ba40ab..f59d23526b 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/log" + "github.com/holiman/uint256" "github.com/ethereum/go-ethereum/common" @@ -262,7 +264,13 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // TODO: consider clearing up unused snapshots: //} else { // evm.StateDB.DiscardSnapshot(snapshot) + + errors := evm.Errors() + for _, e := range errors { + log.Error("call err", "addr", addr, "from", caller.Address(), "err", e.Error()) + } } + return ret, gas, err } diff --git a/eth/api.go b/eth/api.go index 9c2589553c..b1bea41d87 100644 --- a/eth/api.go +++ b/eth/api.go @@ -525,11 +525,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc } triedb := api.eth.BlockChain().StateCache().TrieDB() - oldTrie, err := trie.NewSecure(startBlock.Root(), triedb, false) + oldTrie, err := trie.NewSecure(startBlock.Root(), triedb) if err != nil { return nil, err } - newTrie, err := trie.NewSecure(endBlock.Root(), triedb, false) + newTrie, err := trie.NewSecure(endBlock.Root(), triedb) if err != nil { return nil, err } diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index e5d210cf00..030992d2fd 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -487,7 +487,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s // Make sure we have the state associated with the request triedb := chain.StateCache().TrieDB() - accTrie, err := trie.NewSecure(req.Root, triedb, false) + accTrie, err := trie.NewSecure(req.Root, triedb) if err != nil { // We don't have the requested state available, bail out return nil, nil @@ -529,7 +529,8 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s if err != nil || account == nil { break } - stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb) loads++ // always account database reads, even for failures if err != nil { break diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index e2865afd0d..dd10edea67 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -1604,7 +1604,8 @@ func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { } accounts++ if acc.Root != emptyRoot { - storeTrie, err := trie.NewSecure(acc.Root, triedb, true) + // TODO default using epoch0 trie, but need sndb to query shadow nodes + storeTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { t.Fatal(err) } diff --git a/eth/state_accessor.go b/eth/state_accessor.go index 41459f2eeb..b7cf4c0db2 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -66,7 +66,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // Create an ephemeral trie.Database for isolating the live one. Otherwise // the internal junks created by tracing will be persisted into the disk. database = state.NewDatabaseWithConfig(eth.chainDb, &trie.Config{Cache: 16}) - if statedb, err = state.NewWithEpoch(eth.blockchain.Config(), block.Number(), block.Root(), database, nil, eth.blockchain.ShadowNodeTree()); err == nil { + if statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), block.Number(), block.Root(), database, nil, eth.blockchain.ShadowNodeTree()); err == nil { log.Info("Found disk backend for state trie", "root", block.Root(), "number", block.Number()) return statedb, nil } @@ -89,7 +89,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // we would rewind past a persisted block (specific corner case is chain // tracing from the genesis). if !checkLive { - statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { return statedb, nil } @@ -105,7 +105,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state } current = parent - statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { break } @@ -148,7 +148,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state return nil, fmt.Errorf("stateAtBlock commit failed, number %d root %v: %w", current.NumberU64(), current.Root().Hex(), err) } - statedb, err = state.NewWithEpoch(eth.blockchain.Config(), current.Number(), root, database, nil, eth.blockchain.ShadowNodeTree()) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), root, database, nil, eth.blockchain.ShadowNodeTree()) if err != nil { return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) } diff --git a/les/downloader/downloader_test.go b/les/downloader/downloader_test.go index e270cc0567..69bdb90ed2 100644 --- a/les/downloader/downloader_test.go +++ b/les/downloader/downloader_test.go @@ -229,7 +229,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block { func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error { // For now only check that the state trie is correct if block := dl.GetBlockByHash(hash); block != nil { - _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), false) + _, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb)) return err } return fmt.Errorf("non existent block: %x", hash[:4]) diff --git a/light/trie.go b/light/trie.go index 034e49fabc..0cf403de8e 100644 --- a/light/trie.go +++ b/light/trie.go @@ -39,7 +39,7 @@ var ( func NewState(ctx context.Context, config *params.ChainConfig, head *types.Header, odr OdrBackend) *state.StateDB { tree, _ := trie.NewShadowNodeSnapTree(odr.Database()) - state, _ := state.NewWithEpoch(config, head.Number, head.Root, NewStateDatabase(ctx, head, odr), nil, tree) + state, _ := state.NewWithStateEpoch(config, head.Number, head.Root, NewStateDatabase(ctx, head, odr), nil, tree) return state } diff --git a/trie/database.go b/trie/database.go index c4a3b87aec..da2aad5108 100644 --- a/trie/database.go +++ b/trie/database.go @@ -432,10 +432,6 @@ func (db *Database) node(hash common.Hash) node { return mustDecodeNodeUnsafe(hash[:], enc) } -func (db *Database) RootNode(hash common.Hash) *rootNode { - return nil -} - // Node retrieves an encoded cached trie node from memory. If it cannot be found // cached, the method queries the persistent database for the content. func (db *Database) Node(hash common.Hash) ([]byte, error) { diff --git a/trie/iterator_test.go b/trie/iterator_test.go index f2e6c7d586..871a5f3fca 100644 --- a/trie/iterator_test.go +++ b/trie/iterator_test.go @@ -508,7 +508,7 @@ func makeLargeTestTrie() (*Database, *SecureTrie, *loggingDb) { // Create an empty trie logDb := &loggingDb{0, memorydb.New()} triedb := NewDatabase(logDb) - trie, _ := NewSecure(common.Hash{}, triedb, false) + trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data for i := 0; i < 10000; i++ { diff --git a/trie/proof_test.go b/trie/proof_test.go index 0b914bcef8..ede4ccea5b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -1273,7 +1273,24 @@ func randBytes(n int) []byte { func nonRandomTrie(n int) (*Trie, map[string]*kv) { trie := new(Trie) - trie.useShadowTree = true + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + elem := &kv{key, value, false} + trie.Update(elem.k, elem.v) + vals[string(elem.k)] = elem + } + return trie, vals +} + +func nonRandomTrieWithShadowNodes(n int) (*Trie, map[string]*kv) { + trie := new(Trie) + trie.withShadowNodes = true trie.currentEpoch = 10 // TODO (asyukii): might need to change this vals := make(map[string]*kv) max := uint64(0xffffffffffffffff) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 6456dae840..8272ae0534 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -52,30 +52,14 @@ type SecureTrie struct { // Loaded nodes are kept around until their 'cache generation' expires. // A new cache generation is created by each call to Commit. // cachelimit sets the number of past cache generations to keep. -func NewSecure(root common.Hash, db *Database, isStorageTrie bool) (*SecureTrie, error) { +func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { if db == nil { panic("trie.NewSecure called without a database") } - epoch := types.StateEpoch(0) - shadowHash := common.Hash{} - if isStorageTrie { - if rootNode := db.RootNode(root); rootNode != nil { - root = rootNode.TrieRoot - epoch = rootNode.Epoch - shadowHash = rootNode.ShadowTreeRoot - } - } trie, err := New(root, db) if err != nil { return nil, err } - if isStorageTrie { - if trie.root != nil { - trie.root.setEpoch(epoch) - } - trie.shadowTreeRoot = shadowHash - } - trie.useShadowTree = isStorageTrie return &SecureTrie{trie: *trie}, nil } diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go index f101514675..7bbdf29ef0 100644 --- a/trie/secure_trie_test.go +++ b/trie/secure_trie_test.go @@ -29,7 +29,7 @@ import ( ) func newEmptySecure() *SecureTrie { - trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New()), false) + trie, _ := NewSecure(common.Hash{}, NewDatabase(memorydb.New())) return trie } @@ -37,7 +37,7 @@ func newEmptySecure() *SecureTrie { func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(memorydb.New()) - trie, _ := NewSecure(common.Hash{}, triedb, false) + trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data content := make(map[string][]byte) diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 5cde68966c..04b761acc8 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -6,6 +6,8 @@ import ( "math/big" "sync" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/ethdb" @@ -236,11 +238,13 @@ func NewShadowNodeDatabase(tree *ShadowNodeSnapTree, number *big.Int, blockRoot // try using default snap if snap = tree.Snapshot(emptyRoot); snap == nil { // open read only history + log.Info("NewShadowNodeDatabase use RO database", "number", number, "root", blockRoot) return &ShadowNodeStorageRO{ diskdb: tree.DB(), number: number, }, nil } + log.Info("NewShadowNodeDatabase use default database", "number", number, "root", blockRoot) } return &ShadowNodeStorageRW{ snap: snap, @@ -252,9 +256,6 @@ func NewShadowNodeDatabase(tree *ShadowNodeSnapTree, number *big.Int, blockRoot func (s *ShadowNodeStorageRW) Get(addr common.Hash, path string) ([]byte, error) { s.lock.RLock() defer s.lock.RUnlock() - if s.stale { - return nil, errors.New("storage has staled") - } sub, exist := s.dirties[addr] if exist { if val, ok := sub[path]; ok { diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go index 615fe1a595..3506b1ab6b 100644 --- a/trie/shadow_node_difflayer.go +++ b/trie/shadow_node_difflayer.go @@ -142,8 +142,9 @@ func (s *ShadowNodeSnapTree) Cap(blockRoot common.Hash) error { } func (s *ShadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { + // if there are no changes, just skip if blockRoot == parentRoot { - return errors.New("there no update in layers") + return nil } // Generate a new snapshot on top of the parent diff --git a/trie/sync_test.go b/trie/sync_test.go index 05015c5367..970730b671 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -29,7 +29,7 @@ import ( func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) { // Create an empty trie triedb := NewDatabase(memorydb.New()) - trie, _ := NewSecure(common.Hash{}, triedb, false) + trie, _ := NewSecure(common.Hash{}, triedb) // Fill it with some arbitrary data content := make(map[string][]byte) @@ -60,7 +60,7 @@ func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) { // content map. func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) { // Check root availability and trie contents - trie, err := NewSecure(common.BytesToHash(root), db, false) + trie, err := NewSecure(common.BytesToHash(root), db) if err != nil { t.Fatalf("failed to create trie at %x: %v", root, err) } @@ -77,7 +77,7 @@ func checkTrieContents(t *testing.T, db *Database, root []byte, content map[stri // checkTrieConsistency checks that all nodes in a trie are indeed present. func checkTrieConsistency(db *Database, root common.Hash) error { // Create and iterate a trie rooted in a subnode - trie, err := NewSecure(root, db, false) + trie, err := NewSecure(root, db) if err != nil { return nil // Consider a non existent state consistent } diff --git a/trie/trie.go b/trie/trie.go index 757abba8d2..b1b41713c2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -69,10 +69,11 @@ type Trie struct { unhashed int // fields for shadow tree and rootNode - useShadowTree bool - currentEpoch types.StateEpoch - shadowTreeRoot common.Hash - rootEpoch types.StateEpoch + withShadowNodes bool + currentEpoch types.StateEpoch + shadowTreeRoot common.Hash + rootEpoch types.StateEpoch + originTrieRoot common.Hash } // newFlag returns the cache flag value for a newly created node. @@ -115,15 +116,17 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa // only enable after first state expiry's hard fork if curEpoch > types.StateEpoch0 { useShadowTree = true + log.Info("withShadowNodes trie open", "rootNodeHash", rootNode.cachedHash, "RootHash", rootNode.TrieRoot, "ShadowTreeRoot", rootNode.ShadowTreeRoot) } trie := &Trie{ - db: db, - sndb: sndb, - currentEpoch: curEpoch, - useShadowTree: useShadowTree, - shadowTreeRoot: rootNode.ShadowTreeRoot, - rootEpoch: rootNode.Epoch, + db: db, + sndb: sndb, + currentEpoch: curEpoch, + withShadowNodes: useShadowTree, + shadowTreeRoot: rootNode.ShadowTreeRoot, + rootEpoch: rootNode.Epoch, + originTrieRoot: rootNode.TrieRoot, } if rootNode.TrieRoot != (common.Hash{}) && rootNode.TrieRoot != emptyRoot { root, err := trie.resolveHash(rootNode.TrieRoot[:], nil) @@ -160,7 +163,7 @@ func (t *Trie) Get(key []byte) []byte { func (t *Trie) TryGet(key []byte) (value []byte, err error) { var newroot node var didResolve bool - if t.useShadowTree { + if t.withShadowNodes { var nextEpoch types.StateEpoch if t.root != nil { nextEpoch = t.root.getEpoch() @@ -427,7 +430,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } switch n := n.(type) { case *shortNode: - if t.useShadowTree { + if t.withShadowNodes { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -450,7 +453,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - if t.useShadowTree { + if t.withShadowNodes { branch.setEpoch(t.currentEpoch) branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) } @@ -458,7 +461,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - if t.useShadowTree { + if t.withShadowNodes { branch.setEpoch(t.currentEpoch) branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) } @@ -470,7 +473,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil case *fullNode: - if t.useShadowTree { + if t.withShadowNodes { // this full node is expired, return err if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err @@ -552,7 +555,7 @@ func (t *Trie) TryDelete(key []byte) error { func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { switch n := n.(type) { case *shortNode: - if t.useShadowTree { + if t.withShadowNodes { if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err } @@ -587,7 +590,7 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, } case *fullNode: - if t.useShadowTree { + if t.withShadowNodes { // this full node is expired, return err if expired, err := t.nodeExpired(n, prefix); expired { return false, n, err @@ -800,7 +803,21 @@ func (t *Trie) nodeExpired(n node, prefix []byte) (bool, error) { func (t *Trie) Hash() common.Hash { hash, cached, _ := t.hashRoot() t.root = cached - return common.BytesToHash(hash.(hashNode)) + newRootHash := common.BytesToHash(hash.(hashNode)) + if t.withShadowNodes { + newShadowTreeRoot := emptyRoot + shadowTreeRoot, err := t.ShadowHash() + if err != nil { + panic(fmt.Sprintf("trie hash err, when ShadowHash, err %v", err)) + } + // replace shadowTreeRoot for rootNode + if shadowTreeRoot != nil { + newShadowTreeRoot = *shadowTreeRoot + } + rn := newRootNode(t.root.getEpoch(), newRootHash, newShadowTreeRoot) + return rn.cachedHash + } + return newRootHash } // Commit writes all nodes to the trie's memory database, tracking the internal @@ -814,9 +831,11 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { } // Derive the hash for all dirty nodes first. We hold the assumption // in the following procedure that all nodes are hashed. - newRootHash := t.Hash() + hash, cached, _ := t.hashRoot() + t.root = cached + newRootHash := common.BytesToHash(hash.(hashNode)) newShadowTreeRoot := emptyRoot - if t.useShadowTree { + if t.withShadowNodes { shadowTreeRoot, err := t.ShadowHash() if err != nil { return common.Hash{}, 0, err @@ -825,6 +844,10 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if shadowTreeRoot != nil { newShadowTreeRoot = *shadowTreeRoot } + // commit shadow nodes after ShadowHash in Commit + if err = t.commitShadowNodes(t.root, nil, t.root.getEpoch()); err != nil { + return common.Hash{}, 0, err + } } h := newCommitter() @@ -834,7 +857,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { // up goroutines. This can happen e.g. if we load a trie for reading storage // values, but don't write to it. if _, dirty := t.root.cache(); !dirty { - if t.useShadowTree { + if t.withShadowNodes { rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) if err != nil { return common.Hash{}, 0, err @@ -865,12 +888,13 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if err != nil { return common.Hash{}, 0, err } - if t.useShadowTree { + if t.withShadowNodes { rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) if err != nil { return common.Hash{}, 0, err } t.root = newRoot + log.Info("withShadowNodes trie commit", "rootNodeHash", rootNodeHash, "newRootHash", newRootHash, "newShadowTreeRoot", newShadowTreeRoot) return rootNodeHash, committed, nil } t.root = newRoot @@ -885,9 +909,9 @@ func (t *Trie) hashRoot() (node, node, error) { // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) defer returnHasherToPool(h) - hashed, cached := h.hash(t.root, true) + newRootHash, cached := h.hash(t.root, true) t.unhashed = 0 - return hashed, cached, nil + return newRootHash, cached, nil } // Reset drops the referenced root node and cleans all internal state. @@ -1022,7 +1046,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE } func (t *Trie) resolveShadowNode(epoch types.StateEpoch, origin node, prefix []byte) error { - if !t.useShadowTree { + if !t.withShadowNodes { return nil } @@ -1070,9 +1094,6 @@ func (t *Trie) ShadowHash() (*common.Hash, error) { if t.root == nil { return nil, nil } - if t.sndb == nil { - return nil, errors.New("ShadowHash sndb is nil") - } h := newHasher(true) defer returnHasherToPool(h) return t.shadowHash(t.root, h, nil, t.root.getEpoch()) @@ -1110,16 +1131,13 @@ func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.Sta } } n.shadowNode.ShadowHash = h.shadowNodeHashListToHash(hashList) - // TODO(0xbundler): just save shadowNode, will revert in later cryyl version - encBuf := rlp.NewEncoderBuffer(nil) - n.shadowNode.encode(encBuf) - if err := t.sndb.Put(string(hexToSuffixCompact(prefix)), encBuf.ToBytes()); err != nil { - return nil, err - } return h.shadowBranchNodeToHash(&n.shadowNode), nil case valueNode: return nil, nil case hashNode: + if t.db == nil || t.sndb == nil { + return nil, errors.New("ShadowHash db or sndb is nil") + } // resolve temporary, not add to trie rn, err := t.resolveHash(n, prefix) if err != nil { @@ -1134,6 +1152,48 @@ func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.Sta } } +// commitShadowNodes commit node's shadow node, must call after shadowHash +func (t *Trie) commitShadowNodes(origin node, prefix []byte, epoch types.StateEpoch) error { + if t.sndb == nil { + return errors.New("ShadowHash sndb is nil") + } + switch n := origin.(type) { + case *shortNode: + if err := t.commitShadowNodes(n.Val, append(prefix, n.Key...), epoch); err != nil { + return err + } + return nil + case *fullNode: + epochSelf := n.epoch + epochMap := n.shadowNode.EpochMap + for i := byte(0); i < BranchNodeLength-1; i++ { + child := n.Children[i] + if child == nil { + continue + } + // skip expired node. + if epochSelf >= epochMap[i]+2 { + continue + } + + if err := t.commitShadowNodes(child, append(prefix, i), epochMap[i]); err != nil { + return err + } + } + + encBuf := rlp.NewEncoderBuffer(nil) + n.shadowNode.encode(encBuf) + if err := t.sndb.Put(string(hexToSuffixCompact(prefix)), encBuf.ToBytes()); err != nil { + return err + } + return nil + case valueNode, hashNode: + return nil + default: + return errors.New("cannot get shortNode's child shadow node") + } +} + func (t *Trie) storeRootNode(newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { rn := newRootNode(t.root.getEpoch(), newRootHash, newShadowTreeRoot) if err := t.sndb.Put(ShadowTreeRootNodePath, rn.cachedEnc); err != nil { diff --git a/trie/trie_test.go b/trie/trie_test.go index f66f9213be..3b456aa8fb 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -578,7 +578,7 @@ func getFullNodePrefixKeys(t *Trie, key []byte) [][]byte { // TestTryRevive tests that a trie can be revived from a proof func TestTryRevive(t *testing.T) { - trie, vals := nonRandomTrie(500) + trie, vals := nonRandomTrieWithShadowNodes(500) oriRootHash := trie.Hash() @@ -714,7 +714,7 @@ func TestReviveBadProof(t *testing.T) { // TestReviveBadProofAfterUpdate tests that after reviving a path and // then update the value, old proof should be invalid func TestReviveBadProofAfterUpdate(t *testing.T) { - trie, vals := nonRandomTrie(500) + trie, vals := nonRandomTrieWithShadowNodes(500) for _, kv := range vals { key := kv.k prefixKeys := getFullNodePrefixKeys(trie, key) From e8d914307dbf6b47aea32cd1321a08551b10884c Mon Sep 17 00:00:00 2001 From: asyukii Date: Wed, 17 May 2023 11:31:50 +0800 Subject: [PATCH 41/51] fix: add proof nil check --- accounts/abi/bind/backends/simulated.go | 3 +++ internal/ethapi/api.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 32809d2a46..fddbbfdc53 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -664,6 +664,9 @@ func (b *SimulatedBackend) EstimateGasAndReviveState(ctx context.Context, call e if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { isExpiredError = true proof, err := b.pendingState.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) // TODO (asyukii): Is this key correct? + if proof == nil { + continue + } if err != nil { return true, nil, isExpiredError, err } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index cb07d39a7d..6cea6dd395 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -1294,6 +1294,9 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { isExpiredError = true proof, err := stateDb.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) + if proof == nil { + continue + } if err != nil { return true, nil, isExpiredError, err } From c797f85b41534e9edd5a5c605a9595375da79463 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Mon, 15 May 2023 23:36:32 +0800 Subject: [PATCH 42/51] state/database: fix reuse trie with shadow nodes bug; state/state_object: opt snap state query, opt insert dup check; state/statedb: fix copy bugs; trie/trie: opt trie root epoch bug, remove node's epoch methods; --- core/state/database.go | 18 ++++--- core/state/errors.go | 2 +- core/state/state_object.go | 23 ++++++--- core/state/statedb.go | 9 +++- trie/database.go | 15 +----- trie/node.go | 11 +---- trie/trie.go | 98 ++++++++++++++++++++++---------------- 7 files changed, 94 insertions(+), 82 deletions(-) diff --git a/core/state/database.go b/core/state/database.go index 744f605e0a..b7b7c6be20 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -260,17 +260,9 @@ func (db *cachingDB) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, c if db.noTries { return trie.NewEmptyTrie(), nil } - if db.storageTrieCache != nil { - if tries, exist := db.storageTrieCache.Get(addrHash); exist { - triesPairs := tries.([3]*triePair) - for _, triePair := range triesPairs { - if triePair != nil && triePair.root == root && triePair.trie.Epoch() == curEpoch { - return triePair.trie.(*trie.SecureTrie).Copy(), nil - } - } - } - } + // StorageTrie with ShadowNode cannot use trie cache, + // because of its sndb need update in every block, just reinit it tr, err := trie.NewSecureWithShadowNodes(curEpoch, root, db.db, sndb) if err != nil { return nil, err @@ -290,6 +282,12 @@ func (db *cachingDB) CacheStorage(addrHash common.Hash, root common.Hash, t Trie if db.storageTrieCache == nil { return } + + // do not cache trie with shadow nodes + if t.Epoch() > types.StateEpoch0 { + return + } + tr := t.(*trie.SecureTrie) if tries, exist := db.storageTrieCache.Get(addrHash); exist { triesArray := tries.([3]*triePair) diff --git a/core/state/errors.go b/core/state/errors.go index 3b941cacd6..7a47f3f81c 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -18,7 +18,7 @@ type ExpiredStateError struct { reason string } -func NewPlainExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { +func NewSnapExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { return &ExpiredStateError{ Addr: addr, Key: key, diff --git a/core/state/state_object.go b/core/state/state_object.go index 5b137143ba..d30f3aa810 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -319,13 +319,24 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha if metrics.EnabledExpensive { s.db.SnapshotStorageReads += time.Since(start) } + + // snapshot val encode is different from trie, so handle independent if err == nil { - if snapVal, err := snapshot.ParseSnapValFromBytes(enc); err == nil { - if s.db.enableAccStateEpoch(true, s.address) && types.EpochExpired(snapVal.Epoch, s.targetEpoch) { - return common.Hash{}, NewPlainExpiredStateError(s.address, key, snapVal.Epoch) + var value common.Hash + if len(enc) > 0 { + sv, err := snapshot.ParseSnapValFromBytes(enc) + if err != nil { + s.setError(err) + } + if err == nil && s.db.enableAccStateEpoch(true, s.address) && + types.EpochExpired(sv.Epoch, s.targetEpoch) { + return common.Hash{}, NewSnapExpiredStateError(s.address, key, sv.Epoch) } - return snapVal.Val, nil + value.SetBytes(sv.Val.Bytes()) } + + s.setOriginStorage(key, value) + return value, nil } } @@ -378,7 +389,7 @@ func (s *StateObject) SetState(db Database, key, value common.Hash) error { return nil } // when state insert, check if valid to insert new state - if s.db.enableAccStateEpoch(true, s.address) && prev != (common.Hash{}) { + if s.db.enableAccStateEpoch(true, s.address) && prev == (common.Hash{}) { _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { @@ -606,7 +617,7 @@ func (s *StateObject) CommitTrie(db Database) (int, error) { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } root, committed, err := s.trie.Commit(nil) - log.Info("obj CommitTrie", "addr", s.address, "root", root) + log.Info("obj CommitTrie", "addr", s.address, "root", root, "err", err) if err == nil { s.data.Root = root } diff --git a/core/state/statedb.go b/core/state/statedb.go index 86558b2e10..a40325328e 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -886,6 +886,8 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { preimages: make(map[common.Hash][]byte, len(s.preimages)), journal: newJournal(), hasher: crypto.NewKeccakState(), + targetEpoch: s.targetEpoch, + targetBlk: s.targetBlk, } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { @@ -945,6 +947,11 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { // know that they need to explicitly terminate an active copy). state.prefetcher = state.prefetcher.copy() } + + if s.shadowNodeDB != nil { + state.shadowNodeDB = s.shadowNodeDB + } + if s.snaps != nil { // In order for the miner to be able to use and make additions // to the snapshot tree, we need to copy that aswell. @@ -1645,12 +1652,12 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } + log.Info("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) if s.shadowNodeDB != nil { if err := s.shadowNodeDB.Commit(s.targetBlk, root); err != nil { return common.Hash{}, nil, err } } - log.Info("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) return root, diffLayer, nil } diff --git a/trie/database.go b/trie/database.go index da2aad5108..52469352cf 100644 --- a/trie/database.go +++ b/trie/database.go @@ -28,7 +28,6 @@ import ( "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -102,10 +101,8 @@ type Database struct { // in the same cache fields). type rawNode []byte -func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } -func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawNode) setEpoch(epcoh types.StateEpoch) { panic("this should never end up in a live trie") } -func (n rawNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } +func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } func (n rawNode) EncodeRLP(w io.Writer) error { _, err := w.Write(n) @@ -123,10 +120,6 @@ type rawFullNode [17]node func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawFullNode) setEpoch(epcoh types.StateEpoch) { - panic("this should never end up in a live trie") -} -func (n rawFullNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } func (n rawFullNode) nodeType() int { return rawFullNodeType @@ -148,10 +141,6 @@ type rawShortNode struct { func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } -func (n rawShortNode) setEpoch(epoch types.StateEpoch) { - panic("this should never end up in a live trie") -} -func (n rawShortNode) getEpoch() types.StateEpoch { panic("this should never end up in a live trie") } func (n rawShortNode) nodeType() int { return rawShortNodeType diff --git a/trie/node.go b/trie/node.go index 4ead424510..326b6e9fa7 100644 --- a/trie/node.go +++ b/trie/node.go @@ -47,8 +47,6 @@ type node interface { encode(w rlp.EncoderBuffer) fstring(string) string nodeType() int - setEpoch(epoch types.StateEpoch) - getEpoch() types.StateEpoch } type ( @@ -135,13 +133,8 @@ func (n valueNode) String() string { return n.fstring("") } func (n *fullNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } func (n *shortNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } -func (n hashNode) setEpoch(epoch types.StateEpoch) {} -func (n valueNode) setEpoch(epoch types.StateEpoch) {} - -func (n *fullNode) getEpoch() types.StateEpoch { return n.epoch } -func (n *shortNode) getEpoch() types.StateEpoch { return n.epoch } -func (n hashNode) getEpoch() types.StateEpoch { return 0 } -func (n valueNode) getEpoch() types.StateEpoch { return 0 } +func (n *fullNode) getEpoch() types.StateEpoch { return n.epoch } +func (n *shortNode) getEpoch() types.StateEpoch { return n.epoch } func (n *fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) diff --git a/trie/trie.go b/trie/trie.go index b1b41713c2..4d029849be 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -68,12 +68,14 @@ type Trie struct { // actually unhashed nodes unhashed int - // fields for shadow tree and rootNode + // fields for shadow tree & state epoch withShadowNodes bool currentEpoch types.StateEpoch - shadowTreeRoot common.Hash - rootEpoch types.StateEpoch - originTrieRoot common.Hash + + // fields for rootNode + shadowTreeRoot common.Hash + rootEpoch types.StateEpoch + trieRoot common.Hash } // newFlag returns the cache flag value for a newly created node. @@ -126,7 +128,7 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa withShadowNodes: useShadowTree, shadowTreeRoot: rootNode.ShadowTreeRoot, rootEpoch: rootNode.Epoch, - originTrieRoot: rootNode.TrieRoot, + trieRoot: rootNode.TrieRoot, } if rootNode.TrieRoot != (common.Hash{}) && rootNode.TrieRoot != emptyRoot { root, err := trie.resolveHash(rootNode.TrieRoot[:], nil) @@ -164,11 +166,7 @@ func (t *Trie) TryGet(key []byte) (value []byte, err error) { var newroot node var didResolve bool if t.withShadowNodes { - var nextEpoch types.StateEpoch - if t.root != nil { - nextEpoch = t.root.getEpoch() - } - value, newroot, didResolve, err = t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, nextEpoch, false) + value, newroot, didResolve, err = t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), false) } else { value, newroot, didResolve, err = t.tryGet(t.root, keybytesToHex(key), 0) } @@ -180,11 +178,7 @@ func (t *Trie) TryGet(key []byte) (value []byte, err error) { } func (t *Trie) TryGetAndUpdateEpoch(key []byte) ([]byte, error) { - var nextEpoch types.StateEpoch - if t.root != nil { - nextEpoch = t.root.getEpoch() - } - value, newroot, didResolve, err := t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, nextEpoch, true) + value, newroot, didResolve, err := t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), true) if err == nil && didResolve { t.root = newroot @@ -240,7 +234,7 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S return nil, n, false, nil } // node is expired - if expired, err := t.nodeExpired(n, key[:pos]); expired { + if expired, err := t.nodeExpired(n, epoch, key[:pos]); expired { return nil, n, false, err } @@ -257,7 +251,7 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S return value, n, didResolve, err case *fullNode: // full node is expired - if expired, err := t.nodeExpired(n, key[:pos]); expired { + if expired, err := t.nodeExpired(n, epoch, key[:pos]); expired { return nil, n, false, err } // child node is expired @@ -401,18 +395,15 @@ func (t *Trie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { func (t *Trie) TryUpdate(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) - var nextEpoch types.StateEpoch - if t.root != nil { - nextEpoch = t.root.getEpoch() - } + rootEpoch := t.getRootEpoch() if len(value) != 0 { - _, n, err := t.insert(t.root, nil, k, valueNode(value), nextEpoch) + _, n, err := t.insert(t.root, nil, k, valueNode(value), rootEpoch) if err != nil { return err } t.root = n } else { - _, n, err := t.delete(t.root, nil, k, nextEpoch) + _, n, err := t.delete(t.root, nil, k, rootEpoch) if err != nil { return err } @@ -431,7 +422,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE switch n := n.(type) { case *shortNode: if t.withShadowNodes { - if expired, err := t.nodeExpired(n, prefix); expired { + if expired, err := t.nodeExpired(n, epoch, prefix); expired { return false, n, err } } @@ -475,7 +466,7 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE case *fullNode: if t.withShadowNodes { // this full node is expired, return err - if expired, err := t.nodeExpired(n, prefix); expired { + if expired, err := t.nodeExpired(n, epoch, prefix); expired { return false, n, err } // else, set its epoch to current epoch. @@ -537,11 +528,7 @@ func (t *Trie) Delete(key []byte) { func (t *Trie) TryDelete(key []byte) error { t.unhashed++ k := keybytesToHex(key) - var nextEpoch types.StateEpoch - if t.root != nil { - nextEpoch = t.root.getEpoch() - } - _, n, err := t.delete(t.root, nil, k, nextEpoch) + _, n, err := t.delete(t.root, nil, k, t.getRootEpoch()) if err != nil { return err } @@ -556,7 +543,7 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, switch n := n.(type) { case *shortNode: if t.withShadowNodes { - if expired, err := t.nodeExpired(n, prefix); expired { + if expired, err := t.nodeExpired(n, epoch, prefix); expired { return false, n, err } n.setEpoch(t.currentEpoch) @@ -592,7 +579,7 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, case *fullNode: if t.withShadowNodes { // this full node is expired, return err - if expired, err := t.nodeExpired(n, prefix); expired { + if expired, err := t.nodeExpired(n, epoch, prefix); expired { return false, n, err } // else, set its epoch to current epoch. @@ -787,12 +774,12 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } -func (t *Trie) nodeExpired(n node, prefix []byte) (bool, error) { - if types.EpochExpired(n.getEpoch(), t.currentEpoch) { +func (t *Trie) nodeExpired(n node, epoch types.StateEpoch, prefix []byte) (bool, error) { + if types.EpochExpired(epoch, t.currentEpoch) { return true, &ExpiredNodeError{ ExpiredNode: n, Path: prefix, - Epoch: n.getEpoch(), + Epoch: epoch, } } return false, nil @@ -814,7 +801,7 @@ func (t *Trie) Hash() common.Hash { if shadowTreeRoot != nil { newShadowTreeRoot = *shadowTreeRoot } - rn := newRootNode(t.root.getEpoch(), newRootHash, newShadowTreeRoot) + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) return rn.cachedHash } return newRootHash @@ -845,7 +832,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { newShadowTreeRoot = *shadowTreeRoot } // commit shadow nodes after ShadowHash in Commit - if err = t.commitShadowNodes(t.root, nil, t.root.getEpoch()); err != nil { + if err = t.commitShadowNodes(t.root, nil, t.getRootEpoch()); err != nil { return common.Hash{}, 0, err } } @@ -893,6 +880,10 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if err != nil { return common.Hash{}, 0, err } + // update root & root node + t.rootEpoch = t.getRootEpoch() + t.shadowTreeRoot = newShadowTreeRoot + t.trieRoot = newRootHash t.root = newRoot log.Info("withShadowNodes trie commit", "rootNodeHash", rootNodeHash, "newRootHash", newRootHash, "newShadowTreeRoot", newShadowTreeRoot) return rootNodeHash, committed, nil @@ -941,7 +932,7 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err // Revive trie with each proof nub for _, nub := range proof { path := []byte{} - rootExpired, _ := t.nodeExpired(t.root, nil) + rootExpired, _ := t.nodeExpired(t.root, t.getRootEpoch(), nil) newNode, didRevive, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub, t.currentEpoch, path, rootExpired) if didRevive && err == nil { successNubs = append(successNubs, nub) @@ -980,9 +971,9 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE return nil, false, fmt.Errorf("hash values does not match") } - nub.n1.setEpoch(t.currentEpoch) + tryUpdateNodeEpoch(nub.n1, t.currentEpoch) if nub.n2 != nil { - nub.n2.setEpoch(t.currentEpoch) + tryUpdateNodeEpoch(nub.n2, t.currentEpoch) if n1, ok := nub.n1.(*shortNode); ok { // n2 can only be followed by a short node n1.Val = nub.n2 } else { @@ -1096,7 +1087,7 @@ func (t *Trie) ShadowHash() (*common.Hash, error) { } h := newHasher(true) defer returnHasherToPool(h) - return t.shadowHash(t.root, h, nil, t.root.getEpoch()) + return t.shadowHash(t.root, h, nil, t.getRootEpoch()) } // shadowHash calculate node's shadow node hash, recalculate needn't a copy @@ -1195,13 +1186,27 @@ func (t *Trie) commitShadowNodes(origin node, prefix []byte, epoch types.StateEp } func (t *Trie) storeRootNode(newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { - rn := newRootNode(t.root.getEpoch(), newRootHash, newShadowTreeRoot) + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) if err := t.sndb.Put(ShadowTreeRootNodePath, rn.cachedEnc); err != nil { return common.Hash{}, err } return rn.cachedHash, nil } +// getRootEpoch parse root and resolve its epoch, when root is hashNode, +// the epoch should be t.rootEpoch, not default types.StateEpoch0 +func (t *Trie) getRootEpoch() types.StateEpoch { + ret := t.rootEpoch + switch n := t.root.(type) { + case *shortNode: + ret = n.getEpoch() + case *fullNode: + ret = n.getEpoch() + } + + return ret +} + func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { expectHash := common.BytesToHash(root[:]) @@ -1250,3 +1255,12 @@ func hexToSuffixCompact(hex []byte) []byte { decodeNibbles(hex, buf[:len(buf)-1]) return buf } + +func tryUpdateNodeEpoch(origin node, epoch types.StateEpoch) { + switch n := origin.(type) { + case *shortNode: + n.setEpoch(epoch) + case *fullNode: + n.setEpoch(epoch) + } +} From 06b247d25f7b885fbff306d97af39ada7626849d Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Wed, 17 May 2023 14:36:38 +0800 Subject: [PATCH 43/51] signer: fix txpool, codec bug with new ReviveTx & BEP215Signer; state/state_object: recall dirty revive trie for accuracy expired info; vm/evm: fix sstore calculate expired state gas bug; ethapi/api: fix EstimateGasAndReviveState resolve witness issues; trie/trie: opt shadow hash calculation, recalloect branch node epoch map; trie/trie: keep as hashNode when meet expired node, opt expired node checking; trie/trie: fix expired check bug when child is nil; --- accounts/abi/bind/backends/simulated.go | 3 + accounts/external/backend.go | 2 +- consensus/parlia/parlia.go | 2 +- core/state/errors.go | 2 +- core/state/state_object.go | 12 ++- core/state/statedb.go | 4 +- core/tx_pool.go | 5 +- core/types/receipt.go | 7 +- core/types/transaction_signing.go | 5 + core/vm/evm.go | 8 -- core/vm/gas_table.go | 10 ++ core/vm/operations_acl.go | 6 ++ internal/ethapi/api.go | 121 ++++++++++++---------- trie/errors.go | 12 ++- trie/hasher.go | 11 -- trie/node.go | 5 +- trie/shadow_node.go | 4 +- trie/trie.go | 132 ++++++++++++------------ trie/trie_test.go | 62 ++++++++++- 19 files changed, 251 insertions(+), 162 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index fddbbfdc53..0d195851bb 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -670,6 +670,9 @@ func (b *SimulatedBackend) EstimateGasAndReviveState(ctx context.Context, call e if err != nil { return true, nil, isExpiredError, err } + if len(proof) == 0 { + continue + } addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ RootKeyHex: stateErr.Path, Proof: proof, diff --git a/accounts/external/backend.go b/accounts/external/backend.go index 07062495de..af280fda8a 100644 --- a/accounts/external/backend.go +++ b/accounts/external/backend.go @@ -216,7 +216,7 @@ func (api *ExternalSigner) SignTx(account accounts.Account, tx *types.Transactio From: common.NewMixedcaseAddress(account.Address), } switch tx.Type() { - case types.LegacyTxType, types.AccessListTxType: + case types.LegacyTxType, types.ReviveStateTxType, types.AccessListTxType: args.GasPrice = (*hexutil.Big)(tx.GasPrice()) case types.DynamicFeeTxType: args.MaxFeePerGas = (*hexutil.Big)(tx.GasFeeCap()) diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index 161141bca5..87bbcf3bf3 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -262,7 +262,7 @@ func New( signatures: signatures, validatorSetABI: vABI, slashABI: sABI, - signer: types.NewEIP155Signer(chainConfig.ChainID), + signer: types.NewBEP215Signer(chainConfig.ChainID), } return c diff --git a/core/state/errors.go b/core/state/errors.go index 7a47f3f81c..37820c0d9e 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -52,5 +52,5 @@ func NewInsertExpiredStateError(addr common.Address, key common.Hash, err *trie. } func (e *ExpiredStateError) Error() string { - return fmt.Sprintf("Access expired state, addr: %v, key: %v, epoch: %v, reason: %v", e.Addr, e.Key, e.Epoch, e.reason) + return fmt.Sprintf("Access expired state, addr: %v, path: %v, key: %v, epoch: %v, reason: %v", e.Addr, e.Path, e.Key, e.Epoch, e.reason) } diff --git a/core/state/state_object.go b/core/state/state_object.go index d30f3aa810..f1f29c97c6 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -330,6 +330,11 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha } if err == nil && s.db.enableAccStateEpoch(true, s.address) && types.EpochExpired(sv.Epoch, s.targetEpoch) { + // query from dirty revive trie, got the newest expired info + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return common.Hash{}, NewExpiredStateError(s.address, key, enErr) + } return common.Hash{}, NewSnapExpiredStateError(s.address, key, sv.Epoch) } value.SetBytes(sv.Val.Bytes()) @@ -351,6 +356,11 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha } if err != nil { if enErr, ok := err.(*trie.ExpiredNodeError); ok { + // query from dirty revive trie, got the newest expired info + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return common.Hash{}, NewExpiredStateError(s.address, key, enErr) + } return common.Hash{}, NewExpiredStateError(s.address, key, enErr) } s.setError(err) @@ -617,7 +627,7 @@ func (s *StateObject) CommitTrie(db Database) (int, error) { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } root, committed, err := s.trie.Commit(nil) - log.Info("obj CommitTrie", "addr", s.address, "root", root, "err", err) + log.Debug("obj CommitTrie", "addr", s.address, "root", root, "err", err) if err == nil { s.data.Root = root } diff --git a/core/state/statedb.go b/core/state/statedb.go index a40325328e..ae4cd360be 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -162,7 +162,7 @@ func NewWithStateEpoch(config *params.ChainConfig, targetBlock *big.Int, root co return nil, err } - log.Info("NewWithStateEpoch", "targetBlock", targetBlock, "targetEpoch", targetEpoch, "root", root) + log.Debug("NewWithStateEpoch", "targetBlock", targetBlock, "targetEpoch", targetEpoch, "root", root) // init target block and shadowNodeRW stateDB.targetBlk = targetBlock stateDB.shadowNodeDB, err = trie.NewShadowNodeDatabase(sntree, targetBlock, root) @@ -1652,7 +1652,7 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } - log.Info("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) + log.Debug("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) if s.shadowNodeDB != nil { if err := s.shadowNodeDB.Commit(s.targetBlk, root); err != nil { return common.Hash{}, nil, err diff --git a/core/tx_pool.go b/core/tx_pool.go index 896eb9ad41..13c6602f8f 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -642,7 +642,10 @@ func (pool *TxPool) local() map[common.Address]types.Transactions { func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { // Accept only legacy transactions until EIP-2718/2930 activates. if !pool.eip2718 && tx.Type() != types.LegacyTxType { - return ErrTxTypeNotSupported + // If isElwood, accept types.ReviveStateTxType + if !(pool.isElwood || tx.Type() == types.ReviveStateTxType) { + return ErrTxTypeNotSupported + } } // Reject dynamic fee transactions until EIP-1559 activates. if !pool.eip1559 && tx.Type() == types.DynamicFeeTxType { diff --git a/core/types/receipt.go b/core/types/receipt.go index 90c9dbcd02..27b81aa95e 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -198,7 +198,7 @@ func (r *Receipt) DecodeRLP(s *rlp.Stream) error { return errEmptyTypedReceipt } r.Type = b[0] - if r.Type == AccessListTxType || r.Type == DynamicFeeTxType { + if r.Type == AccessListTxType || r.Type == DynamicFeeTxType || r.Type == ReviveStateTxType { var dec receiptRLP if err := rlp.DecodeBytes(b[1:], &dec); err != nil { return err @@ -234,7 +234,7 @@ func (r *Receipt) decodeTyped(b []byte) error { return errEmptyTypedReceipt } switch b[0] { - case DynamicFeeTxType, AccessListTxType: + case DynamicFeeTxType, AccessListTxType, ReviveStateTxType: var data receiptRLP err := rlp.DecodeBytes(b[1:], &data) if err != nil { @@ -407,6 +407,9 @@ func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) { case DynamicFeeTxType: w.WriteByte(DynamicFeeTxType) rlp.Encode(w, data) + case ReviveStateTxType: + w.WriteByte(ReviveStateTxType) + rlp.Encode(w, data) default: // For unsupported types, write nothing. Since this is for // DeriveSha, the error will be caught matching the derived hash diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 4499c7438f..184b91c063 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -40,6 +40,8 @@ type sigCache struct { func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { var signer Signer switch { + case config.IsClaude(blockNumber): + signer = NewBEP215Signer(config.ChainID) case config.IsLondon(blockNumber): signer = NewLondonSigner(config.ChainID) case config.IsBerlin(blockNumber): @@ -63,6 +65,9 @@ func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { // have the current block number available, use MakeSigner instead. func LatestSigner(config *params.ChainConfig) Signer { if config.ChainID != nil { + if config.ClaudeBlock != nil { + return NewBEP215Signer(config.ChainID) + } if config.LondonBlock != nil { return NewLondonSigner(config.ChainID) } diff --git a/core/vm/evm.go b/core/vm/evm.go index f59d23526b..3408ba40ab 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -22,8 +22,6 @@ import ( "sync/atomic" "time" - "github.com/ethereum/go-ethereum/log" - "github.com/holiman/uint256" "github.com/ethereum/go-ethereum/common" @@ -264,13 +262,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // TODO: consider clearing up unused snapshots: //} else { // evm.StateDB.DiscardSnapshot(snapshot) - - errors := evm.Errors() - for _, e := range errors { - log.Error("call err", "addr", addr, "from", caller.Address(), "err", e.Error()) - } } - return ret, gas, err } diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 083f8c7f7b..89c76e5cf7 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -19,6 +19,8 @@ package vm import ( "errors" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/params" @@ -140,6 +142,10 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi } original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } return 0, err } if original == current { @@ -201,6 +207,10 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } return 0, err } if original == current { diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go index 3907f8b7b3..7570476133 100644 --- a/core/vm/operations_acl.go +++ b/core/vm/operations_acl.go @@ -19,6 +19,8 @@ package vm import ( "errors" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/params" @@ -61,6 +63,10 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } return 0, err } if original == current { diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 6cea6dd395..2b2a9c16c4 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -957,15 +957,15 @@ func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash return result, nil } -func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, []*vm.EVMError, error) { +func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, []*vm.EVMError, *state.StateDB, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { - return nil, nil, err + return nil, nil, nil, err } if err := overrides.Apply(state); err != nil { - return nil, nil, err + return nil, nil, nil, err } // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. @@ -982,11 +982,11 @@ func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNr // Get a new instance of the EVM. msg, err := args.ToMessage(globalGasCap, header.BaseFee) if err != nil { - return nil, nil, err + return nil, nil, nil, err } evm, vmError, err := b.GetEVM(ctx, msg, state, header, &vm.Config{NoBaseFee: true}) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // Wait for the context to be done and cancel the evm. Even if the // EVM has finished, cancelling may be done (repeatedly) @@ -999,17 +999,17 @@ func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNr gp := new(core.GasPool).AddGas(math.MaxUint64) result, err := core.ApplyMessage(evm, msg, gp) if err := vmError(); err != nil { - return nil, evm.ErrorCollection, err + return nil, evm.ErrorCollection, nil, err } // If the timer caused an abort, return an appropriate error message if evm.Cancelled() { - return nil, evm.ErrorCollection, fmt.Errorf("execution aborted (timeout = %v)", timeout) + return nil, evm.ErrorCollection, nil, fmt.Errorf("execution aborted (timeout = %v)", timeout) } if err != nil { - return result, evm.ErrorCollection, fmt.Errorf("err: %w (supplied gas %d)", err, msg.Gas()) + return result, evm.ErrorCollection, nil, fmt.Errorf("err: %w (supplied gas %d)", err, msg.Gas()) } - return result, evm.ErrorCollection, nil + return result, evm.ErrorCollection, state, nil } func newRevertError(result *core.ExecutionResult) *revertError { @@ -1196,20 +1196,18 @@ type EstimateGasAndReviveStateResult struct { ReviveWitness []types.ReviveWitness `json:"reviveWitness"` } -// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// DoEstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the // given transaction against the current pending block. It also attempts to revive expired // storage trie and returns the revive witness list. func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, gasCap uint64) (*EstimateGasAndReviveStateResult, error) { var result EstimateGasAndReviveStateResult - var stateDb *state.StateDB // Initialize witnessList - var witnessList []types.ReviveWitness - if args.WitnessList != nil { - witnessList = *args.WitnessList + if args.WitnessList == nil { + args.WitnessList = (*types.WitnessList)(&[]types.ReviveWitness{}) } - witLen := len(witnessList) + witLen := len(*args.WitnessList) // Binary search the gas requirement, as it may be higher than the amount used var ( @@ -1281,24 +1279,33 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio cap = hi // Create a helper to check if a gas allowance results in an executable transaction - executable := func(gas uint64, witnessList []types.ReviveWitness) (bool, *core.ExecutionResult, bool, error) { + expiedNodeCache := make(map[common.Address]map[string]bool) + executable := func(gas uint64) (bool, *core.ExecutionResult, bool, error) { args.Gas = (*hexutil.Uint64)(&gas) - result, evmErrors, err := DoCallExpired(ctx, b, args, blockNrOrHash, nil, 0, gasCap) // TODO (asyukii): Use a different call function to return EVM errors + result, evmErrors, callState, err := DoCallExpired(ctx, b, args, blockNrOrHash, nil, 0, gasCap) // TODO (asyukii): Use a different call function to return EVM errors // Create MPTProof - isExpiredError := false + resolveWitness := false if len(evmErrors) > 0 { addressToProofMap := make(map[common.Address][]types.MPTProof) for _, evmErr := range evmErrors { if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { - isExpiredError = true - proof, err := stateDb.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) - if proof == nil { - continue + if _, ec := expiedNodeCache[stateErr.Addr]; !ec { + expiedNodeCache[stateErr.Addr] = make(map[string]bool) } + if expiedNodeCache[stateErr.Addr][string(stateErr.Path)] { + // revive not works, just return + return true, nil, resolveWitness, stateErr + } + + expiedNodeCache[stateErr.Addr][string(stateErr.Path)] = true + proof, err := callState.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) if err != nil { - return true, nil, isExpiredError, err + return true, nil, false, err + } + if len(proof) == 0 { + continue } addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ RootKeyHex: stateErr.Path, @@ -1317,7 +1324,7 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio // Encode StorageTrieWitness enc, err := rlp.EncodeToBytes(storageTrieWitness) if err != nil { - return true, nil, isExpiredError, err + return true, nil, resolveWitness, err } // Create a ReviveWitness reviveWitness := types.ReviveWitness{ @@ -1325,29 +1332,30 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio Data: enc, } // Append to witness list - witnessList = append(witnessList, reviveWitness) + *args.WitnessList = append(*args.WitnessList, reviveWitness) + resolveWitness = true } } if err != nil { if errors.Is(err, core.ErrIntrinsicGas) { - return true, nil, isExpiredError, nil // Special case, raise gas limit + return true, nil, resolveWitness, nil // Special case, raise gas limit } - return true, nil, isExpiredError, err // Bail out + return true, nil, resolveWitness, err // Bail out } - return result.Failed(), result, isExpiredError, nil + return result.Failed(), result, resolveWitness, nil } // Execute the binary search and hone in on an executable gas limit for lo+1 < hi { mid := (hi + lo) / 2 - failed, _, isExpiredError, err := executable(mid, witnessList) + failed, _, resolveWitness, err := executable(mid) - if isExpiredError { - if witLen == len(witnessList) { + if resolveWitness { + if witLen == len(*args.WitnessList) { // If witnessList is not updated, it means that the proofs are not // sufficient to revive the state. return nil, fmt.Errorf("cannot generate enough proofs to revive the state") } - witLen = len(witnessList) + witLen = len(*args.WitnessList) continue } @@ -1365,7 +1373,7 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio } // Reject the transaction as invalid if it still fails at the highest allowance if hi == cap { - failed, result, _, err := executable(hi, witnessList) + failed, result, _, err := executable(hi) if err != nil { return nil, err } @@ -1382,7 +1390,7 @@ func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args Transactio } result = EstimateGasAndReviveStateResult{ Hex: hexutil.Uint64(hi), - ReviveWitness: witnessList, + ReviveWitness: *args.WitnessList, } return &result, nil } @@ -1751,25 +1759,26 @@ func (s *PublicBlockChainAPI) rpcMarshalBlock(ctx context.Context, b *types.Bloc // RPCTransaction represents a transaction that will serialize to the RPC representation of a transaction type RPCTransaction struct { - BlockHash *common.Hash `json:"blockHash"` - BlockNumber *hexutil.Big `json:"blockNumber"` - From common.Address `json:"from"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice *hexutil.Big `json:"gasPrice"` - GasFeeCap *hexutil.Big `json:"maxFeePerGas,omitempty"` - GasTipCap *hexutil.Big `json:"maxPriorityFeePerGas,omitempty"` - Hash common.Hash `json:"hash"` - Input hexutil.Bytes `json:"input"` - Nonce hexutil.Uint64 `json:"nonce"` - To *common.Address `json:"to"` - TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` - Value *hexutil.Big `json:"value"` - Type hexutil.Uint64 `json:"type"` - Accesses *types.AccessList `json:"accessList,omitempty"` - ChainID *hexutil.Big `json:"chainId,omitempty"` - V *hexutil.Big `json:"v"` - R *hexutil.Big `json:"r"` - S *hexutil.Big `json:"s"` + BlockHash *common.Hash `json:"blockHash"` + BlockNumber *hexutil.Big `json:"blockNumber"` + From common.Address `json:"from"` + Gas hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + GasFeeCap *hexutil.Big `json:"maxFeePerGas,omitempty"` + GasTipCap *hexutil.Big `json:"maxPriorityFeePerGas,omitempty"` + Hash common.Hash `json:"hash"` + Input hexutil.Bytes `json:"input"` + Nonce hexutil.Uint64 `json:"nonce"` + To *common.Address `json:"to"` + TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` + Value *hexutil.Big `json:"value"` + Type hexutil.Uint64 `json:"type"` + Accesses *types.AccessList `json:"accessList,omitempty"` + Witness *types.WitnessList `json:"witnessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` } // newRPCTransaction returns a transaction that will serialize to the RPC @@ -1798,6 +1807,10 @@ func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber result.TransactionIndex = (*hexutil.Uint64)(&index) } switch tx.Type() { + case types.ReviveStateTxType: + wl := tx.WitnessList() + result.Witness = &wl + result.ChainID = (*hexutil.Big)(tx.ChainId()) case types.AccessListTxType: al := tx.AccessList() result.Accesses = &al @@ -2132,7 +2145,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceiptsByBlockNumber(ctx conte tx := txs[idx] var signer types.Signer = types.FrontierSigner{} if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) + signer = types.NewBEP215Signer(tx.ChainId()) } from, _ := types.Sender(signer, tx) diff --git a/trie/errors.go b/trie/errors.go index c0121dee6a..e11ba816b8 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -37,9 +37,15 @@ func (err *MissingNodeError) Error() string { } type ExpiredNodeError struct { - ExpiredNode node // node of the expired node - Path []byte // hex-encoded path to the expired node - Epoch types.StateEpoch + Path []byte // hex-encoded path to the expired node + Epoch types.StateEpoch +} + +func NewExpiredNodeError(path []byte, epoch types.StateEpoch) error { + return &ExpiredNodeError{ + Path: path, + Epoch: epoch, + } } func (err *ExpiredNodeError) Error() string { diff --git a/trie/hasher.go b/trie/hasher.go index 589f4417a5..edad7f7aa7 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -141,17 +141,6 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached return collapsed, cached } -// shadowExtendNodeToHash hash shadowExtendNode -func (h *hasher) shadowExtendNodeToHash(n *shadowExtensionNode) *common.Hash { - if n.ShadowHash == nil { - return nil - } - - n.encode(h.encbuf) - enc := h.encodedBytes() - return h.hashCommon(enc) -} - // shadowFullNodeToHash hash shadowBranchNode func (h *hasher) shadowBranchNodeToHash(n *shadowBranchNode) *common.Hash { n.encode(h.encbuf) diff --git a/trie/node.go b/trie/node.go index 326b6e9fa7..e1849ea246 100644 --- a/trie/node.go +++ b/trie/node.go @@ -99,9 +99,8 @@ func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.Sta childEpoch := n.GetChildEpoch(index) if types.EpochExpired(childEpoch, currentEpoch) { return true, &ExpiredNodeError{ - ExpiredNode: n.Children[index], - Path: prefix, - Epoch: childEpoch, + Path: prefix, + Epoch: childEpoch, } } return false, nil diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 04b761acc8..ed868b19e8 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -238,13 +238,13 @@ func NewShadowNodeDatabase(tree *ShadowNodeSnapTree, number *big.Int, blockRoot // try using default snap if snap = tree.Snapshot(emptyRoot); snap == nil { // open read only history - log.Info("NewShadowNodeDatabase use RO database", "number", number, "root", blockRoot) + log.Debug("NewShadowNodeDatabase use RO database", "number", number, "root", blockRoot) return &ShadowNodeStorageRO{ diskdb: tree.DB(), number: number, }, nil } - log.Info("NewShadowNodeDatabase use default database", "number", number, "root", blockRoot) + log.Debug("NewShadowNodeDatabase use default database", "number", number, "root", blockRoot) } return &ShadowNodeStorageRW{ snap: snap, diff --git a/trie/trie.go b/trie/trie.go index 4d029849be..084b005ede 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -118,7 +118,7 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa // only enable after first state expiry's hard fork if curEpoch > types.StateEpoch0 { useShadowTree = true - log.Info("withShadowNodes trie open", "rootNodeHash", rootNode.cachedHash, "RootHash", rootNode.TrieRoot, "ShadowTreeRoot", rootNode.ShadowTreeRoot) + log.Debug("withShadowNodes trie open", "rootNodeHash", rootNode.cachedHash, "RootHash", rootNode.TrieRoot, "ShadowTreeRoot", rootNode.ShadowTreeRoot) } trie := &Trie{ @@ -130,7 +130,13 @@ func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Databa rootEpoch: rootNode.Epoch, trieRoot: rootNode.TrieRoot, } + if rootNode.TrieRoot != (common.Hash{}) && rootNode.TrieRoot != emptyRoot { + // check root if expired, if expired just set as hashNode + if types.EpochExpired(rootNode.Epoch, curEpoch) { + trie.root = hashNode(rootNode.TrieRoot.Bytes()) + return trie, nil + } root, err := trie.resolveHash(rootNode.TrieRoot[:], nil) if err != nil { return nil, err @@ -223,6 +229,10 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.StateEpoch, updateEpoch bool) (value []byte, newnode node, didResolve bool, err error) { + if t.epochExpired(origNode, epoch) { + return nil, nil, false, NewExpiredNodeError(key[:pos], epoch) + } + switch n := (origNode).(type) { case nil: return nil, nil, false, nil @@ -233,10 +243,6 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S // key not found in trie return nil, n, false, nil } - // node is expired - if expired, err := t.nodeExpired(n, epoch, key[:pos]); expired { - return nil, n, false, err - } if updateEpoch { n.setEpoch(t.currentEpoch) @@ -250,16 +256,13 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S } return value, n, didResolve, err case *fullNode: - // full node is expired - if expired, err := t.nodeExpired(n, epoch, key[:pos]); expired { - return nil, n, false, err - } - // child node is expired - if expired, err := n.ChildExpired(key[:pos+1], int(key[pos]), t.currentEpoch); expired { - return nil, n, false, err - } - if updateEpoch { + // child node is expired + if n.Children[key[pos]] != nil { + if expired, err := n.ChildExpired(key[:pos+1], int(key[pos]), t.currentEpoch); expired { + return nil, n, false, err + } + } n.setEpoch(t.currentEpoch) n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, t.currentEpoch, true) @@ -413,6 +416,10 @@ func (t *Trie) TryUpdate(key, value []byte) error { } func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch) + } + if len(key) == 0 { if v, ok := n.(valueNode); ok { return !bytes.Equal(v, value.(valueNode)), value, nil @@ -421,11 +428,6 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE } switch n := n.(type) { case *shortNode: - if t.withShadowNodes { - if expired, err := t.nodeExpired(n, epoch, prefix); expired { - return false, n, err - } - } matchlen := prefixLen(key, n.Key) // If the whole key matches, keep this short node as is // and only update the value. @@ -465,10 +467,6 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE case *fullNode: if t.withShadowNodes { - // this full node is expired, return err - if expired, err := t.nodeExpired(n, epoch, prefix); expired { - return false, n, err - } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) // if inserting a new node to this full node, there is no need to check whether this child is expired. @@ -540,14 +538,12 @@ func (t *Trie) TryDelete(key []byte) error { // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch) + } + switch n := n.(type) { case *shortNode: - if t.withShadowNodes { - if expired, err := t.nodeExpired(n, epoch, prefix); expired { - return false, n, err - } - n.setEpoch(t.currentEpoch) - } matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { return false, n, nil // don't replace n on mismatch @@ -578,10 +574,6 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, case *fullNode: if t.withShadowNodes { - // this full node is expired, return err - if expired, err := t.nodeExpired(n, epoch, prefix); expired { - return false, n, err - } // else, set its epoch to current epoch. n.setEpoch(t.currentEpoch) // if child is expired, return err @@ -774,17 +766,6 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { return nil, &MissingNodeError{NodeHash: hash, Path: prefix} } -func (t *Trie) nodeExpired(n node, epoch types.StateEpoch, prefix []byte) (bool, error) { - if types.EpochExpired(epoch, t.currentEpoch) { - return true, &ExpiredNodeError{ - ExpiredNode: n, - Path: prefix, - Epoch: epoch, - } - } - return false, nil -} - // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { @@ -885,7 +866,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { t.shadowTreeRoot = newShadowTreeRoot t.trieRoot = newRootHash t.root = newRoot - log.Info("withShadowNodes trie commit", "rootNodeHash", rootNodeHash, "newRootHash", newRootHash, "newShadowTreeRoot", newShadowTreeRoot) + log.Debug("withShadowNodes trie commit", "rootNodeHash", rootNodeHash, "newRootHash", newRootHash, "newShadowTreeRoot", newShadowTreeRoot) return rootNodeHash, committed, nil } t.root = newRoot @@ -932,7 +913,7 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err // Revive trie with each proof nub for _, nub := range proof { path := []byte{} - rootExpired, _ := t.nodeExpired(t.root, t.getRootEpoch(), nil) + rootExpired := types.EpochExpired(t.getRootEpoch(), t.currentEpoch) newNode, didRevive, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub, t.currentEpoch, path, rootExpired) if didRevive && err == nil { successNubs = append(successNubs, nub) @@ -955,7 +936,6 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE // 2. the node must be expired // 3. the node must be a hash node // 4. the node hash must match the hash value of nub - if len(key) == 0 { if !isExpired { return nil, false, fmt.Errorf("key %v not found", key) @@ -984,11 +964,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE } if isExpired { // the node is expired but targeted node is not reached - return nil, false, &ExpiredNodeError{ - ExpiredNode: n, - Path: path, - Epoch: 0, // Set default value, will change later - } + return nil, false, NewExpiredNodeError(path, 0) // Set default value, will change later } switch n := n.(type) { @@ -1000,6 +976,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE if didRevive && err == nil { n = n.copy() n.Val = newNode + n.setEpoch(t.currentEpoch) } return n, didRevive, err case *fullNode: @@ -1009,6 +986,7 @@ func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateE if didRevive && err == nil { n = n.copy() n.Children[childIndex] = newNode + n.setEpoch(t.currentEpoch) n.UpdateChildEpoch(childIndex, t.currentEpoch) } @@ -1049,7 +1027,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, origin node, prefix []b case *shortNode: n.setEpoch(epoch) n.shadowNode.ShadowHash = nil - return t.resolveShadowNode(epoch, n.Val, append(prefix, n.Key...)) + return t.resolveShadowNode(epoch, n.Val, safeAppendBytes(prefix, n.Key...)) case *fullNode: n.setEpoch(epoch) val, err := t.sndb.Get(string(hexToSuffixCompact(prefix))) @@ -1068,7 +1046,7 @@ func (t *Trie) resolveShadowNode(epoch types.StateEpoch, origin node, prefix []b n.shadowNode = *tmp } for i := byte(0); i < BranchNodeLength-1; i++ { - if err := t.resolveShadowNode(n.shadowNode.EpochMap[i], n.Children[i], append(prefix, i)); err != nil { + if err := t.resolveShadowNode(n.shadowNode.EpochMap[i], n.Children[i], safeAppendBytes(prefix, i)); err != nil { return err } } @@ -1087,18 +1065,20 @@ func (t *Trie) ShadowHash() (*common.Hash, error) { } h := newHasher(true) defer returnHasherToPool(h) - return t.shadowHash(t.root, h, nil, t.getRootEpoch()) + sh, _, err := t.shadowHash(t.root, h, nil, t.getRootEpoch()) + return sh, err } // shadowHash calculate node's shadow node hash, recalculate needn't a copy -func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.StateEpoch) (*common.Hash, error) { +// epoch param only use in resolve hashNode +func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.StateEpoch) (*common.Hash, types.StateEpoch, error) { switch n := origin.(type) { case *shortNode: var err error - if n.shadowNode.ShadowHash, err = t.shadowHash(n.Val, h, append(prefix, n.Key...), epoch); err != nil { - return nil, err + if n.shadowNode.ShadowHash, _, err = t.shadowHash(n.Val, h, append(prefix, n.Key...), n.epoch); err != nil { + return nil, types.StateEpoch0, err } - return h.shadowExtendNodeToHash(&n.shadowNode), nil + return n.shadowNode.ShadowHash, n.epoch, nil case *fullNode: epochSelf := n.epoch epochMap := n.shadowNode.EpochMap @@ -1110,36 +1090,38 @@ func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.Sta } // skip expired node. if epochSelf >= epochMap[i]+2 { + n.shadowNode.EpochMap[i] = 0 continue } - subHash, err := t.shadowHash(child, h, append(prefix, i), epochMap[i]) + subHash, subEpoch, err := t.shadowHash(child, h, append(prefix, i), epochMap[i]) if err != nil { - return nil, err + return nil, types.StateEpoch0, err } if subHash != nil { hashList = append(hashList, subHash) } + n.shadowNode.EpochMap[i] = subEpoch } n.shadowNode.ShadowHash = h.shadowNodeHashListToHash(hashList) - return h.shadowBranchNodeToHash(&n.shadowNode), nil + return h.shadowBranchNodeToHash(&n.shadowNode), n.epoch, nil case valueNode: - return nil, nil + return nil, epoch, nil case hashNode: if t.db == nil || t.sndb == nil { - return nil, errors.New("ShadowHash db or sndb is nil") + return nil, types.StateEpoch0, errors.New("ShadowHash db or sndb is nil") } // resolve temporary, not add to trie rn, err := t.resolveHash(n, prefix) if err != nil { - return nil, err + return nil, types.StateEpoch0, err } if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { - return nil, err + return nil, types.StateEpoch0, err } return t.shadowHash(rn, h, prefix, epoch) default: - return nil, errors.New("cannot get shortNode's child shadow node") + return nil, types.StateEpoch0, errors.New("cannot get shortNode's child shadow node") } } @@ -1207,8 +1189,15 @@ func (t *Trie) getRootEpoch() types.StateEpoch { return ret } -func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { +func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool { + // when node is nil, skip epoch check + if !t.withShadowNodes || n == nil { + return false + } + return types.EpochExpired(epoch, t.currentEpoch) +} +func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { expectHash := common.BytesToHash(root[:]) val, err := sndb.Get(ShadowTreeRootNodePath) if err != nil { @@ -1264,3 +1253,10 @@ func tryUpdateNodeEpoch(origin node, epoch types.StateEpoch) { n.setEpoch(epoch) } } + +func safeAppendBytes(src []byte, added ...byte) []byte { + ret := make([]byte, len(src)+len(added)) + copy(ret, src) + copy(ret[len(src):], added) + return ret +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 3b456aa8fb..f3be7ce9d9 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -570,7 +570,9 @@ func getFullNodePrefixKeys(t *Trie, key []byte) [][]byte { } // Remove the first item in prefixKeys, which is the empty key - prefixKeys = prefixKeys[1:] + if len(prefixKeys) > 0 { + prefixKeys = prefixKeys[1:] + } return prefixKeys } @@ -850,6 +852,14 @@ func TestReviveValueAtFullNode(t *testing.T) { } } +// TODO add testing trie epoch update & expired +// case1: when meet err, update do not update epoch +// case2: when child is nil, do not check its epoch, default is 0 +// case3: when access/update/delete, update epoch correct, when node is not expired, then update it later +// case3: when access/update/delete, check expired node correct +// case4: root node epoch update and check correct +// case5: when node is expired, do not resolve it, safe for revive + func TestTrie_ShadowNodeRW_expired(t *testing.T) { database, tree := makeStorageTrieDatabase(t) storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) @@ -942,7 +952,7 @@ func TestTrie_ShadowHash(t *testing.T) { batchUpdateTrie(t, tr, []string{"a711355", "450", "a77d337", "100", "a7f9365", "110", "a77d397", "012"}) sh1, err := tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0x73325476298d27129c8b8d64e8d0abd66d6cc26601c9a012304170432ad3a00d"), *sh1) + assert.Equal(t, common.HexToHash("0xc752578873185d8b97bdf9e59c8178719e30a03515c7a791e779d4823bbb3fa4"), *sh1) // commit and shadow hash again newRoot, _, err := tr.Commit(nil) @@ -956,13 +966,49 @@ func TestTrie_ShadowHash(t *testing.T) { assert.NoError(t, err) sh1, err = tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0x73325476298d27129c8b8d64e8d0abd66d6cc26601c9a012304170432ad3a00d"), *sh1) + assert.Equal(t, common.HexToHash("0xc752578873185d8b97bdf9e59c8178719e30a03515c7a791e779d4823bbb3fa4"), *sh1) err = tr.TryUpdate(common.Hex2Bytes("a711355"), common.Hex2Bytes("800")) assert.NoError(t, err) sh1, err = tr.ShadowHash() assert.NoError(t, err) - assert.Equal(t, common.HexToHash("0x3e17653305330721f2a9792b510247a411efedf8d687fa3fb1ae32d2c4325511"), *sh1) + assert.Equal(t, common.HexToHash("0xa88d96fa4e1b7b4421198f965230b85e153a6453d1f43b97c0ad89feafa73dd6"), *sh1) +} + +func TestTrie_ShadowHash_case2(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, emptyRoot) + assert.NoError(t, err) + + tr, err := NewWithShadowNode(10, newRootNode(10, emptyRoot, emptyRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + + batchUpdateTrie(t, tr, []string{"223dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "450", "224dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "100", "233dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "110", "253dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "012"}) + + // commit and shadow hash again + newRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + err = storageDB.Commit(common.Big1, newRoot) + assert.NoError(t, err) + + storageDB, err = NewShadowNodeDatabase(tree, common.Big2, newRoot) + assert.NoError(t, err) + + sndb := storageDB.OpenStorage(contract1) + enc, err := sndb.Get(ShadowTreeRootNodePath) + assert.NoError(t, err) + r1, err := decodeRootNode(enc) + assert.NoError(t, err) + tr, err = NewWithShadowNode(11, r1, database, sndb) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("223dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("224dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("233dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("253dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) } func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { @@ -1403,6 +1449,14 @@ func TestCommitSequenceSmallRoot(t *testing.T) { } } +func TestSafeAppendBytes(t *testing.T) { + assert.Equal(t, safeAppendBytes(nil, []byte{1}...), append([]byte(nil), 1)) + assert.Equal(t, safeAppendBytes(nil, []byte{1, 2, 3, 5}...), append([]byte(nil), []byte{1, 2, 3, 5}...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}, []byte{5, 4, 3, 2, 1}...), append([]byte{1, 2, 3, 5}, []byte{5, 4, 3, 2, 1}...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}, []byte(nil)...), append([]byte{1, 2, 3, 5}, []byte(nil)...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}), append([]byte{1, 2, 3, 5})) +} + // BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie. // This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically, // storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple From a328169521fdf4dcc4c3804bf0a7690c0592c04d Mon Sep 17 00:00:00 2001 From: asyukii Date: Tue, 16 May 2023 23:42:05 +0800 Subject: [PATCH 44/51] feat(prune): inline prune expired nodes feat(prune): inline prune expired nodes fix dereference codes fix: raw node fix: add BEP-206 prune rules --- core/blockchain.go | 6 ++- core/blockchain_sethead_test.go | 6 ++- core/blockchain_test.go | 6 ++- core/state/snapshot/conversion.go | 2 +- eth/state_accessor.go | 3 +- eth/tracers/api.go | 6 ++- trie/database.go | 68 +++++++++++++++++++++++-------- trie/node_enc.go | 4 +- trie/stacktrie.go | 6 +-- trie/trie_test.go | 2 +- 10 files changed, 75 insertions(+), 34 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index f2e69e6e32..b76da2e8a8 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1115,8 +1115,9 @@ func (bc *BlockChain) Stop() { rawdb.WriteSafePointBlockNumber(bc.db, bc.CurrentBlock().NumberU64()) } } + currentEpoch := types.GetStateEpoch(bc.chainConfig, bc.CurrentBlock().Number()) for !bc.triegc.Empty() { - go triedb.Dereference(bc.triegc.PopItem().(common.Hash)) + go triedb.Dereference(bc.triegc.PopItem().(common.Hash), currentEpoch) } if size, _ := triedb.Size(); size != 0 { log.Error("Dangling trie nodes after full cleanup") @@ -1558,13 +1559,14 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. } } // Garbage collect anything below our required write retention + currentEpoch := types.GetStateEpoch(bc.chainConfig, bc.CurrentBlock().Number()) for !bc.triegc.Empty() { root, number := bc.triegc.Pop() if uint64(-number) > chosen { bc.triegc.Push(root, number) break } - go triedb.Dereference(root.(common.Hash)) + go triedb.Dereference(root.(common.Hash), currentEpoch) } } } diff --git a/core/blockchain_sethead_test.go b/core/blockchain_sethead_test.go index 7ee213b726..7df56f2a78 100644 --- a/core/blockchain_sethead_test.go +++ b/core/blockchain_sethead_test.go @@ -2017,10 +2017,12 @@ func testSetHead(t *testing.T, tt *rewindTest, snapshots bool) { } // Manually dereference anything not committed to not have to work with 128+ tries for _, block := range sideblocks { - chain.stateCache.TrieDB().Dereference(block.Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, block.Number()) + chain.stateCache.TrieDB().Dereference(block.Root(), blockEpoch) } for _, block := range canonblocks { - chain.stateCache.TrieDB().Dereference(block.Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, block.Number()) + chain.stateCache.TrieDB().Dereference(block.Root(), blockEpoch) } chain.stateCache.Purge() // Force run a freeze cycle diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 55e39133a4..5580c05acf 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -1785,8 +1785,10 @@ func TestTrieForkGC(t *testing.T) { } // Dereference all the recent tries and ensure no past trie is left in for i := 0; i < TestTriesInMemory; i++ { - chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root()) - chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, blocks[len(blocks)-1-i].Number()) + chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), blockEpoch) + forkEpoch := types.GetStateEpoch(params.TestChainConfig, forks[len(blocks)-1-i].Number()) + chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), forkEpoch) } if len(chain.stateCache.TrieDB().Nodes()) > 0 { t.Fatalf("stale tries still alive after garbase collection") diff --git a/core/state/snapshot/conversion.go b/core/state/snapshot/conversion.go index 250692422d..f7a79503a2 100644 --- a/core/state/snapshot/conversion.go +++ b/core/state/snapshot/conversion.go @@ -89,7 +89,7 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd } defer storageIt.Release() - hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) + hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) // This is where storage trie is generated. if err != nil { return common.Hash{}, err } diff --git a/eth/state_accessor.go b/eth/state_accessor.go index b7cf4c0db2..b9bd32a92e 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -53,6 +53,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state database state.Database report = true origin = block.NumberU64() + epoch = types.GetStateEpoch(eth.blockchain.Config(), block.Number()) ) // Check the live database first if we have the state fully available, use that. if checkLive { @@ -154,7 +155,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state } database.TrieDB().Reference(root, common.Hash{}) if parent != (common.Hash{}) { - database.TrieDB().Dereference(parent) + database.TrieDB().Dereference(parent, epoch) } parent = root } diff --git a/eth/tracers/api.go b/eth/tracers/api.go index 40aec6b3be..09b38fb53a 100644 --- a/eth/tracers/api.go +++ b/eth/tracers/api.go @@ -345,8 +345,10 @@ func (api *API) traceChain(ctx context.Context, start, end *types.Block, config } // clean out any derefs derefsMu.Lock() + numberBigInt := new(big.Int).SetUint64(number) + epoch := types.GetStateEpoch(api.backend.ChainConfig(), numberBigInt) for _, h := range derefTodo { - statedb.Database().TrieDB().Dereference(h) + statedb.Database().TrieDB().Dereference(h, epoch) } derefTodo = derefTodo[:0] derefsMu.Unlock() @@ -375,7 +377,7 @@ func (api *API) traceChain(ctx context.Context, start, end *types.Block, config // Release the parent state because it's already held by the tracer if parent != (common.Hash{}) { - trieDb.Dereference(parent) + trieDb.Dereference(parent, epoch) } // Prefer disk if the trie db memory grows too much s1, s2 := trieDb.Size() diff --git a/trie/database.go b/trie/database.go index 52469352cf..abdb6093a2 100644 --- a/trie/database.go +++ b/trie/database.go @@ -28,6 +28,7 @@ import ( "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -116,7 +117,10 @@ func (n rawNode) nodeType() int { // rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. -type rawFullNode [17]node +type rawFullNode struct { + children [BranchNodeLength]node + epoch types.StateEpoch `rlp:"-" json:"-"` +} func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } @@ -135,8 +139,9 @@ func (n rawFullNode) EncodeRLP(w io.Writer) error { // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. type rawShortNode struct { - Key []byte - Val node + Key []byte + Val node + epoch types.StateEpoch `rlp:"-" json:"-"` } func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } @@ -201,15 +206,26 @@ func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { } } +func (n *cachedNode) getEpoch() (types.StateEpoch, error) { + switch n := n.node.(type) { + case *rawShortNode: + return n.epoch, nil + case *rawFullNode: + return n.epoch, nil + default: + return 0, fmt.Errorf("unknown node type: %T", n) // TODO(asyukii): may never reach this case, consider panic + } +} + // forGatherChildren traverses the node hierarchy of a collapsed storage node and // invokes the callback for all the hashnode children. func forGatherChildren(n node, onChild func(hash common.Hash)) { switch n := n.(type) { case *rawShortNode: forGatherChildren(n.Val, onChild) - case rawFullNode: + case *rawFullNode: for i := 0; i < 16; i++ { - forGatherChildren(n[i], onChild) + forGatherChildren(n.children[i], onChild) } case hashNode: onChild(common.BytesToHash(n)) @@ -225,14 +241,17 @@ func simplifyNode(n node) node { switch n := n.(type) { case *shortNode: // Short nodes discard the flags and cascade - return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)} + return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val), epoch: n.epoch} case *fullNode: // Full nodes discard the flags and cascade - node := rawFullNode(n.Children) - for i := 0; i < len(node); i++ { - if node[i] != nil { - node[i] = simplifyNode(node[i]) + node := &rawFullNode{ + children: n.Children, + epoch: n.epoch, + } + for i := 0; i < len(node.children); i++ { + if node.children[i] != nil { + node.children[i] = simplifyNode(node.children[i]) } } return node @@ -259,7 +278,7 @@ func expandNode(hash hashNode, n node) node { }, } - case rawFullNode: + case *rawFullNode: // Full nodes need child expansion node := &fullNode{ flags: nodeFlag{ @@ -267,8 +286,8 @@ func expandNode(hash hashNode, n node) node { }, } for i := 0; i < len(node.Children); i++ { - if n[i] != nil { - node.Children[i] = expandNode(nil, n[i]) + if n.children[i] != nil { + node.Children[i] = expandNode(nil, n.children[i]) } } return node @@ -534,7 +553,7 @@ func (db *Database) reference(child common.Hash, parent common.Hash) { } // Dereference removes an existing reference from a root node. -func (db *Database) Dereference(root common.Hash) { +func (db *Database) Dereference(root common.Hash, epoch types.StateEpoch) { // Sanity check to ensure that the meta-root is not removed if root == (common.Hash{}) { log.Error("Attempted to dereference the trie cache meta root") @@ -544,7 +563,7 @@ func (db *Database) Dereference(root common.Hash) { defer db.lock.Unlock() nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now() - db.dereference(root, common.Hash{}) + db.dereference(root, common.Hash{}, epoch) db.gcnodes += uint64(nodes - len(db.dirties)) db.gcsize += storage - db.dirtiesSize @@ -558,8 +577,12 @@ func (db *Database) Dereference(root common.Hash) { "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize) } +func checkBEP206PruneRule(childEpoch, parentEpoch, currEpoch types.StateEpoch) bool { + return types.EpochExpired(childEpoch, currEpoch) && (parentEpoch >= childEpoch+2) +} + // dereference is the private locked version of Dereference. -func (db *Database) dereference(child common.Hash, parent common.Hash) { +func (db *Database) dereference(child common.Hash, parent common.Hash, epoch types.StateEpoch) { // Dereference the parent-child node := db.dirties[parent] @@ -570,6 +593,7 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { db.childrenSize -= (common.HashLength + 2) // uint16 counter } } + parentEpoch, getParentEpochErr := node.getEpoch() // If the child does not exist, it's a previously committed node. node, ok := db.dirties[child] if !ok { @@ -583,7 +607,15 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { // no problem in itself, but don't make maxint parents out of it. node.parents-- } - if node.parents == 0 { + childEpoch, getChildEpochErr := node.getEpoch() + canPruneExpired := false + if getParentEpochErr == nil && getChildEpochErr == nil { + canPruneExpired = checkBEP206PruneRule(childEpoch, parentEpoch, epoch) + } + if canPruneExpired || node.parents == 0 { // Delete nodes if expired or no more parents node referencing this node + if canPruneExpired { + log.Info("Dereferencing expired trie node") + } // Remove the node from the flush-list switch child { case db.oldest: @@ -598,7 +630,7 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { } // Dereference all children and delete the node node.forChilds(func(hash common.Hash) { - db.dereference(hash, child) + db.dereference(hash, child, epoch) }) delete(db.dirties, child) db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size)) diff --git a/trie/node_enc.go b/trie/node_enc.go index cade35b707..c8990b17d1 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -59,9 +59,9 @@ func (n valueNode) encode(w rlp.EncoderBuffer) { w.WriteBytes(n) } -func (n rawFullNode) encode(w rlp.EncoderBuffer) { +func (n *rawFullNode) encode(w rlp.EncoderBuffer) { offset := w.List() - for _, c := range n { + for _, c := range n.children { if c != nil { c.encode(w) } else { diff --git a/trie/stacktrie.go b/trie/stacktrie.go index ec278d390c..0cbc02aa47 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -392,15 +392,15 @@ func (st *StackTrie) hashRec(hasher *hasher) { var nodes rawFullNode for i, child := range st.children { if child == nil { - nodes[i] = nilValueNode + nodes.children[i] = nilValueNode continue } child.hashRec(hasher) if len(child.val) < 32 { - nodes[i] = rawNode(child.val) + nodes.children[i] = rawNode(child.val) } else { - nodes[i] = hashNode(child.val) + nodes.children[i] = hashNode(child.val) } // Release child back to pool. diff --git a/trie/trie_test.go b/trie/trie_test.go index f3be7ce9d9..0c8a0d2fa6 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1612,7 +1612,7 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [] h := trie.Hash() trie.Commit(nil) b.StartTimer() - trie.db.Dereference(h) + trie.db.Dereference(h, 0) // TODO(asyukii): set epoch to 0 temporary, might need to fix b.StopTimer() } From a98d129a42e8e6c815ebd2d4d7de2e5ce0ee7045 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Wed, 17 May 2023 23:54:28 +0800 Subject: [PATCH 45/51] state/pruner: support expired snap kv prune; trie/shadow_node_history: remove history when in pruning mode; --- cmd/evm/internal/t8ntool/execution.go | 2 +- core/blockchain.go | 3 +- core/chain_makers.go | 2 +- core/rawdb/schema.go | 7 ++ core/state/pruner/pruner.go | 50 +++++++++--- eth/backend.go | 1 + light/trie.go | 2 +- trie/shadow_node_difflayer.go | 107 +++++++++++++++++--------- trie/shadow_node_difflayer_test.go | 10 ++- trie/shadow_node_history_test.go | 32 +++++++- trie/shadow_node_test.go | 8 +- trie/trie_test.go | 2 +- 12 files changed, 164 insertions(+), 62 deletions(-) diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index ba656da33d..a10e95dec0 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -272,7 +272,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) *state.StateDB { sdb := state.NewDatabase(db) - tree, _ := trie.NewShadowNodeSnapTree(db) + tree, _ := trie.NewShadowNodeSnapTree(db, false) statedb, _ := state.NewWithStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1), common.Hash{}, sdb, nil, tree) for addr, a := range pre.Pre { statedb.SetCode(addr, a.Code) diff --git a/core/blockchain.go b/core/blockchain.go index b76da2e8a8..75a964b7c6 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -148,6 +148,7 @@ type CacheConfig struct { NoTries bool // Insecure settings. Do not have any tries in databases if enabled. SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it + NoPruning bool } // To avoid cycle import @@ -365,7 +366,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par } // load shadow node tree to R&W - if bc.shadowNodeTree, err = trie.NewShadowNodeSnapTree(db); err != nil { + if bc.shadowNodeTree, err = trie.NewShadowNodeSnapTree(db, cacheConfig.NoPruning); err != nil { return nil, err } diff --git a/core/chain_makers.go b/core/chain_makers.go index 1e86f9a245..f51bf9feb8 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -276,7 +276,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse } return nil, nil } - tree, err := trie.NewShadowNodeSnapTree(db) + tree, err := trie.NewShadowNodeSnapTree(db, false) if err != nil { panic(err) } diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index 9eeddc5efc..a1c56a79a0 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -262,6 +262,13 @@ func IsCodeKey(key []byte) (bool, []byte) { return false, nil } +func IsSnapStorageKey(key []byte) (bool, []byte, []byte) { + if bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == common.HashLength+common.HashLength+len(SnapshotStoragePrefix) { + return true, key[len(SnapshotStoragePrefix):], key[len(SnapshotStoragePrefix) : len(SnapshotStoragePrefix)+common.HashLength] + } + return false, nil, nil +} + // configKey = configPrefix + hash func configKey(hash common.Hash) []byte { return append(configPrefix, hash.Bytes()...) diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 8032d7294c..87f145f377 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -88,6 +88,7 @@ type Pruner struct { headHeader *types.Header snaptree *snapshot.Tree triesInMemory uint64 + latestEpoch types.StateEpoch } type BlockPruner struct { @@ -104,6 +105,13 @@ func NewPruner(db ethdb.Database, datadir, trieCachePath string, bloomSize, trie if headBlock == nil { return nil, errors.New("Failed to load head block") } + chainConfig := rawdb.ReadChainConfig(db, headBlock.Hash()) + if chainConfig == nil { + return nil, errors.New("cannot find chainConfig") + } + latestEpoch := types.GetStateEpoch(chainConfig, headBlock.Number()) + log.Info("NewPruner with", "number", headBlock.Number(), "latestEpoch", latestEpoch) + snaptree, err := snapshot.New(db, trie.NewDatabase(db), 256, int(triesInMemory), headBlock.Root(), false, false, false, false) if err != nil { return nil, err // The relevant snapshot(s) might not exist @@ -126,6 +134,7 @@ func NewPruner(db ethdb.Database, datadir, trieCachePath string, bloomSize, trie triesInMemory: triesInMemory, headHeader: headBlock.Header(), snaptree: snaptree, + latestEpoch: latestEpoch, }, nil } @@ -238,7 +247,7 @@ func pruneAll(maindb ethdb.Database, g *core.Genesis) error { return nil } -func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, stateBloom *stateBloom, bloomPath string, middleStateRoots map[common.Hash]struct{}, start time.Time) error { +func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, stateBloom *stateBloom, bloomPath string, middleStateRoots map[common.Hash]struct{}, start time.Time, latestEpoch types.StateEpoch) error { // Delete all stale trie nodes in the disk. With the help of state bloom // the trie nodes(and codes) belong to the active state will be filtered // out. A very small part of stale tries will also be filtered because of @@ -247,16 +256,31 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta // dangling node is the state root is super low. So the dangling nodes in // theory will never ever be visited again. var ( - count int - size common.StorageSize - pstart = time.Now() - logged = time.Now() - batch = maindb.NewBatch() - iter = maindb.NewIterator(nil, nil) + count int + snapCount int + size common.StorageSize + snapSize common.StorageSize + pstart = time.Now() + logged = time.Now() + batch = maindb.NewBatch() + iter = maindb.NewIterator(nil, nil) ) for iter.Next() { key := iter.Key() + // if it is snap kv, check if expired, do not follow parent>=child+2 prune rule, cover by trie node + isSnapKey, _, snapAddr := rawdb.IsSnapStorageKey(key) + if isSnapKey { + snapVal, err := snapshot.ParseSnapValFromBytes(iter.Value()) + if err == nil && types.EpochExpired(snapVal.Epoch, latestEpoch) { + batch.Delete(key) + snapCount += 1 + snapSize += common.StorageSize(len(key) + len(iter.Value())) + log.Info("delete expired snap kv", "addrHash", snapAddr, "kvEpoch", snapVal.Epoch, "epoch", latestEpoch) + } + continue + } + // All state entries don't belong to specific state and genesis are deleted here // - trie node // - legacy contract code @@ -310,6 +334,7 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta } iter.Release() log.Info("Pruned state data", "nodes", count, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart))) + log.Info("Pruned snap data", "kvs", snapCount, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart))) // Pruning is done, now drop the "useless" layers from the snapshot. // Firstly, flushing the target layer into the disk. After that all @@ -658,7 +683,7 @@ func (p *Pruner) Prune(root common.Hash) error { return err } log.Info("State bloom filter committed", "name", filterName) - return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start) + return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start, p.latestEpoch) } // RecoverPruning will resume the pruning procedure during the system restart. @@ -680,6 +705,13 @@ func RecoverPruning(datadir string, db ethdb.Database, trieCachePath string, tri if headBlock == nil { return errors.New("Failed to load head block") } + chainConfig := rawdb.ReadChainConfig(db, headBlock.Hash()) + if chainConfig == nil { + return errors.New("cannot find chainConfig") + } + latestEpoch := types.GetStateEpoch(chainConfig, headBlock.Number()) + log.Info("RecoverPruning with", "number", headBlock.Number(), "latestEpoch", latestEpoch) + // Initialize the snapshot tree in recovery mode to handle this special case: // - Users run the `prune-state` command multiple times // - Neither these `prune-state` running is finished(e.g. interrupted manually) @@ -722,7 +754,7 @@ func RecoverPruning(datadir string, db ethdb.Database, trieCachePath string, tri log.Error("Pruning target state is not existent") return errors.New("non-existent target state") } - return prune(snaptree, stateBloomRoot, db, stateBloom, stateBloomPath, middleRoots, time.Now()) + return prune(snaptree, stateBloomRoot, db, stateBloom, stateBloomPath, middleRoots, time.Now(), latestEpoch) } // extractGenesis loads the genesis state and commits all the state entries diff --git a/eth/backend.go b/eth/backend.go index f7fe65e285..855939dd56 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -204,6 +204,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { TrieCleanRejournal: config.TrieCleanCacheRejournal, TrieDirtyLimit: config.TrieDirtyCache, TrieDirtyDisabled: config.NoPruning, + NoPruning: config.NoPruning, TrieTimeLimit: config.TrieTimeout, NoTries: config.TriesVerifyMode != core.LocalVerify, SnapshotLimit: config.SnapshotCache, diff --git a/light/trie.go b/light/trie.go index 0cf403de8e..2de90fc698 100644 --- a/light/trie.go +++ b/light/trie.go @@ -38,7 +38,7 @@ var ( ) func NewState(ctx context.Context, config *params.ChainConfig, head *types.Header, odr OdrBackend) *state.StateDB { - tree, _ := trie.NewShadowNodeSnapTree(odr.Database()) + tree, _ := trie.NewShadowNodeSnapTree(odr.Database(), true) state, _ := state.NewWithStateEpoch(config, head.Number, head.Root, NewStateDatabase(ctx, head, odr), nil, tree) return state } diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go index 3506b1ab6b..dc400e8886 100644 --- a/trie/shadow_node_difflayer.go +++ b/trie/shadow_node_difflayer.go @@ -8,6 +8,8 @@ import ( "math/big" "sync" + "github.com/ethereum/go-ethereum/log" + "github.com/RoaringBitmap/roaring/roaring64" "github.com/ethereum/go-ethereum/common/math" lru "github.com/hashicorp/golang-lru" @@ -55,8 +57,8 @@ type ShadowNodeSnapTree struct { lock sync.RWMutex } -func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore) (*ShadowNodeSnapTree, error) { - diskLayer, err := loadDiskLayer(diskdb) +func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore, archiveMode bool) (*ShadowNodeSnapTree, error) { + diskLayer, err := loadDiskLayer(diskdb, archiveMode) if err != nil { return nil, err } @@ -225,11 +227,11 @@ func (s *ShadowNodeSnapTree) flattenDiffs2Disk(flatten []shadowNodeSnapshot, dis } // loadDiskLayer load from db, could be nil when none in db -func loadDiskLayer(db ethdb.KeyValueStore) (*shadowNodeDiskLayer, error) { +func loadDiskLayer(db ethdb.KeyValueStore, archiveMode bool) (*shadowNodeDiskLayer, error) { val := rawdb.ReadShadowNodePlainStateMeta(db) // if there is no disk layer, will construct a fake disk layer if len(val) == 0 { - diskLayer, err := newShadowNodeDiskLayer(db, common.Big0, emptyRoot) + diskLayer, err := newShadowNodeDiskLayer(db, common.Big0, emptyRoot, archiveMode) if err != nil { return nil, err } @@ -240,7 +242,7 @@ func loadDiskLayer(db ethdb.KeyValueStore) (*shadowNodeDiskLayer, error) { return nil, err } - layer, err := newShadowNodeDiskLayer(db, meta.BlockNumber, meta.BlockRoot) + layer, err := newShadowNodeDiskLayer(db, meta.BlockNumber, meta.BlockRoot, archiveMode) if err != nil { return nil, err } @@ -443,11 +445,12 @@ type shadowNodeDiskLayer struct { blockNumber *big.Int blockRoot common.Hash cache *lru.Cache + archiveMode bool // archiveMode, if true keep all history, if false just flatten all changSet to plainState lock sync.RWMutex } -func newShadowNodeDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, blockRoot common.Hash) (*shadowNodeDiskLayer, error) { +func newShadowNodeDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, blockRoot common.Hash, archiveMode bool) (*shadowNodeDiskLayer, error) { cache, err := lru.New(defaultDiskLayerCacheSize) if err != nil { return nil, err @@ -457,6 +460,7 @@ func newShadowNodeDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, bl blockNumber: blockNumber, blockRoot: blockRoot, cache: cache, + archiveMode: archiveMode, }, nil } @@ -512,36 +516,8 @@ func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDi } batch := s.diskdb.NewBatch() nodeSet := diff.getNodeSet() - for addr, subSet := range nodeSet { - changeSet := make([]nodeChgRecord, 0, len(subSet)) - for path, val := range subSet { - if err := refreshShadowNodeHistory(s.diskdb, batch, addr, path, number.Uint64()); err != nil { - return nil, err - } - prev := rawdb.ReadShadowNodePlainState(s.diskdb, addr, path) - // refresh plain state - if len(val) == 0 { - if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { - return nil, err - } - } else { - if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { - return nil, err - } - } - - changeSet = append(changeSet, nodeChgRecord{ - Path: path, - Prev: prev, - }) - } - enc, err := rlp.EncodeToBytes(changeSet) - if err != nil { - return nil, err - } - if err = rawdb.WriteShadowNodeChangeSet(batch, addr, number.Uint64(), enc); err != nil { - return nil, err - } + if err := s.writeHistory(number, batch, diff.getNodeSet()); err != nil { + return nil, err } // update meta @@ -565,10 +541,11 @@ func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDi blockNumber: number, blockRoot: diff.blockRoot, cache: s.cache, + archiveMode: s.archiveMode, } // reuse cache - for addr, nodes := range diff.nodeSet { + for addr, nodes := range nodeSet { for path, val := range nodes { diskLayer.cache.Add(shadowNodeCacheKey(addr, path), val) } @@ -576,6 +553,62 @@ func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDi return diskLayer, nil } +func (s *shadowNodeDiskLayer) writeHistory(number *big.Int, batch ethdb.Batch, nodeSet map[common.Hash]map[string][]byte) error { + // if not in archiveMode, just flatten to plainState + if !s.archiveMode { + for addr, subSet := range nodeSet { + for path, val := range subSet { + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { + return err + } + } else { + if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { + return err + } + } + } + } + log.Info("shadow node history pruned, only keep plainState", "number", number, "count", len(nodeSet)) + return nil + } + + for addr, subSet := range nodeSet { + changeSet := make([]nodeChgRecord, 0, len(subSet)) + for path, val := range subSet { + if err := refreshShadowNodeHistory(s.diskdb, batch, addr, path, number.Uint64()); err != nil { + return err + } + prev := rawdb.ReadShadowNodePlainState(s.diskdb, addr, path) + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { + return err + } + } else { + if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { + return err + } + } + + changeSet = append(changeSet, nodeChgRecord{ + Path: path, + Prev: prev, + }) + } + enc, err := rlp.EncodeToBytes(changeSet) + if err != nil { + return err + } + if err = rawdb.WriteShadowNodeChangeSet(batch, addr, number.Uint64(), enc); err != nil { + return err + } + } + + return nil +} + func shadowNodeCacheKey(addr common.Hash, path string) string { key := make([]byte, len(addr)+len(path)) copy(key[:], addr.Bytes()) diff --git a/trie/shadow_node_difflayer_test.go b/trie/shadow_node_difflayer_test.go index a31ee1ff30..e91f103a00 100644 --- a/trie/shadow_node_difflayer_test.go +++ b/trie/shadow_node_difflayer_test.go @@ -14,9 +14,11 @@ var ( blockRoot0 = makeHash("b0") blockRoot1 = makeHash("b1") blockRoot2 = makeHash("b2") + blockRoot3 = makeHash("b3") storageRoot0 = makeHash("s0") storageRoot1 = makeHash("s1") storageRoot2 = makeHash("s2") + storageRoot3 = makeHash("s3") contract1 = makeHash("c1") contract2 = makeHash("c2") contract3 = makeHash("c3") @@ -25,7 +27,7 @@ var ( func TestShadowNodeDiffLayer_whenGenesis(t *testing.T) { diskdb := memorydb.New() // create empty tree - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) snap := tree.Snapshot(blockRoot0) assert.Nil(t, snap) @@ -41,7 +43,7 @@ func TestShadowNodeDiffLayer_whenGenesis(t *testing.T) { assert.NoError(t, err) // reload - tree, err = NewShadowNodeSnapTree(diskdb) + tree, err = NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) diskLayer := tree.Snapshot(emptyRoot) assert.NotNil(t, diskLayer) @@ -69,7 +71,7 @@ func TestShadowNodeDiffLayer_whenGenesis(t *testing.T) { func TestShadowNodeDiffLayer_crud(t *testing.T) { diskdb := memorydb.New() // create empty tree - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) set1 := makeNodeSet(contract1, []string{"hello", "world", "h1", "w1"}) appendNodeSet(set1, contract3, []string{"h3", "w3"}) @@ -113,7 +115,7 @@ func TestShadowNodeDiffLayer_crud(t *testing.T) { func TestShadowNodeDiffLayer_capDiffLayers(t *testing.T) { diskdb := memorydb.New() // create empty tree - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) // push 200 diff layers diff --git a/trie/shadow_node_history_test.go b/trie/shadow_node_history_test.go index d502cbcc8d..fcd1ecb48c 100644 --- a/trie/shadow_node_history_test.go +++ b/trie/shadow_node_history_test.go @@ -10,7 +10,7 @@ import ( func TestShadowNodeHistory_Diff2Disk(t *testing.T) { diskdb := memorydb.New() - diskLayer, err := loadDiskLayer(diskdb) + diskLayer, err := loadDiskLayer(diskdb, true) assert.NoError(t, err) diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) _, err = diskLayer.PushDiff(diff) @@ -25,7 +25,7 @@ func TestShadowNodeHistory_Diff2Disk(t *testing.T) { assert.Equal(t, []byte{}, val) // reload disk layer - diskLayer, err = loadDiskLayer(diskdb) + diskLayer, err = loadDiskLayer(diskdb, true) assert.NoError(t, err) val, err = diskLayer.ShadowNode(contract1, "hello") assert.NoError(t, err) @@ -34,7 +34,7 @@ func TestShadowNodeHistory_Diff2Disk(t *testing.T) { func TestShadowNodeHistory_case2(t *testing.T) { diskdb := memorydb.New() - diskLayer, err := loadDiskLayer(diskdb) + diskLayer, err := loadDiskLayer(diskdb, true) assert.NoError(t, err) diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) @@ -53,3 +53,29 @@ func TestShadowNodeHistory_case2(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []byte("world1"), val) } + +func TestShadowNodeHistory_disableArchive(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb, false) + assert.NoError(t, err) + + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + diff = newShadowNodeDiffLayer(common.Big2, blockRoot2, nil, makeNodeSet(contract1, []string{"hello", "world1"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err := FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world1"), val) + + diff = newShadowNodeDiffLayer(common.Big3, blockRoot3, nil, makeNodeSet(contract1, []string{"hello", "world2"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err = FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world2"), val) +} diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go index 1626635f94..c307a72a2d 100644 --- a/trie/shadow_node_test.go +++ b/trie/shadow_node_test.go @@ -17,7 +17,7 @@ import ( func TestShadowNodeRW_CRUD(t *testing.T) { diskdb := memorydb.New() - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) @@ -40,7 +40,7 @@ func TestShadowNodeRO_Get(t *testing.T) { diskdb := memorydb.New() makeDiskLayer(diskdb, common.Big2, blockRoot2, contract1, []string{"k1", "v1"}) - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) storageRO, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) @@ -78,7 +78,7 @@ func makeDiskLayer(diskdb *memorydb.Database, number *big.Int, root common.Hash, func TestShadowNodeRW_Commit(t *testing.T) { diskdb := memorydb.New() - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) @@ -98,7 +98,7 @@ func TestShadowNodeRW_Commit(t *testing.T) { func TestNewShadowNodeStorage4Trie(t *testing.T) { diskdb := memorydb.New() - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) assert.NoError(t, err) diff --git a/trie/trie_test.go b/trie/trie_test.go index 0c8a0d2fa6..dd0ad9b3f2 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1024,7 +1024,7 @@ func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { func makeStorageTrieDatabase(t *testing.T) (*Database, *ShadowNodeSnapTree) { diskdb := memorydb.New() database := NewDatabase(diskdb) - tree, err := NewShadowNodeSnapTree(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) assert.NoError(t, err) return database, tree } From ce354d0413a257a5ec8cb62cf4d72624ba94fdea Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 18 May 2023 12:52:26 +0800 Subject: [PATCH 46/51] fix: add parent expired check for inline prune --- trie/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trie/database.go b/trie/database.go index abdb6093a2..c27d9a2fec 100644 --- a/trie/database.go +++ b/trie/database.go @@ -578,7 +578,7 @@ func (db *Database) Dereference(root common.Hash, epoch types.StateEpoch) { } func checkBEP206PruneRule(childEpoch, parentEpoch, currEpoch types.StateEpoch) bool { - return types.EpochExpired(childEpoch, currEpoch) && (parentEpoch >= childEpoch+2) + return types.EpochExpired(childEpoch, currEpoch) && (types.EpochExpired(parentEpoch, currEpoch) || parentEpoch >= childEpoch+2) } // dereference is the private locked version of Dereference. From ec4fcf4e39a8f65105747528f4d55eb170a40977 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Thu, 18 May 2023 16:58:51 +0800 Subject: [PATCH 47/51] logs: add some error logs; --- core/state/state_object.go | 3 +++ core/vm/evm.go | 12 ++++++++++++ trie/trie.go | 3 +++ 3 files changed, 18 insertions(+) diff --git a/core/state/state_object.go b/core/state/state_object.go index f1f29c97c6..e7da13328f 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -627,6 +627,9 @@ func (s *StateObject) CommitTrie(db Database) (int, error) { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } root, committed, err := s.trie.Commit(nil) + if err != nil { + log.Error("obj CommitTrie", "addr", s.address, "root", root, "err", err) + } log.Debug("obj CommitTrie", "addr", s.address, "root", root, "err", err) if err == nil { s.data.Root = root diff --git a/core/vm/evm.go b/core/vm/evm.go index 3408ba40ab..1cf6978412 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/log" + "github.com/holiman/uint256" "github.com/ethereum/go-ethereum/common" @@ -263,6 +265,16 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas //} else { // evm.StateDB.DiscardSnapshot(snapshot) } + + errors := evm.Errors() + if err != nil || len(errors) > 0 { + log.Error("execution got err", "from", caller.Address(), "to", addr) + for _, e := range errors { + log.Error("op err", "err", e) + } + log.Error("return err", "err", err) + } + return ret, gas, err } diff --git a/trie/trie.go b/trie/trie.go index 084b005ede..36c8c0711e 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -915,6 +915,9 @@ func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err path := []byte{} rootExpired := types.EpochExpired(t.getRootEpoch(), t.currentEpoch) newNode, didRevive, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub, t.currentEpoch, path, rootExpired) + if err != nil { + log.Error("tryRevive err", "prefix", nub.n1PrefixKey, "didRevive", didRevive, "err", err) + } if didRevive && err == nil { successNubs = append(successNubs, nub) t.root = newNode From cba33ff0eaed89ce79c04144459eaa04ec548f1d Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Fri, 26 May 2023 13:52:08 +0800 Subject: [PATCH 48/51] db/inspect: add inspect shadow nodes info; pruner: opt init with chain config; state/statedb: generate witness fron pendingReviveTrie; log: add some logs; --- core/blockchain.go | 3 ++- core/rawdb/database.go | 21 +++++++++++++++- core/state/errors.go | 49 +++++++++++++++++++------------------ core/state/pruner/pruner.go | 6 +++-- core/state/state_object.go | 6 +++-- core/state/statedb.go | 12 ++++++++- miner/worker.go | 4 +-- 7 files changed, 68 insertions(+), 33 deletions(-) diff --git a/core/blockchain.go b/core/blockchain.go index 75a964b7c6..929c971ff7 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1841,6 +1841,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) } for ; block != nil && err == nil || errors.Is(err, ErrKnownBlock); block, err = it.next() { + log.Info("Try import new chain segment", "number", block.NumberU64(), "hash", block.Hash(), "from", block.Coinbase()) // If the chain is terminating, stop processing blocks if bc.insertStopped() { log.Debug("Abort during block processing") @@ -1947,7 +1948,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) substart = time.Now() if !statedb.IsLightProcessed() { if err := bc.validator.ValidateState(block, statedb, receipts, usedGas); err != nil { - log.Error("validate state failed", "error", err) + log.Error("validate state failed", "number", block.NumberU64(), "hash", block.Hash(), "proposer", block.Coinbase(), "error", err) bc.reportBlock(block, receipts, err) statedb.StopPrefetcher() return it.index, err diff --git a/core/rawdb/database.go b/core/rawdb/database.go index f63112f313..ab056784a9 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -450,6 +450,12 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { chtTrieNodes stat bloomTrieNodes stat + // shadow nodes statistics + shadowNodeMetadataSize stat + shadowNodeHistorySize stat + shadowNodeChangeSetSize stat + shadowNodePlainStateSize stat + // Meta- and unaccounted data metadata stat unaccounted stat @@ -508,13 +514,22 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { bytes.HasPrefix(key, []byte("bltIndex-")) || bytes.HasPrefix(key, []byte("bltRoot-")): // Bloomtrie sub bloomTrieNodes.Add(size) + + case bytes.Equal(key, shadowNodePlainStateMeta): + shadowNodeMetadataSize.Add(size) + case bytes.HasPrefix(key, ShadowNodeHistoryPrefix) && len(key) >= (len(ShadowNodeHistoryPrefix)+common.HashLength+8): + shadowNodeHistorySize.Add(size) + case bytes.HasPrefix(key, ShadowNodeChangeSetPrefix) && len(key) == (len(ShadowNodeChangeSetPrefix)+common.HashLength+8): + shadowNodeChangeSetSize.Add(size) + case bytes.HasPrefix(key, ShadowNodePlainStatePrefix) && len(key) >= (len(ShadowNodePlainStatePrefix)+common.HashLength): + shadowNodePlainStateSize.Add(size) default: var accounted bool for _, meta := range [][]byte{ databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, lastPivotKey, fastTrieProgressKey, snapshotDisabledKey, SnapshotRootKey, snapshotJournalKey, snapshotGeneratorKey, snapshotRecoveryKey, txIndexTailKey, fastTxLookupLimitKey, - uncleanShutdownKey, badBlockKey, transitionStatusKey, + uncleanShutdownKey, badBlockKey, transitionStatusKey, LastSafePointBlockKey, pruneAncientKey, shadowNodeSnapshotJournalKey, } { if bytes.Equal(key, meta) { metadata.Add(size) @@ -571,6 +586,10 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + {"Shadow Node", "Metadata", shadowNodeMetadataSize.Size(), shadowNodeMetadataSize.Count()}, + {"Shadow Node", "History", shadowNodeHistorySize.Size(), shadowNodeHistorySize.Count()}, + {"Shadow Node", "ChangeSet", shadowNodeChangeSetSize.Size(), shadowNodeChangeSetSize.Count()}, + {"Shadow Node", "PlainState", shadowNodePlainStateSize.Size(), shadowNodePlainStateSize.Count()}, } table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"Database", "Category", "Size", "Items"}) diff --git a/core/state/errors.go b/core/state/errors.go index 37820c0d9e..ebfd57c22f 100644 --- a/core/state/errors.go +++ b/core/state/errors.go @@ -10,47 +10,48 @@ import ( // ExpiredStateError Access State error, must revert the execution type ExpiredStateError struct { - Addr common.Address - Key common.Hash - Path []byte - Epoch types.StateEpoch - isInsert bool // when true it through expired path, must recovery the expired path - reason string + Addr common.Address + Key common.Hash + Path []byte + Epoch types.StateEpoch + reason string } func NewSnapExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { return &ExpiredStateError{ - Addr: addr, - Key: key, - Path: []byte{}, - Epoch: epoch, - isInsert: false, - reason: "snap query", + Addr: addr, + Key: key, + Path: []byte{}, + Epoch: epoch, + reason: "snap query", } } func NewExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { return &ExpiredStateError{ - Addr: addr, - Key: key, - Path: err.Path, - Epoch: err.Epoch, - isInsert: false, - reason: "query", + Addr: addr, + Key: key, + Path: err.Path, + Epoch: err.Epoch, + reason: "query", } } func NewInsertExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { return &ExpiredStateError{ - Addr: addr, - Key: key, - Path: err.Path, - Epoch: err.Epoch, - isInsert: true, - reason: "insert", + Addr: addr, + Key: key, + Path: err.Path, + Epoch: err.Epoch, + reason: "insert", } } func (e *ExpiredStateError) Error() string { return fmt.Sprintf("Access expired state, addr: %v, path: %v, key: %v, epoch: %v, reason: %v", e.Addr, e.Path, e.Key, e.Epoch, e.reason) } + +func (e *ExpiredStateError) Reason(r string) *ExpiredStateError { + e.reason = r + return e +} diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 87f145f377..eba9898a06 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -105,7 +105,8 @@ func NewPruner(db ethdb.Database, datadir, trieCachePath string, bloomSize, trie if headBlock == nil { return nil, errors.New("Failed to load head block") } - chainConfig := rawdb.ReadChainConfig(db, headBlock.Hash()) + stored := rawdb.ReadCanonicalHash(db, 0) + chainConfig := rawdb.ReadChainConfig(db, stored) if chainConfig == nil { return nil, errors.New("cannot find chainConfig") } @@ -705,7 +706,8 @@ func RecoverPruning(datadir string, db ethdb.Database, trieCachePath string, tri if headBlock == nil { return errors.New("Failed to load head block") } - chainConfig := rawdb.ReadChainConfig(db, headBlock.Hash()) + stored := rawdb.ReadCanonicalHash(db, 0) + chainConfig := rawdb.ReadChainConfig(db, stored) if chainConfig == nil { return errors.New("cannot find chainConfig") } diff --git a/core/state/state_object.go b/core/state/state_object.go index e7da13328f..025b7a2303 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -202,6 +202,7 @@ func (s *StateObject) getTrie(db Database) Trie { var err error // check if enable state epoch if s.db.enableAccStateEpoch(false, s.address) { + log.Debug("Open StorageTrie with shadow nodes", "addr", s.address, "targetEpoch", s.targetEpoch) s.trie, err = db.OpenStorageTrieWithShadowNode(s.addrHash, s.data.Root, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) if err != nil { log.Error("OpenStorageTrieWithShadowNode err", "targetEpoch", s.targetEpoch, "err", err) @@ -211,6 +212,7 @@ func (s *StateObject) getTrie(db Database) Trie { return s.trie } + log.Debug("Open StorageTrie normal", "addr", s.address, "targetEpoch", s.targetEpoch, "addr", s.address) s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) @@ -333,7 +335,7 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Ha // query from dirty revive trie, got the newest expired info _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) if enErr, ok := err.(*trie.ExpiredNodeError); ok { - return common.Hash{}, NewExpiredStateError(s.address, key, enErr) + return common.Hash{}, NewExpiredStateError(s.address, key, enErr).Reason("snap query") } return common.Hash{}, NewSnapExpiredStateError(s.address, key, sv.Epoch) } @@ -388,7 +390,7 @@ func (s *StateObject) SetState(db Database, key, value common.Hash) error { // If the new value is the same as old, don't set prev, err := s.GetState(db, key) if exErr, ok := err.(*ExpiredStateError); ok { - exErr.isInsert = true + exErr.Reason("query from insert") return exErr } if err != nil { diff --git a/core/state/statedb.go b/core/state/statedb.go index ae4cd360be..3a2fb772c0 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -526,7 +526,7 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { // GetStorageWitness returns only the Merkle proof for given storage slot. func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, key common.Hash) ([][]byte, error) { var proof proofList - trie := s.StorageTrie(a) + trie := s.StorageReviveTrie(a) if trie == nil { return proof, errors.New("storage trie for requested address does not exist") } @@ -571,6 +571,16 @@ func (s *StateDB) StorageTrie(addr common.Address) Trie { return cpy.getTrie(s.db) } +func (s *StateDB) StorageReviveTrie(addr common.Address) Trie { + stateObject := s.getStateObject(addr) + if stateObject == nil { + return nil + } + cpy := stateObject.deepCopy(s) + cpy.updateTrie(s.db) + return cpy.getPendingReviveTrie(s.db) +} + func (s *StateDB) HasSuicided(addr common.Address) bool { stateObject := s.getStateObject(addr) if stateObject != nil { diff --git a/miner/worker.go b/miner/worker.go index e8a67514e6..d6f51aedb7 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -679,8 +679,8 @@ func (w *worker) resultLoop() { continue } writeBlockTimer.UpdateSince(start) - log.Info("Successfully sealed new block", "number", block.Number(), "sealhash", sealhash, "hash", hash, - "elapsed", common.PrettyDuration(time.Since(task.createdAt))) + log.Info("Successfully sealed new block", "number", block.Number(), "hash", hash, "from", block.Coinbase(), + "sealhash", sealhash, "elapsed", common.PrettyDuration(time.Since(task.createdAt))) // Broadcast the block and announce chain insertion event w.mux.Post(core.NewMinedBlockEvent{Block: block}) From 17d93b371f36efc61fce72df551895a83fc50263 Mon Sep 17 00:00:00 2001 From: 0xbundler <124862913+0xbundler@users.noreply.github.com> Date: Mon, 29 May 2023 15:09:13 +0800 Subject: [PATCH 49/51] trie/trie: opt trie update epoch, using copy to avoid node sharing; --- trie/trie.go | 101 ++++++++++++++++++++++++--------------------------- 1 file changed, 48 insertions(+), 53 deletions(-) diff --git a/trie/trie.go b/trie/trie.go index 36c8c0711e..1c7e71969f 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -244,35 +244,30 @@ func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.S return nil, n, false, nil } - if updateEpoch { - n.setEpoch(t.currentEpoch) - value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), t.currentEpoch, true) - } else { - value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), epoch, false) - } - if err == nil && didResolve { + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), t.currentEpoch, updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { n = n.copy() n.Val = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + didResolve = true } return value, n, didResolve, err case *fullNode: - if updateEpoch { - // child node is expired - if n.Children[key[pos]] != nil { - if expired, err := n.ChildExpired(key[:pos+1], int(key[pos]), t.currentEpoch); expired { - return nil, n, false, err - } - } - n.setEpoch(t.currentEpoch) - n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) - value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, t.currentEpoch, true) - } else { - value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos])), false) - } - if err == nil && didResolve { + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos])), updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { n = n.copy() n.Children[key[pos]] = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + if updateEpoch && newnode != nil { + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + } + didResolve = true } + return value, n, didResolve, err case hashNode: child, err := t.resolveHash(n, key[:pos]) @@ -432,9 +427,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - n.setEpoch(t.currentEpoch) dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value, n.epoch) - if !dirty || err != nil { + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag(), epoch: t.currentEpoch}, nil @@ -446,16 +440,13 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE if err != nil { return false, nil, err } - if t.withShadowNodes { - branch.setEpoch(t.currentEpoch) - branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) - } _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value, t.currentEpoch) if err != nil { return false, nil, err } if t.withShadowNodes { branch.setEpoch(t.currentEpoch) + branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) } // Replace this shortNode with the branch if it occurs at index 0. @@ -466,26 +457,17 @@ func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateE return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil case *fullNode: - if t.withShadowNodes { - // else, set its epoch to current epoch. - n.setEpoch(t.currentEpoch) - // if inserting a new node to this full node, there is no need to check whether this child is expired. - if len(key) > 0 && n.Children[key[0]] != nil { - // if child is expired, return err - if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { - return false, n.Children[key[0]], err - } - } - // else, set child node's epoch to current epoch - n.UpdateChildEpoch(int(key[0]), t.currentEpoch) - } dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value, n.GetChildEpoch(int(key[0]))) - if !dirty || err != nil { + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn + if t.withShadowNodes { + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } return true, n, nil case nil: @@ -556,7 +538,7 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, // subtrie must contain at least two other values with keys // longer than n.Key. dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):], n.epoch) - if !dirty || err != nil { + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } switch child := child.(type) { @@ -573,23 +555,19 @@ func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, } case *fullNode: - if t.withShadowNodes { - // else, set its epoch to current epoch. - n.setEpoch(t.currentEpoch) - // if child is expired, return err - if expired, err := n.ChildExpired(append(prefix, key[0]), int(key[0]), t.currentEpoch); expired { - return false, n.Children[key[0]], err - } - // else, set child node's epoch to current epoch - n.UpdateChildEpoch(int(key[0]), t.currentEpoch) - } dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:], n.GetChildEpoch(int(key[0]))) - if !dirty || err != nil { + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn + if t.withShadowNodes { + n.setEpoch(t.currentEpoch) + } + if t.withShadowNodes && nn != nil { + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } // Because n is a full node, it must've contained at least two children // before the delete operation. If the new child value is non-nil, n still @@ -1200,6 +1178,23 @@ func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool { return types.EpochExpired(epoch, t.currentEpoch) } +// renewNode check if renew node, according to trie node epoch and childDirty, +// childDirty or updateEpoch need copy for prevent reuse trie cache +func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bool) bool { + // when !updateEpoch, it same as !t.withShadowNodes + if !t.withShadowNodes || !updateEpoch { + return childDirty + } + + // when no epoch update, same as before + if epoch == t.currentEpoch { + return childDirty + } + + // node need update epoch, just renew + return true +} + func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { expectHash := common.BytesToHash(root[:]) val, err := sndb.Get(ShadowTreeRootNodePath) From 6fb0a68691aac59fa47aba3303cf2969215e9b2c Mon Sep 17 00:00:00 2001 From: asyukii Date: Thu, 25 May 2023 20:04:28 +0800 Subject: [PATCH 50/51] refactor(trie): rootNode implements node interface --- trie/committer.go | 14 ++++++++++++++ trie/database.go | 6 ++++-- trie/node.go | 4 ++++ trie/secure_trie.go | 2 +- trie/shadow_node.go | 20 ++++++++++++++++++++ trie/trie.go | 27 +++++++++++++++++++++++++-- trie/trie_test.go | 15 ++++++++------- 7 files changed, 76 insertions(+), 12 deletions(-) diff --git a/trie/committer.go b/trie/committer.go index 5d9a27a503..553d94d511 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -119,6 +119,18 @@ func (c *committer) commit(n node, db *Database) (node, int, error) { return hn, childCommitted + 1, nil } return collapsed, childCommitted, nil + case *rootNode: + log.Info("Committing root node") + hash, _ := cn.cache() + log.Info("Root node cache", "hash", hash) + collapsed := cn.copy() + hashedNode := c.store(collapsed, db) + if hn, ok := hashedNode.(hashNode); ok { + return hn, 1, nil + } else { + log.Info("Root node is not a hash node.") + } + return collapsed, 0, nil case hashNode: return cn, 0, nil default: @@ -250,6 +262,8 @@ func estimateSize(n node) int { return 1 + len(n) case hashNode: return 1 + len(n) + case *rootNode: + return 1 + len(n.cachedHash) + len(n.TrieRoot) + len(n.ShadowTreeRoot) default: panic(fmt.Sprintf("node type %T", n)) } diff --git a/trie/database.go b/trie/database.go index c27d9a2fec..650b8ad9b7 100644 --- a/trie/database.go +++ b/trie/database.go @@ -229,6 +229,8 @@ func forGatherChildren(n node, onChild func(hash common.Hash)) { } case hashNode: onChild(common.BytesToHash(n)) + case *rootNode: + onChild(n.TrieRoot) case valueNode, nil, rawNode: default: panic(fmt.Sprintf("unknown node type: %T", n)) @@ -256,7 +258,7 @@ func simplifyNode(n node) node { } return node - case valueNode, hashNode, rawNode: + case valueNode, hashNode, rawNode, *rootNode: return n default: @@ -292,7 +294,7 @@ func expandNode(hash hashNode, n node) node { } return node - case valueNode, hashNode: + case valueNode, hashNode, *rootNode: return n default: diff --git a/trie/node.go b/trie/node.go index e1849ea246..a9ee38835f 100644 --- a/trie/node.go +++ b/trie/node.go @@ -38,6 +38,7 @@ const ( rawNodeType rawShortNodeType rawFullNodeType + rootNodeType ) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} @@ -218,6 +219,9 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { case 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") + case 3: + n, err := decodeRootNode(buf) + return n, wrapError(err, "root") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index 8272ae0534..99bdc9ea31 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -70,7 +70,7 @@ func NewSecureWithShadowNodes(curEpoch types.StateEpoch, root common.Hash, db *D panic("trie.NewSecure called without a database") } - rn, err := resolveRootNode(sndb, root) + rn, err := resolveRootNodeTrieDb(db, root) if err != nil { return nil, err } diff --git a/trie/shadow_node.go b/trie/shadow_node.go index ed868b19e8..802eb33795 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "errors" + "fmt" "math/big" "sync" @@ -41,9 +42,28 @@ func newRootNode(epoch types.StateEpoch, trieRoot, shadowTreeRoot common.Hash) * n.resolveCache() return n } + +func (n *rootNode) copy() *rootNode { copy := *n; return © } + func (n *rootNode) encode(w rlp.EncoderBuffer) { rlp.Encode(w, n) } + +func (n *rootNode) cache() (hashNode, bool) { return n.cachedHash[:], true } + +func (n *rootNode) setEpoch(epoch types.StateEpoch) { n.Epoch = epoch } +func (n *rootNode) getEpoch() types.StateEpoch { return n.Epoch } + +func (n *rootNode) String() string { return n.fstring("") } + +func (n *rootNode) fstring(ind string) string { + return fmt.Sprintf("rootNode{epoch:%d, trieRoot:%s, shadowTreeRoot:%s}", n.Epoch, n.TrieRoot, n.ShadowTreeRoot) +} + +func (n *rootNode) nodeType() int { + return rootNodeType +} + func (n *rootNode) resolveCache() { buf := rlp.NewEncoderBuffer(nil) n.encode(buf) diff --git a/trie/trie.go b/trie/trie.go index 1c7e71969f..bcd0d19958 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -223,6 +223,8 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } value, newnode, _, err := t.tryGet(child, key, pos) return value, newnode, true, err + case *rootNode: + return nil, n, false, nil // TODO(asyukii): temporary fix default: panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } @@ -750,6 +752,7 @@ func (t *Trie) Hash() common.Hash { hash, cached, _ := t.hashRoot() t.root = cached newRootHash := common.BytesToHash(hash.(hashNode)) + t.trieRoot = newRootHash if t.withShadowNodes { newShadowTreeRoot := emptyRoot shadowTreeRoot, err := t.ShadowHash() @@ -804,7 +807,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { // values, but don't write to it. if _, dirty := t.root.cache(); !dirty { if t.withShadowNodes { - rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) + rootNodeHash, err := t.storeRootNodeTrieDb(h, newRootHash, newShadowTreeRoot) if err != nil { return common.Hash{}, 0, err } @@ -835,7 +838,7 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { return common.Hash{}, 0, err } if t.withShadowNodes { - rootNodeHash, err := t.storeRootNode(newRootHash, newShadowTreeRoot) + rootNodeHash, err := t.storeRootNodeTrieDb(h, newRootHash, newShadowTreeRoot) if err != nil { return common.Hash{}, 0, err } @@ -1148,6 +1151,12 @@ func (t *Trie) commitShadowNodes(origin node, prefix []byte, epoch types.StateEp } } +func (t *Trie) storeRootNodeTrieDb(c *committer, newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) + hn, _, err := c.Commit(rn, t.db) + return common.BytesToHash(hn), err +} + func (t *Trie) storeRootNode(newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) if err := t.sndb.Put(ShadowTreeRootNodePath, rn.cachedEnc); err != nil { @@ -1195,6 +1204,20 @@ func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bo return true } +func resolveRootNodeTrieDb(db *Database, root common.Hash) (*rootNode, error) { + expectHash := common.BytesToHash(root[:]) + rn := db.node(root) + n, ok := rn.(*rootNode) + if !ok { + return newEpoch0RootNode(root), nil + } + + if n.cachedHash != expectHash { + return nil, errors.New("found the wrong rootNode") + } + return n, nil +} + func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { expectHash := common.BytesToHash(root[:]) val, err := sndb.Get(ShadowTreeRootNodePath) diff --git a/trie/trie_test.go b/trie/trie_test.go index dd0ad9b3f2..656cae81a9 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -617,7 +617,7 @@ func TestTryRevive(t *testing.T) { assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) // Reset trie - trie, _ = nonRandomTrie(500) + trie, _ = nonRandomTrieWithShadowNodes(500) } } } @@ -664,7 +664,7 @@ func TestTryReviveCustomData(t *testing.T) { // Verify root hash currRootHash := trie.Hash() - assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash, key, prefixKey) + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x, key %x, prefixKey %x", currRootHash, oriRootHash, key, prefixKey) // Reset trie trie = createCustomTrie(data, 10) @@ -995,11 +995,12 @@ func TestTrie_ShadowHash_case2(t *testing.T) { assert.NoError(t, err) sndb := storageDB.OpenStorage(contract1) - enc, err := sndb.Get(ShadowTreeRootNodePath) - assert.NoError(t, err) - r1, err := decodeRootNode(enc) - assert.NoError(t, err) - tr, err = NewWithShadowNode(11, r1, database, sndb) + rn := tr.db.node(newRoot) + // enc, err := sndb.Get(ShadowTreeRootNodePath) + // assert.NoError(t, err) + // r1, err := decodeRootNode(enc) + // assert.NoError(t, err) + tr, err = NewWithShadowNode(11, rn.(*rootNode), database, sndb) assert.NoError(t, err) _, err = tr.TryGet(common.Hex2Bytes("223dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) assert.NoError(t, err) From b9d8408a7b3bd694b14a51442ee64ae9d6e0ea49 Mon Sep 17 00:00:00 2001 From: asyukii Date: Fri, 26 May 2023 14:44:59 +0800 Subject: [PATCH 51/51] fix(prune): storage trie now can be inline pruned fix: temporarily do not prune account trie node fix: estimate size for root node --- trie/committer.go | 7 +------ trie/database.go | 26 +++++++++++++++++++++++++- trie/node.go | 2 +- trie/shadow_node.go | 2 +- trie/shadow_node_test.go | 2 +- trie/trie.go | 2 +- 6 files changed, 30 insertions(+), 11 deletions(-) diff --git a/trie/committer.go b/trie/committer.go index 553d94d511..a7df16d438 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -120,15 +120,10 @@ func (c *committer) commit(n node, db *Database) (node, int, error) { } return collapsed, childCommitted, nil case *rootNode: - log.Info("Committing root node") - hash, _ := cn.cache() - log.Info("Root node cache", "hash", hash) collapsed := cn.copy() hashedNode := c.store(collapsed, db) if hn, ok := hashedNode.(hashNode); ok { return hn, 1, nil - } else { - log.Info("Root node is not a hash node.") } return collapsed, 0, nil case hashNode: @@ -263,7 +258,7 @@ func estimateSize(n node) int { case hashNode: return 1 + len(n) case *rootNode: - return 1 + len(n.cachedHash) + len(n.TrieRoot) + len(n.ShadowTreeRoot) + return 5 + len(n.TrieRoot) + len(n.ShadowTreeRoot) default: panic(fmt.Sprintf("node type %T", n)) } diff --git a/trie/database.go b/trie/database.go index 650b8ad9b7..6c07654683 100644 --- a/trie/database.go +++ b/trie/database.go @@ -212,11 +212,26 @@ func (n *cachedNode) getEpoch() (types.StateEpoch, error) { return n.epoch, nil case *rawFullNode: return n.epoch, nil + case *rootNode: + return n.Epoch, nil default: return 0, fmt.Errorf("unknown node type: %T", n) // TODO(asyukii): may never reach this case, consider panic } } +func (n *cachedNode) updateEpoch(epoch types.StateEpoch) { + switch n := n.node.(type) { + case *rawShortNode: + n.epoch = epoch + case *rawFullNode: + n.epoch = epoch + case *rootNode: + n.Epoch = epoch + default: + return + } +} + // forGatherChildren traverses the node hierarchy of a collapsed storage node and // invokes the callback for all the hashnode children. func forGatherChildren(n node, onChild func(hash common.Hash)) { @@ -357,6 +372,15 @@ func (db *Database) insert(hash common.Hash, size int, node node) { // If the node's already cached, skip if _, ok := db.dirties[hash]; ok { + // update the epoch + switch n := node.(type) { + case *shortNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + case *fullNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + case *rootNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + } return } memcacheDirtyWriteMeter.Mark(int64(size)) @@ -611,7 +635,7 @@ func (db *Database) dereference(child common.Hash, parent common.Hash, epoch typ } childEpoch, getChildEpochErr := node.getEpoch() canPruneExpired := false - if getParentEpochErr == nil && getChildEpochErr == nil { + if childEpoch != 0 && getParentEpochErr == nil && getChildEpochErr == nil { // TODO(asyukii): temporary fix, will not prune epoch 0 nodes because of account trie nodes canPruneExpired = checkBEP206PruneRule(childEpoch, parentEpoch, epoch) } if canPruneExpired || node.parents == 0 { // Delete nodes if expired or no more parents node referencing this node diff --git a/trie/node.go b/trie/node.go index a9ee38835f..fdf08be1ec 100644 --- a/trie/node.go +++ b/trie/node.go @@ -220,7 +220,7 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { n, err := decodeFull(hash, elems) return n, wrapError(err, "full") case 3: - n, err := decodeRootNode(buf) + n, err := DecodeRootNode(buf) return n, wrapError(err, "root") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) diff --git a/trie/shadow_node.go b/trie/shadow_node.go index 802eb33795..67c3b8baa9 100644 --- a/trie/shadow_node.go +++ b/trie/shadow_node.go @@ -77,7 +77,7 @@ func (n *rootNode) resolveCache() { returnHasherToPool(h) } -func decodeRootNode(enc []byte) (*rootNode, error) { +func DecodeRootNode(enc []byte) (*rootNode, error) { n := &rootNode{} if err := rlp.DecodeBytes(enc, n); err != nil { return nil, err diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go index c307a72a2d..667cc291ee 100644 --- a/trie/shadow_node_test.go +++ b/trie/shadow_node_test.go @@ -234,7 +234,7 @@ func TestRootNode_encodeDecode(t *testing.T) { item.n.encode(buf) enc := buf.ToBytes() - rn, err := decodeRootNode(enc) + rn, err := DecodeRootNode(enc) assert.NoError(t, err) if !item.isEqual { assert.NotEqual(t, item.n, rn) diff --git a/trie/trie.go b/trie/trie.go index bcd0d19958..5614eed7f2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -1227,7 +1227,7 @@ func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error if len(val) == 0 { return newEpoch0RootNode(root), nil } - n, err := decodeRootNode(val) + n, err := DecodeRootNode(val) if err != nil { return nil, err }