diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 43d2de9356..6fa80373cf 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: unit-test: diff --git a/.github/workflows/commit-lint.yml b/.github/workflows/commit-lint.yml index 7df13ebcac..0680a52f79 100644 --- a/.github/workflows/commit-lint.yml +++ b/.github/workflows/commit-lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: commitlint: diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 234cbfda5b..fb66659f44 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: truffle-test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ebe646dfe6..476665a660 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: golang-lint: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index d692c2f0d8..809bc12f82 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -5,11 +5,13 @@ on: branches: - master - develop + - state_expiry_develop pull_request: branches: - master - develop + - state_expiry_develop jobs: unit-test: diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index f8ceec8838..0d195851bb 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -42,6 +42,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" ) @@ -130,7 +131,8 @@ func (b *SimulatedBackend) rollback(parent *types.Block) { blocks, _ := core.GenerateChain(b.config, parent, ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {}) b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), b.blockchain.StateCache(), nil) + blockNum := new(big.Int).Add(parent.Number(), common.Big1) + b.pendingState, _ = state.NewWithStateEpoch(b.config, blockNum, b.pendingBlock.Root(), b.blockchain.StateCache(), nil, b.blockchain.ShadowNodeTree()) } // Fork creates a side-chain that can be used to simulate reorgs. @@ -169,7 +171,7 @@ func (b *SimulatedBackend) stateByBlockNumber(ctx context.Context, blockNumber * if err != nil { return nil, err } - return b.blockchain.StateAt(block.Root()) + return b.blockchain.StateAt(block.Root(), block.Number()) } // CodeAt returns the code associated with a certain account in the blockchain. @@ -221,7 +223,10 @@ func (b *SimulatedBackend) StorageAt(ctx context.Context, contract common.Addres return nil, err } - val := stateDB.GetState(contract, key) + val, err := stateDB.GetState(contract, key) + if err != nil { + return nil, err + } return val[:], nil } @@ -585,6 +590,178 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMs return hi, nil } +// EstimateGasAndReviveState executes the requested code against the currently pending block/state and +// returns the used amount of gas. If the execution fails due to expired state, it will revive the state +// and try again. +func (b *SimulatedBackend) EstimateGasAndReviveState(ctx context.Context, call ethereum.CallMsg) (uint64, []types.ReviveWitness, error) { + b.mu.Lock() + defer b.mu.Unlock() + + // Initialize witnessList + var witnessList []types.ReviveWitness + if call.WitnessList != nil { + witnessList = call.WitnessList + } + witLen := len(witnessList) + + // Determine the lowest and highest possible gas limits to binary search in between + var ( + lo uint64 = params.TxGas - 1 + hi uint64 + cap uint64 + ) + if call.Gas >= params.TxGas { + hi = call.Gas + } else { + hi = b.pendingBlock.GasLimit() + } + // Normalize the max fee per gas the call is willing to spend. + var feeCap *big.Int + if call.GasPrice != nil && (call.GasFeeCap != nil || call.GasTipCap != nil) { + return 0, witnessList, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } else if call.GasPrice != nil { + feeCap = call.GasPrice + } else if call.GasFeeCap != nil { + feeCap = call.GasFeeCap + } else { + feeCap = common.Big0 + } + // Recap the highest gas allowance with account's balance. + if feeCap.BitLen() != 0 { + balance := b.pendingState.GetBalance(call.From) // from can't be nil + available := new(big.Int).Set(balance) + if call.Value != nil { + if call.Value.Cmp(available) >= 0 { + return 0, witnessList, errors.New("insufficient funds for transfer") + } + available.Sub(available, call.Value) + } + allowance := new(big.Int).Div(available, feeCap) + if allowance.IsUint64() && hi > allowance.Uint64() { + transfer := call.Value + if transfer == nil { + transfer = new(big.Int) + } + log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance, + "sent", transfer, "feecap", feeCap, "fundable", allowance) + hi = allowance.Uint64() + } + } + cap = hi + + // Create a helper to check if a gas allowance results in an executable transaction + executable := func(gas uint64, witnessList []types.ReviveWitness) (bool, *core.ExecutionResult, bool, error) { + call.Gas = gas + + snapshot := b.pendingState.Snapshot() + res, evmErrors, err := b.callContractExpired(ctx, call, b.pendingBlock, b.pendingState) + + // Create MPTProof + isExpiredError := false + if len(evmErrors) > 0 { + addressToProofMap := make(map[common.Address][]types.MPTProof) + for _, evmErr := range evmErrors { + if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { + isExpiredError = true + proof, err := b.pendingState.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) // TODO (asyukii): Is this key correct? + if proof == nil { + continue + } + if err != nil { + return true, nil, isExpiredError, err + } + if len(proof) == 0 { + continue + } + addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ + RootKeyHex: stateErr.Path, + Proof: proof, + }) + } + } + + // TODO (asyukii): existing witnessList might already has ReviveWitness with the same address + // Might want to consider decoding it and merge them together + + // Create a ReviveWitness object for each address and add it to witnessList + for addr, proofs := range addressToProofMap { + // Build a storageTrieWitness + storageTrieWitness := types.StorageTrieWitness{ + Address: addr, + ProofList: proofs, + } + // Encode StorageTrieWitness + enc, err := rlp.EncodeToBytes(storageTrieWitness) + if err != nil { + return true, nil, isExpiredError, err + } + // Create a ReviveWitness + reviveWitness := types.ReviveWitness{ + WitnessType: types.StorageTrieWitnessType, + Data: enc, + } + // Append to witness list + witnessList = append(witnessList, reviveWitness) + } + } + + b.pendingState.RevertToSnapshot(snapshot) + + if err != nil { + if _, ok := err.(*state.ExpiredStateError); ok || errors.Is(err, core.ErrIntrinsicGas) { + return true, nil, isExpiredError, nil // Special case, raise gas limit + } + return true, nil, isExpiredError, err // Bail out + } + return res.Failed(), res, isExpiredError, nil + } + // Execute the binary search and hone in on an executable gas limit + for lo+1 < hi { + mid := (hi + lo) / 2 + failed, _, isExpiredError, err := executable(mid, witnessList) + + if isExpiredError { + if witLen == len(witnessList) { + // If witnessList is not updated, it means that the proofs are not + // sufficient to revive the state. + return 0, witnessList, fmt.Errorf("cannot generate enough proofs to revive the state") + } + witLen = len(witnessList) + continue + } + + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more + if err != nil { + return 0, witnessList, err + } + if failed { + lo = mid + } else { + hi = mid + } + } + // Reject the transaction as invalid if it still fails at the highest allowance + if hi == cap { + failed, result, _, err := executable(hi, witnessList) + if err != nil { + return 0, witnessList, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return 0, witnessList, newRevertError(result) + } + return 0, witnessList, result.Err + } + // Otherwise, the specified gas cap is too low + return 0, witnessList, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } + } + return hi, witnessList, nil +} + // callContract implements common code between normal and pending contract calls. // state is modified during execution, make sure to copy it if necessary. func (b *SimulatedBackend) callContract(ctx context.Context, call ethereum.CallMsg, block *types.Block, stateDB *state.StateDB) (*core.ExecutionResult, error) { @@ -642,6 +819,64 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call ethereum.CallM return core.NewStateTransition(vmEnv, msg, gasPool).TransitionDb() } +// callContractExpired implements common code between normal and pending contract calls. +// state is modified during execution, make sure to copy it if necessary. +// It will also return the EVM errors if any. +func (b *SimulatedBackend) callContractExpired(ctx context.Context, call ethereum.CallMsg, block *types.Block, stateDB *state.StateDB) (*core.ExecutionResult, []*vm.EVMError, error) { + // Gas prices post 1559 need to be initialized + if call.GasPrice != nil && (call.GasFeeCap != nil || call.GasTipCap != nil) { + return nil, nil, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } + head := b.blockchain.CurrentHeader() + if !b.blockchain.Config().IsLondon(head.Number) { + // If there's no basefee, then it must be a non-1559 execution + if call.GasPrice == nil { + call.GasPrice = new(big.Int) + } + call.GasFeeCap, call.GasTipCap = call.GasPrice, call.GasPrice + } else { + // A basefee is provided, necessitating 1559-type execution + if call.GasPrice != nil { + // User specified the legacy gas field, convert to 1559 gas typing + call.GasFeeCap, call.GasTipCap = call.GasPrice, call.GasPrice + } else { + // User specified 1559 gas feilds (or none), use those + if call.GasFeeCap == nil { + call.GasFeeCap = new(big.Int) + } + if call.GasTipCap == nil { + call.GasTipCap = new(big.Int) + } + // Backfill the legacy gasPrice for EVM execution, unless we're all zeroes + call.GasPrice = new(big.Int) + if call.GasFeeCap.BitLen() > 0 || call.GasTipCap.BitLen() > 0 { + call.GasPrice = math.BigMin(new(big.Int).Add(call.GasTipCap, head.BaseFee), call.GasFeeCap) + } + } + } + // Ensure message is initialized properly. + if call.Gas == 0 { + call.Gas = 50000000 + } + if call.Value == nil { + call.Value = new(big.Int) + } + // Set infinite balance to the fake caller account. + from := stateDB.GetOrNewStateObject(call.From) + from.SetBalance(math.MaxBig256) + // Execute the call. + msg := callMsg{call} + + txContext := core.NewEVMTxContext(msg) + evmContext := core.NewEVMBlockContext(block.Header(), b.blockchain, nil) + // Create a new environment which holds all relevant information + // about the transaction and calling mechanisms. + vmEnv := vm.NewEVM(evmContext, txContext, stateDB, b.config, vm.Config{NoBaseFee: true}) + gasPool := new(core.GasPool).AddGas(math.MaxUint64) + res, err := core.NewStateTransition(vmEnv, msg, gasPool).TransitionDb() + return res, vmEnv.ErrorCollection, err +} + // SendTransaction updates the pending block to include the given transaction. func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transaction) error { b.mu.Lock() @@ -672,7 +907,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil) + b.pendingState, _ = state.NewWithStateEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) b.pendingReceipts = receipts[0] return nil } @@ -788,8 +1023,7 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error { stateDB, _ := b.blockchain.State() b.pendingBlock = blocks[0] - b.pendingState, _ = state.New(b.pendingBlock.Root(), stateDB.Database(), nil) - + b.pendingState, _ = state.NewWithStateEpoch(b.config, b.pendingBlock.Number(), b.pendingBlock.Root(), stateDB.Database(), nil, b.blockchain.ShadowNodeTree()) return nil } @@ -803,17 +1037,18 @@ type callMsg struct { ethereum.CallMsg } -func (m callMsg) From() common.Address { return m.CallMsg.From } -func (m callMsg) Nonce() uint64 { return 0 } -func (m callMsg) IsFake() bool { return true } -func (m callMsg) To() *common.Address { return m.CallMsg.To } -func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callMsg) GasFeeCap() *big.Int { return m.CallMsg.GasFeeCap } -func (m callMsg) GasTipCap() *big.Int { return m.CallMsg.GasTipCap } -func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callMsg) Value() *big.Int { return m.CallMsg.Value } -func (m callMsg) Data() []byte { return m.CallMsg.Data } -func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) IsFake() bool { return true } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) GasFeeCap() *big.Int { return m.CallMsg.GasFeeCap } +func (m callMsg) GasTipCap() *big.Int { return m.CallMsg.GasTipCap } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } +func (m callMsg) WitnessList() types.WitnessList { return m.CallMsg.WitnessList } // filterBackend implements filters.Backend to support filtering for logs without // taking bloom-bits acceleration structures into account. diff --git a/accounts/abi/bind/backends/simulated_test.go b/accounts/abi/bind/backends/simulated_test.go index f857c399f7..97ed6e8061 100644 --- a/accounts/abi/bind/backends/simulated_test.go +++ b/accounts/abi/bind/backends/simulated_test.go @@ -531,6 +531,100 @@ func TestEstimateGas(t *testing.T) { } } +// TODO (asyukii) +// Need some smart way to modify state such that they can be expired +// Add more tests: +// [] - Test for gas estimation of fully expired contract +// [] - Test for gas estimation of contract with partially expired storage +func TestEstimateGasAndReviveState(t *testing.T) { + /* + pragma solidity ^0.8.0; + contract HardcodedStorage { + uint256 private _value1 = 10; + uint256 private _value2 = 20; + uint256 private _value3 = 30; + uint256 private _value5 = 50; + uint256 private _value6 = 60; + uint256 private _value7 = 70; + uint256 private _value8 = 80; + + function getValue1() public view returns (uint256) { + return _value1; + } + } + */ + const contractAbi = "[{\"inputs\":[],\"name\":\"getValue1\",\"outputs\":[{\"internalType\":\"uint256\",\"name\":\"\",\"type\":\"uint256\"}],\"stateMutability\":\"view\",\"type\":\"function\"}]" + const contractBin = "0x6080604052600a6000556014600155601e6002556032600355603c6004556046600555605060065534801561003357600080fd5b5060b6806100426000396000f3fe6080604052348015600f57600080fd5b506004361060285760003560e01c806360d586f814602d575b600080fd5b60336047565b604051603e91906067565b60405180910390f35b60008054905090565b6000819050919050565b6061816050565b82525050565b6000602082019050607a6000830184605a565b9291505056fea26469706673582212208ea04f86f5d169c7e3fde0277a16120a9cc6db59035dee70356b3b1e615fe6a664736f6c63430008120033" + + key, _ := crypto.GenerateKey() + addr := crypto.PubkeyToAddress(key.PublicKey) + opts, _ := bind.NewKeyedTransactorWithChainID(key, big.NewInt(1337)) + + sim := NewSimulatedBackend(core.GenesisAlloc{addr: {Balance: big.NewInt(params.Ether)}}, 10000000) + defer sim.Close() + + parsed, _ := abi.JSON(strings.NewReader(contractAbi)) + contractAddr, _, _, _ := bind.DeployContract(opts, parsed, common.FromHex(contractBin), sim) + sim.Commit() + + var cases = []struct { + name string + message ethereum.CallMsg + expect uint64 + expectError error + expectData interface{} + }{ + {"plain transfer(valid)", ethereum.CallMsg{ + From: addr, + To: &addr, + Gas: 0, + GasPrice: big.NewInt(0), + Value: big.NewInt(1), + Data: nil, + }, params.TxGas, nil, nil}, + + {"plain transfer(invalid)", ethereum.CallMsg{ + From: addr, + To: &contractAddr, + Gas: 0, + GasPrice: big.NewInt(0), + Value: big.NewInt(1), + Data: nil, + }, 0, errors.New("execution reverted"), nil}, + {"call getValue1(valid)", ethereum.CallMsg{ + From: addr, + To: &contractAddr, + Gas: 0, + GasPrice: big.NewInt(100000000000), + Value: big.NewInt(0), + Data: common.Hex2Bytes("60d586f8"), + }, 23479, nil, nil}, + } + + for _, c := range cases { + got, _, err := sim.EstimateGasAndReviveState(context.Background(), c.message) + if c.expectError != nil { + if err == nil { + t.Fatalf("Expect error, got nil") + } + if c.expectError.Error() != err.Error() { + t.Fatalf("Expect error, want %v, got %v", c.expectError, err) + } + if c.expectData != nil { + if err, ok := err.(*revertError); !ok { + t.Fatalf("Expect revert error, got %T", err) + } else if !reflect.DeepEqual(err.ErrorData(), c.expectData) { + t.Fatalf("Error data mismatch, want %v, got %v", c.expectData, err.ErrorData()) + } + } + continue + } + if got != c.expect { + t.Fatalf("Gas estimation mismatch, want %d, got %d", c.expect, got) + } + } +} + func TestEstimateGasWithPrice(t *testing.T) { key, _ := crypto.GenerateKey() addr := crypto.PubkeyToAddress(key.PublicKey) diff --git a/accounts/external/backend.go b/accounts/external/backend.go index 07062495de..af280fda8a 100644 --- a/accounts/external/backend.go +++ b/accounts/external/backend.go @@ -216,7 +216,7 @@ func (api *ExternalSigner) SignTx(account accounts.Account, tx *types.Transactio From: common.NewMixedcaseAddress(account.Address), } switch tx.Type() { - case types.LegacyTxType, types.AccessListTxType: + case types.LegacyTxType, types.ReviveStateTxType, types.AccessListTxType: args.GasPrice = (*hexutil.Big)(tx.GasPrice()) case types.DynamicFeeTxType: args.MaxFeePerGas = (*hexutil.Big)(tx.GasFeeCap()) diff --git a/cmd/evm/internal/t8ntool/execution.go b/cmd/evm/internal/t8ntool/execution.go index 56d6a9b5ff..a10e95dec0 100644 --- a/cmd/evm/internal/t8ntool/execution.go +++ b/cmd/evm/internal/t8ntool/execution.go @@ -116,7 +116,7 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, return h } var ( - statedb = MakePreState(rawdb.NewMemoryDatabase(), pre.Pre) + statedb = MakePreState(rawdb.NewMemoryDatabase(), pre, chainConfig) signer = types.MakeSigner(chainConfig, new(big.Int).SetUint64(pre.Env.Number)) gaspool = new(core.GasPool) blockHash = common.Hash{0x13, 0x37} @@ -270,10 +270,11 @@ func (pre *Prestate) Apply(vmConfig vm.Config, chainConfig *params.ChainConfig, return statedb, execRs, nil } -func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB { +func MakePreState(db ethdb.Database, pre *Prestate, config *params.ChainConfig) *state.StateDB { sdb := state.NewDatabase(db) - statedb, _ := state.New(common.Hash{}, sdb, nil) - for addr, a := range accounts { + tree, _ := trie.NewShadowNodeSnapTree(db, false) + statedb, _ := state.NewWithStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number-1), common.Hash{}, sdb, nil, tree) + for addr, a := range pre.Pre { statedb.SetCode(addr, a.Code) statedb.SetNonce(addr, a.Nonce) statedb.SetBalance(addr, a.Balance) @@ -285,7 +286,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB statedb.Finalise(false) statedb.AccountsIntermediateRoot() root, _, _ := statedb.Commit(nil) - statedb, _ = state.New(root, sdb, nil) + statedb, _ = state.NewWithStateEpoch(config, new(big.Int).SetUint64(pre.Env.Number), root, sdb, nil, tree) return statedb } diff --git a/cmd/evm/internal/t8ntool/transaction.go b/cmd/evm/internal/t8ntool/transaction.go index 6f1c964ada..57057089cb 100644 --- a/cmd/evm/internal/t8ntool/transaction.go +++ b/cmd/evm/internal/t8ntool/transaction.go @@ -139,7 +139,7 @@ func Transaction(ctx *cli.Context) error { r.Address = sender } // Check intrinsic gas - if gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, + if gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, chainConfig.IsHomestead(new(big.Int)), chainConfig.IsIstanbul(new(big.Int))); err != nil { r.Error = err results = append(results, r) diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go index d57602f8d5..1baffb5db9 100644 --- a/cmd/evm/runner.go +++ b/cmd/evm/runner.go @@ -138,8 +138,8 @@ func runCmd(ctx *cli.Context) error { genesisConfig = gen db := rawdb.NewMemoryDatabase() genesis := gen.ToBlock(db) - statedb, _ = state.New(genesis.Root(), state.NewDatabase(db), nil) chainConfig = gen.Config + statedb, _ = state.New(genesis.Root(), state.NewDatabase(db), nil) } else { statedb, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()), nil) genesisConfig = new(core.Genesis) diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index d2b44a4565..be3a9fe750 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "math/big" "net" "os" "path" @@ -495,11 +496,11 @@ func exportPreimages(ctx *cli.Context) error { return nil } -func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, ethdb.Database, common.Hash, error) { +func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, ethdb.Database, common.Hash, *big.Int, error) { db := utils.MakeChainDatabase(ctx, stack, true, false) var header *types.Header if ctx.NArg() > 1 { - return nil, nil, common.Hash{}, fmt.Errorf("expected 1 argument (number or hash), got %d", ctx.NArg()) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("expected 1 argument (number or hash), got %d", ctx.NArg()) } if ctx.NArg() == 1 { arg := ctx.Args().First() @@ -508,17 +509,17 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth if number := rawdb.ReadHeaderNumber(db, hash); number != nil { header = rawdb.ReadHeader(db, hash, *number) } else { - return nil, nil, common.Hash{}, fmt.Errorf("block %x not found", hash) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("block %x not found", hash) } } else { number, err := strconv.Atoi(arg) if err != nil { - return nil, nil, common.Hash{}, err + return nil, nil, common.Hash{}, common.Big0, err } if hash := rawdb.ReadCanonicalHash(db, uint64(number)); hash != (common.Hash{}) { header = rawdb.ReadHeader(db, hash, uint64(number)) } else { - return nil, nil, common.Hash{}, fmt.Errorf("header for block %d not found", number) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("header for block %d not found", number) } } } else { @@ -526,7 +527,7 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth header = rawdb.ReadHeadHeader(db) } if header == nil { - return nil, nil, common.Hash{}, errors.New("no head block found") + return nil, nil, common.Hash{}, common.Big0, errors.New("no head block found") } startArg := common.FromHex(ctx.String(utils.StartKeyFlag.Name)) var start common.Hash @@ -538,7 +539,7 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth start = crypto.Keccak256Hash(startArg) log.Info("Converting start-address to hash", "address", common.BytesToAddress(startArg), "hash", start.Hex()) default: - return nil, nil, common.Hash{}, fmt.Errorf("invalid start argument: %x. 20 or 32 hex-encoded bytes required", startArg) + return nil, nil, common.Hash{}, common.Big0, fmt.Errorf("invalid start argument: %x. 20 or 32 hex-encoded bytes required", startArg) } var conf = &state.DumpConfig{ SkipCode: ctx.Bool(utils.ExcludeCodeFlag.Name), @@ -550,17 +551,18 @@ func parseDumpConfig(ctx *cli.Context, stack *node.Node) (*state.DumpConfig, eth log.Info("State dump configured", "block", header.Number, "hash", header.Hash().Hex(), "skipcode", conf.SkipCode, "skipstorage", conf.SkipStorage, "start", hexutil.Encode(conf.Start), "limit", conf.Max) - return conf, db, header.Root, nil + return conf, db, header.Root, header.Number, nil } func dump(ctx *cli.Context) error { stack, _ := makeConfigNode(ctx) defer stack.Close() - conf, db, root, err := parseDumpConfig(ctx, stack) + conf, db, root, _, err := parseDumpConfig(ctx, stack) if err != nil { return err } + // TODO unknown chainconfig, default using epoch0 state, err := state.New(root, state.NewDatabase(db), nil) if err != nil { return err diff --git a/cmd/geth/snapshot.go b/cmd/geth/snapshot.go index 1c92da95f5..f2243ff390 100644 --- a/cmd/geth/snapshot.go +++ b/cmd/geth/snapshot.go @@ -563,6 +563,7 @@ func traverseState(ctx *cli.Context) error { return err } if acc.Root != emptyRoot { + // TODO default using epoch0 trie, but need sndb to query shadow nodes storageTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) @@ -667,6 +668,7 @@ func traverseRawState(ctx *cli.Context) error { return errors.New("invalid account") } if acc.Root != emptyRoot { + // TODO default using epoch0 trie, but need sndb to query shadow nodes storageTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { log.Error("Failed to open storage trie", "root", acc.Root, "err", err) @@ -728,7 +730,7 @@ func dumpState(ctx *cli.Context) error { stack, _ := makeConfigNode(ctx) defer stack.Close() - conf, db, root, err := parseDumpConfig(ctx, stack) + conf, db, root, _, err := parseDumpConfig(ctx, stack) if err != nil { return err } diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index b1abc1221f..87bbcf3bf3 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -148,7 +148,7 @@ var ( type SignerFn func(accounts.Account, string, []byte) ([]byte, error) type SignerTxFn func(accounts.Account, *types.Transaction, *big.Int) (*types.Transaction, error) -func isToSystemContract(to common.Address) bool { +func IsToSystemContract(to common.Address) bool { return systemContracts[to] } @@ -230,6 +230,10 @@ func New( if parliaConfig != nil && parliaConfig.Epoch == 0 { parliaConfig.Epoch = defaultEpochLength } + if parliaConfig != nil && parliaConfig.StateEpochPeriod == 0 { + parliaConfig.StateEpochPeriod = types.DefaultStateEpochPeriod + } + log.Info("instance parlia with config", "period", parliaConfig.Period, "epoch", parliaConfig.Epoch, "stateEpochPeriod", parliaConfig.StateEpochPeriod) // Allocate the snapshot caches and create the engine recentSnaps, err := lru.NewARC(inMemorySnapshots) @@ -258,7 +262,7 @@ func New( signatures: signatures, validatorSetABI: vABI, slashABI: sABI, - signer: types.NewEIP155Signer(chainConfig.ChainID), + signer: types.NewBEP215Signer(chainConfig.ChainID), } return c @@ -273,7 +277,7 @@ func (p *Parlia) IsSystemTransaction(tx *types.Transaction, header *types.Header if err != nil { return false, errors.New("UnAuthorized transaction") } - if sender == header.Coinbase && isToSystemContract(*tx.To()) && tx.GasPrice().Cmp(big.NewInt(0)) == 0 { + if sender == header.Coinbase && IsToSystemContract(*tx.To()) && tx.GasPrice().Cmp(big.NewInt(0)) == 0 { return true, nil } return false, nil @@ -283,7 +287,7 @@ func (p *Parlia) IsSystemContract(to *common.Address) bool { if to == nil { return false } - return isToSystemContract(*to) + return IsToSystemContract(*to) } // Author implements consensus.Engine, returning the SystemAddress diff --git a/core/bench_test.go b/core/bench_test.go index 06333033c4..e78a926836 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -85,7 +85,7 @@ func genValueTx(nbytes int) func(int, *BlockGen) { return func(i int, gen *BlockGen) { toaddr := common.Address{} data := make([]byte, nbytes) - gas, _ := IntrinsicGas(data, nil, false, false, false) + gas, _ := IntrinsicGas(data, nil, nil, false, false, false) signer := types.MakeSigner(gen.config, big.NewInt(int64(i))) gasPrice := big.NewInt(0) if gen.header.BaseFee != nil { diff --git a/core/blockchain.go b/core/blockchain.go index b119270b99..929c971ff7 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -148,6 +148,7 @@ type CacheConfig struct { NoTries bool // Insecure settings. Do not have any tries in databases if enabled. SnapshotWait bool // Wait for snapshot construction on startup. TODO(karalabe): This is a dirty hack for testing, nuke it + NoPruning bool } // To avoid cycle import @@ -192,6 +193,8 @@ type BlockChain struct { gcproc time.Duration // Accumulates canonical block processing for trie dumping commitLock sync.Mutex // CommitLock is used to protect above field from being modified concurrently + shadowNodeTree *trie.ShadowNodeSnapTree + // txLookupLimit is the maximum number of blocks from head whose tx indices // are reserved: // * 0: means no limit and regenerate any missing indexes @@ -362,9 +365,14 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par return nil, err } + // load shadow node tree to R&W + if bc.shadowNodeTree, err = trie.NewShadowNodeSnapTree(db, cacheConfig.NoPruning); err != nil { + return nil, err + } + // Make sure the state associated with the block is available head := bc.CurrentBlock() - if _, err := state.New(head.Root(), bc.stateCache, bc.snaps); err != nil { + if _, err := state.NewWithStateEpoch(chainConfig, head.Number(), head.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { // Head state is missing, before the state recovery, find out the // disk layer point of snapshot(if it's enabled). Make sure the // rewound point is lower than disk layer. @@ -715,7 +723,7 @@ func (bc *BlockChain) setHeadBeyondRoot(head uint64, root common.Hash, repair bo enoughBeyondCount = beyondCount > maxBeyondBlocks - if _, err := state.New(newHeadBlock.Root(), bc.stateCache, bc.snaps); err != nil { + if _, err := state.NewWithStateEpoch(bc.chainConfig, newHeadBlock.Number(), newHeadBlock.Root(), bc.stateCache, bc.snaps, bc.shadowNodeTree); err != nil { log.Trace("Block state missing, rewinding further", "number", newHeadBlock.NumberU64(), "hash", newHeadBlock.Hash()) if pivot == nil || newHeadBlock.NumberU64() > *pivot { parent := bc.GetBlock(newHeadBlock.ParentHash(), newHeadBlock.NumberU64()-1) @@ -1074,6 +1082,11 @@ func (bc *BlockChain) Stop() { log.Error("Failed to journal state snapshot", "err", err) } } + if bc.shadowNodeTree != nil { + if err := bc.shadowNodeTree.Journal(); err != nil { + log.Error("Failed to journal shadow node snapshot", "err", err) + } + } // Ensure the state of a recent block is also stored to disk before exiting. // We're writing three different states to catch different restart scenarios: @@ -1103,8 +1116,9 @@ func (bc *BlockChain) Stop() { rawdb.WriteSafePointBlockNumber(bc.db, bc.CurrentBlock().NumberU64()) } } + currentEpoch := types.GetStateEpoch(bc.chainConfig, bc.CurrentBlock().Number()) for !bc.triegc.Empty() { - go triedb.Dereference(bc.triegc.PopItem().(common.Hash)) + go triedb.Dereference(bc.triegc.PopItem().(common.Hash), currentEpoch) } if size, _ := triedb.Size(); size != 0 { log.Error("Dangling trie nodes after full cleanup") @@ -1546,13 +1560,14 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types. } } // Garbage collect anything below our required write retention + currentEpoch := types.GetStateEpoch(bc.chainConfig, bc.CurrentBlock().Number()) for !bc.triegc.Empty() { root, number := bc.triegc.Pop() if uint64(-number) > chosen { bc.triegc.Push(root, number) break } - go triedb.Dereference(root.(common.Hash)) + go triedb.Dereference(root.(common.Hash), currentEpoch) } } } @@ -1826,6 +1841,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) } for ; block != nil && err == nil || errors.Is(err, ErrKnownBlock); block, err = it.next() { + log.Info("Try import new chain segment", "number", block.NumberU64(), "hash", block.Hash(), "from", block.Coinbase()) // If the chain is terminating, stop processing blocks if bc.insertStopped() { log.Debug("Abort during block processing") @@ -1882,7 +1898,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) if parent == nil { parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1) } - statedb, err := state.NewWithSharedPool(parent.Root, bc.stateCache, bc.snaps) + statedb, err := state.NewWithStateEpoch(bc.chainConfig, block.Number(), parent.Root, bc.stateCache, bc.snaps, bc.ShadowNodeTree()) if err != nil { return it.index, err } @@ -1932,7 +1948,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals, setHead bool) substart = time.Now() if !statedb.IsLightProcessed() { if err := bc.validator.ValidateState(block, statedb, receipts, usedGas); err != nil { - log.Error("validate state failed", "error", err) + log.Error("validate state failed", "number", block.NumberU64(), "hash", block.Hash(), "proposer", block.Coinbase(), "error", err) bc.reportBlock(block, receipts, err) statedb.StopPrefetcher() return it.index, err diff --git a/core/blockchain_reader.go b/core/blockchain_reader.go index b28661da3e..203f2473f6 100644 --- a/core/blockchain_reader.go +++ b/core/blockchain_reader.go @@ -19,6 +19,8 @@ package core import ( "math/big" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/core/rawdb" @@ -308,12 +310,13 @@ func (bc *BlockChain) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) { // State returns a new mutable state based on the current HEAD block. func (bc *BlockChain) State() (*state.StateDB, error) { - return bc.StateAt(bc.CurrentBlock().Root()) + block := bc.CurrentBlock() + return bc.StateAt(block.Root(), block.Number()) } // StateAt returns a new mutable state based on a particular point in time. -func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { - return state.New(root, bc.stateCache, bc.snaps) +func (bc *BlockChain) StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) { + return state.NewWithStateEpoch(bc.chainConfig, number, root, bc.stateCache, bc.snaps, bc.shadowNodeTree) } // Config retrieves the chain's fork configuration. @@ -327,6 +330,11 @@ func (bc *BlockChain) Snapshots() *snapshot.Tree { return bc.snaps } +// ShadowNodeTree returns the blockchain shadow node tree. +func (bc *BlockChain) ShadowNodeTree() *trie.ShadowNodeSnapTree { + return bc.shadowNodeTree +} + // Validator returns the current validator. func (bc *BlockChain) Validator() Validator { return bc.validator diff --git a/core/blockchain_sethead_test.go b/core/blockchain_sethead_test.go index 7ee213b726..7df56f2a78 100644 --- a/core/blockchain_sethead_test.go +++ b/core/blockchain_sethead_test.go @@ -2017,10 +2017,12 @@ func testSetHead(t *testing.T, tt *rewindTest, snapshots bool) { } // Manually dereference anything not committed to not have to work with 128+ tries for _, block := range sideblocks { - chain.stateCache.TrieDB().Dereference(block.Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, block.Number()) + chain.stateCache.TrieDB().Dereference(block.Root(), blockEpoch) } for _, block := range canonblocks { - chain.stateCache.TrieDB().Dereference(block.Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, block.Number()) + chain.stateCache.TrieDB().Dereference(block.Root(), blockEpoch) } chain.stateCache.Purge() // Force run a freeze cycle diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 20615eef8c..5580c05acf 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -1785,8 +1785,10 @@ func TestTrieForkGC(t *testing.T) { } // Dereference all the recent tries and ensure no past trie is left in for i := 0; i < TestTriesInMemory; i++ { - chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root()) - chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root()) + blockEpoch := types.GetStateEpoch(params.TestChainConfig, blocks[len(blocks)-1-i].Number()) + chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), blockEpoch) + forkEpoch := types.GetStateEpoch(params.TestChainConfig, forks[len(blocks)-1-i].Number()) + chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), forkEpoch) } if len(chain.stateCache.TrieDB().Nodes()) > 0 { t.Fatalf("stale tries still alive after garbase collection") @@ -3178,21 +3180,26 @@ func TestDeleteRecreateSlots(t *testing.T) { statedb, _ := chain.State() // If all is correct, then slot 1 and 2 are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } // Also, 3 and 4 should be set - if got, exp := statedb.GetState(aa, common.HexToHash("03")), common.HexToHash("03"); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("03")), common.HexToHash("03"); got != exp { t.Fatalf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("04")), common.HexToHash("04"); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("04")), common.HexToHash("04"); got != exp { t.Fatalf("got %x exp %x", got, exp) } } +func getStateIgnoreErr(statedb *state.StateDB, addr common.Address, hash common.Hash) common.Hash { + val, _ := statedb.GetState(addr, hash) + return val +} + // TestDeleteRecreateAccount tests a state-transition that contains deletion of a // contract with storage, and a recreate of the same contract via a // regular value-transfer @@ -3258,10 +3265,10 @@ func TestDeleteRecreateAccount(t *testing.T) { statedb, _ := chain.State() // If all is correct, then both slots are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("got %x exp %x", got, exp) } } @@ -3435,10 +3442,10 @@ func TestDeleteRecreateSlotsAcrossManyBlocks(t *testing.T) { } statedb, _ := chain.State() // If all is correct, then slot 1 and 2 are zero - if got, exp := statedb.GetState(aa, common.HexToHash("01")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("01")), (common.Hash{}); got != exp { t.Errorf("block %d, got %x exp %x", blockNum, got, exp) } - if got, exp := statedb.GetState(aa, common.HexToHash("02")), (common.Hash{}); got != exp { + if got, exp := getStateIgnoreErr(statedb, aa, common.HexToHash("02")), (common.Hash{}); got != exp { t.Errorf("block %d, got %x exp %x", blockNum, got, exp) } exp := expectations[i] @@ -3447,7 +3454,7 @@ func TestDeleteRecreateSlotsAcrossManyBlocks(t *testing.T) { t.Fatalf("block %d, expected %v to exist, it did not", blockNum, aa) } for slot, val := range exp.values { - if gotValue, expValue := statedb.GetState(aa, asHash(slot)), asHash(val); gotValue != expValue { + if gotValue, expValue := getStateIgnoreErr(statedb, aa, asHash(slot)), asHash(val); gotValue != expValue { t.Fatalf("block %d, slot %d, got %x exp %x", blockNum, slot, gotValue, expValue) } } diff --git a/core/chain_makers.go b/core/chain_makers.go index c53269ed6e..f51bf9feb8 100644 --- a/core/chain_makers.go +++ b/core/chain_makers.go @@ -20,6 +20,8 @@ import ( "fmt" "math/big" + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" "github.com/ethereum/go-ethereum/consensus/misc" @@ -274,8 +276,13 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse } return nil, nil } + tree, err := trie.NewShadowNodeSnapTree(db, false) + if err != nil { + panic(err) + } for i := 0; i < n; i++ { - statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil) + number := new(big.Int).Add(parent.Number(), common.Big1) + statedb, err := state.NewWithStateEpoch(config, number, parent.Root(), state.NewDatabase(db), nil, tree) if err != nil { panic(err) } @@ -284,6 +291,9 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse receipts[i] = receipt parent = block } + if err = tree.Journal(); err != nil { + panic(err) + } return blocks, receipts } diff --git a/core/rawdb/accessors_shadow_node.go b/core/rawdb/accessors_shadow_node.go new file mode 100644 index 0000000000..9b0322bbb2 --- /dev/null +++ b/core/rawdb/accessors_shadow_node.go @@ -0,0 +1,91 @@ +package rawdb + +import ( + "encoding/binary" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/log" +) + +func DeleteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter) { + if err := db.Delete(shadowNodeSnapshotJournalKey); err != nil { + log.Crit("Failed to remove snapshot journal", "err", err) + } +} + +func ReadShadowNodeSnapshotJournal(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(shadowNodeSnapshotJournalKey) + return data +} + +func WriteShadowNodeSnapshotJournal(db ethdb.KeyValueWriter, journal []byte) { + if err := db.Put(shadowNodeSnapshotJournalKey, journal); err != nil { + log.Crit("Failed to store snapshot journal", "err", err) + } +} + +func ReadShadowNodePlainStateMeta(db ethdb.KeyValueReader) []byte { + data, _ := db.Get(shadowNodePlainStateMeta) + return data +} + +func WriteShadowNodePlainStateMeta(db ethdb.KeyValueWriter, val []byte) error { + return db.Put(shadowNodePlainStateMeta, val) +} + +func ReadShadowNodeHistory(db ethdb.KeyValueReader, addr common.Hash, path string, number uint64) []byte { + val, _ := db.Get(shadowNodeHistoryKey(addr, path, number)) + return val +} + +func WriteShadowNodeHistory(db ethdb.KeyValueWriter, addr common.Hash, path string, number uint64, val []byte) error { + return db.Put(shadowNodeHistoryKey(addr, path, number), val) +} + +func ReadShadowNodeChangeSet(db ethdb.KeyValueReader, addr common.Hash, number uint64) []byte { + val, _ := db.Get(shadowNodeChangeSetKey(addr, number)) + return val +} + +func WriteShadowNodeChangeSet(db ethdb.KeyValueWriter, addr common.Hash, number uint64, val []byte) error { + return db.Put(shadowNodeChangeSetKey(addr, number), val) +} + +func ReadShadowNodePlainState(db ethdb.KeyValueReader, addr common.Hash, path string) []byte { + val, _ := db.Get(shadowNodePlainStateKey(addr, path)) + return val +} + +func WriteShadowNodePlainState(db ethdb.KeyValueWriter, addr common.Hash, path string, val []byte) error { + return db.Put(shadowNodePlainStateKey(addr, path), val) +} + +func DeleteShadowNodePlainState(db ethdb.KeyValueWriter, addr common.Hash, path string) error { + return db.Delete(shadowNodePlainStateKey(addr, path)) +} + +func shadowNodeChangeSetKey(addr common.Hash, number uint64) []byte { + key := make([]byte, len(ShadowNodeChangeSetPrefix)+8+len(addr)) + copy(key[:], ShadowNodeHistoryPrefix) + binary.BigEndian.PutUint64(key[len(ShadowNodeHistoryPrefix):], number) + copy(key[len(ShadowNodeHistoryPrefix)+8:], addr.Bytes()) + return key +} + +func shadowNodeHistoryKey(addr common.Hash, path string, number uint64) []byte { + key := make([]byte, len(ShadowNodeHistoryPrefix)+len(addr)+len(path)+8) + copy(key[:], ShadowNodeHistoryPrefix) + copy(key[len(ShadowNodeHistoryPrefix):], addr.Bytes()) + copy(key[len(ShadowNodeHistoryPrefix)+len(addr):], path) + binary.BigEndian.PutUint64(key[len(key)-8:], number) + return key +} + +func shadowNodePlainStateKey(addr common.Hash, path string) []byte { + key := make([]byte, len(ShadowNodePlainStatePrefix)+len(addr)+len(path)) + copy(key[:], ShadowNodeHistoryPrefix) + copy(key[len(ShadowNodeHistoryPrefix):], addr.Bytes()) + copy(key[len(ShadowNodeHistoryPrefix)+len(addr):], path) + return key +} diff --git a/core/rawdb/database.go b/core/rawdb/database.go index f63112f313..ab056784a9 100644 --- a/core/rawdb/database.go +++ b/core/rawdb/database.go @@ -450,6 +450,12 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { chtTrieNodes stat bloomTrieNodes stat + // shadow nodes statistics + shadowNodeMetadataSize stat + shadowNodeHistorySize stat + shadowNodeChangeSetSize stat + shadowNodePlainStateSize stat + // Meta- and unaccounted data metadata stat unaccounted stat @@ -508,13 +514,22 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { bytes.HasPrefix(key, []byte("bltIndex-")) || bytes.HasPrefix(key, []byte("bltRoot-")): // Bloomtrie sub bloomTrieNodes.Add(size) + + case bytes.Equal(key, shadowNodePlainStateMeta): + shadowNodeMetadataSize.Add(size) + case bytes.HasPrefix(key, ShadowNodeHistoryPrefix) && len(key) >= (len(ShadowNodeHistoryPrefix)+common.HashLength+8): + shadowNodeHistorySize.Add(size) + case bytes.HasPrefix(key, ShadowNodeChangeSetPrefix) && len(key) == (len(ShadowNodeChangeSetPrefix)+common.HashLength+8): + shadowNodeChangeSetSize.Add(size) + case bytes.HasPrefix(key, ShadowNodePlainStatePrefix) && len(key) >= (len(ShadowNodePlainStatePrefix)+common.HashLength): + shadowNodePlainStateSize.Add(size) default: var accounted bool for _, meta := range [][]byte{ databaseVersionKey, headHeaderKey, headBlockKey, headFastBlockKey, lastPivotKey, fastTrieProgressKey, snapshotDisabledKey, SnapshotRootKey, snapshotJournalKey, snapshotGeneratorKey, snapshotRecoveryKey, txIndexTailKey, fastTxLookupLimitKey, - uncleanShutdownKey, badBlockKey, transitionStatusKey, + uncleanShutdownKey, badBlockKey, transitionStatusKey, LastSafePointBlockKey, pruneAncientKey, shadowNodeSnapshotJournalKey, } { if bytes.Equal(key, meta) { metadata.Add(size) @@ -571,6 +586,10 @@ func InspectDatabase(db ethdb.Database, keyPrefix, keyStart []byte) error { {"Ancient store", "Block number->hash", ancientHashesSize.String(), ancients.String()}, {"Light client", "CHT trie nodes", chtTrieNodes.Size(), chtTrieNodes.Count()}, {"Light client", "Bloom trie nodes", bloomTrieNodes.Size(), bloomTrieNodes.Count()}, + {"Shadow Node", "Metadata", shadowNodeMetadataSize.Size(), shadowNodeMetadataSize.Count()}, + {"Shadow Node", "History", shadowNodeHistorySize.Size(), shadowNodeHistorySize.Count()}, + {"Shadow Node", "ChangeSet", shadowNodeChangeSetSize.Size(), shadowNodeChangeSetSize.Count()}, + {"Shadow Node", "PlainState", shadowNodePlainStateSize.Size(), shadowNodePlainStateSize.Count()}, } table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"Database", "Category", "Size", "Items"}) diff --git a/core/rawdb/schema.go b/core/rawdb/schema.go index e04f94e7d4..a1c56a79a0 100644 --- a/core/rawdb/schema.go +++ b/core/rawdb/schema.go @@ -63,6 +63,9 @@ var ( // snapshotSyncStatusKey tracks the snapshot sync status across restarts. snapshotSyncStatusKey = []byte("SnapshotSyncStatus") + // shadowNodeSnapshotJournalKey tracks the in-memory diff layers across restarts. + shadowNodeSnapshotJournalKey = []byte("ShadowNodeSnapshotJournalKey") + // txIndexTailKey tracks the oldest block whose transactions have been indexed. txIndexTailKey = []byte("TransactionIndexTail") @@ -93,6 +96,9 @@ var ( // transitionStatusKey tracks the eth2 transition status. transitionStatusKey = []byte("eth2-transition") + // shadowNodePlainStateMeta save disk layer meta data + shadowNodePlainStateMeta = []byte("shadowNodePlainStateMeta") + // Data item prefixes (use single byte to avoid mixing data types, avoid `i`, used for indexes). headerPrefix = []byte("h") // headerPrefix + num (uint64 big endian) + hash -> header headerTDSuffix = []byte("t") // headerPrefix + num (uint64 big endian) + hash + headerTDSuffix -> td @@ -108,6 +114,10 @@ var ( SnapshotStoragePrefix = []byte("o") // SnapshotStoragePrefix + account hash + storage hash -> storage trie value CodePrefix = []byte("c") // CodePrefix + code hash -> account code + ShadowNodeHistoryPrefix = []byte("sh") // ShadowNodeHistoryPrefix + addr hash + path + blockNr -> bitmap, default blockNr = math.MaxUint64 + ShadowNodeChangeSetPrefix = []byte("sc") // ShadowNodeChangeSetPrefix + addr hash + blockNr -> changeSet/prev val + ShadowNodePlainStatePrefix = []byte("sp") // ShadowNodePlainStatePrefix + addr hash + path -> val + // difflayer database diffLayerPrefix = []byte("d") // diffLayerPrefix + hash -> diffLayer @@ -252,6 +262,13 @@ func IsCodeKey(key []byte) (bool, []byte) { return false, nil } +func IsSnapStorageKey(key []byte) (bool, []byte, []byte) { + if bytes.HasPrefix(key, SnapshotStoragePrefix) && len(key) == common.HashLength+common.HashLength+len(SnapshotStoragePrefix) { + return true, key[len(SnapshotStoragePrefix):], key[len(SnapshotStoragePrefix) : len(SnapshotStoragePrefix)+common.HashLength] + } + return false, nil, nil +} + // configKey = configPrefix + hash func configKey(hash common.Hash) []byte { return append(configPrefix, hash.Bytes()...) diff --git a/core/state/database.go b/core/state/database.go index 0f31bc9139..b7b7c6be20 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -55,6 +55,9 @@ type Database interface { // OpenStorageTrie opens the storage trie of an account. OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + // OpenStorageTrieWithShadowNode opens the storage trie of an account and allow rw shadow nodes. + OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (Trie, error) + // CopyTrie returns an independent copy of the given trie. CopyTrie(Trie) Trie @@ -93,6 +96,9 @@ type Trie interface { // trie.MissingNodeError is returned. TryGet(key []byte) ([]byte, error) + // TryUpdateEpoch just update key's epoch, only using in storage trie + TryUpdateEpoch(key []byte) error + // TryUpdateAccount abstract an account write in the trie. TryUpdateAccount(key []byte, account *types.StateAccount) error @@ -110,6 +116,9 @@ type Trie interface { // can be used even if the trie doesn't have one. Hash() common.Hash + // HashKey return trie key hash result + HashKey(key []byte) []byte + // Commit writes all nodes to the trie's memory database, tracking the internal // and external (for account tries) references. Commit(onleaf trie.LeafCallback) (common.Hash, int, error) @@ -126,6 +135,12 @@ type Trie interface { // nodes of the longest existing prefix of the key (at least the root), ending // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error + + ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error + + ReviveTrie(proof []*trie.MPTProofNub) []*trie.MPTProofNub + + Epoch() types.StateEpoch } // NewDatabase creates a backing store for state. The returned database is safe for @@ -232,6 +247,7 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { } } + // TODO default using epoch0 trie, but need sndb to query shadow nodes tr, err := trie.NewSecure(root, db.db) if err != nil { return nil, err @@ -239,6 +255,21 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { return tr, nil } +// OpenStorageTrieWithShadowNode opens the storage trie of an account. +func (db *cachingDB) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (Trie, error) { + if db.noTries { + return trie.NewEmptyTrie(), nil + } + + // StorageTrie with ShadowNode cannot use trie cache, + // because of its sndb need update in every block, just reinit it + tr, err := trie.NewSecureWithShadowNodes(curEpoch, root, db.db, sndb) + if err != nil { + return nil, err + } + return tr, nil +} + func (db *cachingDB) CacheAccount(root common.Hash, t Trie) { if db.accountTrieCache == nil { return @@ -251,6 +282,12 @@ func (db *cachingDB) CacheStorage(addrHash common.Hash, root common.Hash, t Trie if db.storageTrieCache == nil { return } + + // do not cache trie with shadow nodes + if t.Epoch() > types.StateEpoch0 { + return + } + tr := t.(*trie.SecureTrie) if tries, exist := db.storageTrieCache.Get(addrHash); exist { triesArray := tries.([3]*triePair) diff --git a/core/state/errors.go b/core/state/errors.go new file mode 100644 index 0000000000..ebfd57c22f --- /dev/null +++ b/core/state/errors.go @@ -0,0 +1,57 @@ +package state + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/trie" +) + +// ExpiredStateError Access State error, must revert the execution +type ExpiredStateError struct { + Addr common.Address + Key common.Hash + Path []byte + Epoch types.StateEpoch + reason string +} + +func NewSnapExpiredStateError(addr common.Address, key common.Hash, epoch types.StateEpoch) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: key, + Path: []byte{}, + Epoch: epoch, + reason: "snap query", + } +} + +func NewExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: key, + Path: err.Path, + Epoch: err.Epoch, + reason: "query", + } +} + +func NewInsertExpiredStateError(addr common.Address, key common.Hash, err *trie.ExpiredNodeError) *ExpiredStateError { + return &ExpiredStateError{ + Addr: addr, + Key: key, + Path: err.Path, + Epoch: err.Epoch, + reason: "insert", + } +} + +func (e *ExpiredStateError) Error() string { + return fmt.Sprintf("Access expired state, addr: %v, path: %v, key: %v, epoch: %v, reason: %v", e.Addr, e.Path, e.Key, e.Epoch, e.reason) +} + +func (e *ExpiredStateError) Reason(r string) *ExpiredStateError { + e.reason = r + return e +} diff --git a/core/state/iterator.go b/core/state/iterator.go index 611df52431..bbb4329dab 100644 --- a/core/state/iterator.go +++ b/core/state/iterator.go @@ -109,6 +109,7 @@ func (it *NodeIterator) step() error { if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { return err } + // TODO(0xbundler): fix iterator later dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root) if err != nil { return err diff --git a/core/state/journal.go b/core/state/journal.go index 4f1fe2bf48..bf594585b2 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -139,6 +139,13 @@ type ( address *common.Address slot *common.Hash } + reviveStorageTrieNodeChange struct { + address *common.Address + } + accessedStorageStateChange struct { + address *common.Address + slot *common.Hash + } ) func (ch createObjectChange) revert(s *StateDB) { @@ -272,3 +279,27 @@ func (ch accessListAddSlotChange) revert(s *StateDB) { func (ch accessListAddSlotChange) dirtied() *common.Address { return nil } + +func (ch reviveStorageTrieNodeChange) revert(s *StateDB) { + s.getStateObject(*ch.address).dirtyReviveState = make(map[string]common.Hash) + s.getStateObject(*ch.address).dirtyReviveTrie = nil +} + +func (ch reviveStorageTrieNodeChange) dirtied() *common.Address { + return ch.address +} + +func (ch accessedStorageStateChange) revert(s *StateDB) { + if count, ok := s.getStateObject(*ch.address).dirtyAccessedState[*ch.slot]; ok { + if count > 1 { + s.getStateObject(*ch.address).dirtyAccessedState[*ch.slot] = count - 1 + } else { + delete(s.getStateObject(*ch.address).dirtyAccessedState, *ch.slot) + } + } + +} + +func (ch accessedStorageStateChange) dirtied() *common.Address { + return ch.address +} diff --git a/core/state/pruner/pruner.go b/core/state/pruner/pruner.go index 83c56d5785..eba9898a06 100644 --- a/core/state/pruner/pruner.go +++ b/core/state/pruner/pruner.go @@ -88,6 +88,7 @@ type Pruner struct { headHeader *types.Header snaptree *snapshot.Tree triesInMemory uint64 + latestEpoch types.StateEpoch } type BlockPruner struct { @@ -104,6 +105,14 @@ func NewPruner(db ethdb.Database, datadir, trieCachePath string, bloomSize, trie if headBlock == nil { return nil, errors.New("Failed to load head block") } + stored := rawdb.ReadCanonicalHash(db, 0) + chainConfig := rawdb.ReadChainConfig(db, stored) + if chainConfig == nil { + return nil, errors.New("cannot find chainConfig") + } + latestEpoch := types.GetStateEpoch(chainConfig, headBlock.Number()) + log.Info("NewPruner with", "number", headBlock.Number(), "latestEpoch", latestEpoch) + snaptree, err := snapshot.New(db, trie.NewDatabase(db), 256, int(triesInMemory), headBlock.Root(), false, false, false, false) if err != nil { return nil, err // The relevant snapshot(s) might not exist @@ -126,6 +135,7 @@ func NewPruner(db ethdb.Database, datadir, trieCachePath string, bloomSize, trie triesInMemory: triesInMemory, headHeader: headBlock.Header(), snaptree: snaptree, + latestEpoch: latestEpoch, }, nil } @@ -238,7 +248,7 @@ func pruneAll(maindb ethdb.Database, g *core.Genesis) error { return nil } -func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, stateBloom *stateBloom, bloomPath string, middleStateRoots map[common.Hash]struct{}, start time.Time) error { +func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, stateBloom *stateBloom, bloomPath string, middleStateRoots map[common.Hash]struct{}, start time.Time, latestEpoch types.StateEpoch) error { // Delete all stale trie nodes in the disk. With the help of state bloom // the trie nodes(and codes) belong to the active state will be filtered // out. A very small part of stale tries will also be filtered because of @@ -247,16 +257,31 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta // dangling node is the state root is super low. So the dangling nodes in // theory will never ever be visited again. var ( - count int - size common.StorageSize - pstart = time.Now() - logged = time.Now() - batch = maindb.NewBatch() - iter = maindb.NewIterator(nil, nil) + count int + snapCount int + size common.StorageSize + snapSize common.StorageSize + pstart = time.Now() + logged = time.Now() + batch = maindb.NewBatch() + iter = maindb.NewIterator(nil, nil) ) for iter.Next() { key := iter.Key() + // if it is snap kv, check if expired, do not follow parent>=child+2 prune rule, cover by trie node + isSnapKey, _, snapAddr := rawdb.IsSnapStorageKey(key) + if isSnapKey { + snapVal, err := snapshot.ParseSnapValFromBytes(iter.Value()) + if err == nil && types.EpochExpired(snapVal.Epoch, latestEpoch) { + batch.Delete(key) + snapCount += 1 + snapSize += common.StorageSize(len(key) + len(iter.Value())) + log.Info("delete expired snap kv", "addrHash", snapAddr, "kvEpoch", snapVal.Epoch, "epoch", latestEpoch) + } + continue + } + // All state entries don't belong to specific state and genesis are deleted here // - trie node // - legacy contract code @@ -310,6 +335,7 @@ func prune(snaptree *snapshot.Tree, root common.Hash, maindb ethdb.Database, sta } iter.Release() log.Info("Pruned state data", "nodes", count, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart))) + log.Info("Pruned snap data", "kvs", snapCount, "size", size, "elapsed", common.PrettyDuration(time.Since(pstart))) // Pruning is done, now drop the "useless" layers from the snapshot. // Firstly, flushing the target layer into the disk. After that all @@ -658,7 +684,7 @@ func (p *Pruner) Prune(root common.Hash) error { return err } log.Info("State bloom filter committed", "name", filterName) - return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start) + return prune(p.snaptree, root, p.db, p.stateBloom, filterName, middleRoots, start, p.latestEpoch) } // RecoverPruning will resume the pruning procedure during the system restart. @@ -680,6 +706,14 @@ func RecoverPruning(datadir string, db ethdb.Database, trieCachePath string, tri if headBlock == nil { return errors.New("Failed to load head block") } + stored := rawdb.ReadCanonicalHash(db, 0) + chainConfig := rawdb.ReadChainConfig(db, stored) + if chainConfig == nil { + return errors.New("cannot find chainConfig") + } + latestEpoch := types.GetStateEpoch(chainConfig, headBlock.Number()) + log.Info("RecoverPruning with", "number", headBlock.Number(), "latestEpoch", latestEpoch) + // Initialize the snapshot tree in recovery mode to handle this special case: // - Users run the `prune-state` command multiple times // - Neither these `prune-state` running is finished(e.g. interrupted manually) @@ -722,7 +756,7 @@ func RecoverPruning(datadir string, db ethdb.Database, trieCachePath string, tri log.Error("Pruning target state is not existent") return errors.New("non-existent target state") } - return prune(snaptree, stateBloomRoot, db, stateBloom, stateBloomPath, middleRoots, time.Now()) + return prune(snaptree, stateBloomRoot, db, stateBloom, stateBloomPath, middleRoots, time.Now(), latestEpoch) } // extractGenesis loads the genesis state and commits all the state entries @@ -756,6 +790,7 @@ func extractGenesis(db ethdb.Database, stateBloom *stateBloom) error { return err } if acc.Root != emptyRoot { + // TODO default using epoch0 trie, but need sndb to query shadow nodes storageTrie, err := trie.NewSecure(acc.Root, trie.NewDatabase(db)) if err != nil { return err diff --git a/core/state/snapshot/conversion.go b/core/state/snapshot/conversion.go index 250692422d..f7a79503a2 100644 --- a/core/state/snapshot/conversion.go +++ b/core/state/snapshot/conversion.go @@ -89,7 +89,7 @@ func GenerateTrie(snaptree *Tree, root common.Hash, src ethdb.Database, dst ethd } defer storageIt.Release() - hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) + hash, err := generateTrieRoot(dst, storageIt, accountHash, stackTrieGenerate, nil, stat, false) // This is where storage trie is generated. if err != nil { return common.Hash{}, err } diff --git a/core/state/snapshot/generate_test.go b/core/state/snapshot/generate_test.go index 9eb812f83b..1f45fe72a6 100644 --- a/core/state/snapshot/generate_test.go +++ b/core/state/snapshot/generate_test.go @@ -42,6 +42,7 @@ func TestGeneration(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 @@ -99,6 +100,7 @@ func TestGenerateExistentState(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 @@ -211,6 +213,7 @@ func (t *testHelper) addSnapStorage(accKey string, keys []string, vals []string) } func (t *testHelper) makeStorageTrie(keys []string, vals []string) []byte { + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, t.triedb) for i, k := range keys { stTrie.Update([]byte(k), []byte(vals[i])) @@ -428,6 +431,7 @@ func TestGenerateMissingStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 @@ -487,6 +491,8 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { diskdb = memorydb.New() triedb = trie.NewDatabase(diskdb) ) + + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, triedb) stTrie.Update([]byte("key-1"), []byte("val-1")) // 0x1314700b81afc49f94db3623ef1df38f3ed18b73a1b7ea2f6c095118cf6118a0 stTrie.Update([]byte("key-2"), []byte("val-2")) // 0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371 @@ -537,6 +543,7 @@ func TestGenerateCorruptStorageTrie(t *testing.T) { } func getStorageTrie(n int, triedb *trie.Database) *trie.SecureTrie { + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, _ := trie.NewSecure(common.Hash{}, triedb) for i := 0; i < n; i++ { k := fmt.Sprintf("key-%d", i) diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index 2f13631607..592564ce27 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -24,6 +24,8 @@ import ( "sync" "sync/atomic" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" @@ -182,6 +184,44 @@ type Tree struct { onFlatten func() // Hook invoked when the bottom most diff layers are flattened } +// SnapValue store snap with Epoch after BEP-216 +type SnapValue struct { + Epoch types.StateEpoch + Val common.Hash +} + +func ParseSnapValFromBytes(enc []byte) (*SnapValue, error) { + k, content, _, err := rlp.Split(enc) + if err != nil { + return nil, err + } + if k != rlp.List { + val := common.Hash{} + val.SetBytes(content) + return &SnapValue{ + Epoch: 0, + Val: val, + }, nil + } + var val SnapValue + if err = rlp.DecodeBytes(enc, &val); err != nil { + return nil, err + } + return &val, nil +} + +func NewSnapValBytes(epoch types.StateEpoch, val common.Hash) ([]byte, error) { + snapVal := SnapValue{ + Epoch: epoch, + Val: val, + } + enc, err := rlp.EncodeToBytes(&snapVal) + if err != nil { + return nil, err + } + return enc, nil +} + // New attempts to load an already existing snapshot from a persistent key-value // store (with a number of memory layers from a journal), ensuring that the head // of the snapshot matches the expected one. diff --git a/core/state/state_object.go b/core/state/state_object.go index 1ede96ec63..025b7a2303 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -24,6 +24,12 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/log" + + "github.com/ethereum/go-ethereum/core/state/snapshot" + + "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" @@ -79,9 +85,10 @@ type StateObject struct { dbErr error // Write caches. - trie Trie // storage trie, which becomes non-nil on first access + trie Trie // storage trie, which becomes non-nil on first access, it's committed trie code Code // contract bytecode, which gets set when code is loaded + // TODO(0xbundler) Attention: it's a shared storage between stateDBs, will expired, just disable sharedOriginStorage now sharedOriginStorage *sync.Map // Point to the entry of the stateObject in sharedPool originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction @@ -89,6 +96,18 @@ type StateObject struct { dirtyStorage Storage // Storage entries that have been modified in the current transaction execution fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. + // revive state + pendingReviveTrie Trie // pendingReviveTrie it contains pending revive trie nodes, could update & commit later + dirtyReviveTrie Trie // dirtyReviveTrie for tx + + // when R&W, access revive state first + pendingReviveState map[string]common.Hash // pendingReviveState for block, it cannot flush to trie, just cache + dirtyReviveState map[string]common.Hash // dirtyReviveState for tx, for cache dirtyReviveTrie + + // accessed state, don't record revive state, don't record nonexist state or any err access + pendingAccessedState map[common.Hash]int // pendingAccessedState record which state is accessed, it will update epoch index late + dirtyAccessedState map[common.Hash]int // dirtyAccessedState record which state is accessed, it will update epoch index later + // Cache flags. // When an object is marked suicided it will be delete from the trie // during the "update" phase of the state transition. @@ -97,7 +116,8 @@ type StateObject struct { deleted bool //encode - encodeData []byte + encodeData []byte + targetEpoch types.StateEpoch } // empty returns whether the account is considered empty. @@ -123,14 +143,19 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *St } return &StateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - sharedOriginStorage: storageMap, - originStorage: make(Storage), - pendingStorage: make(Storage), - dirtyStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + sharedOriginStorage: storageMap, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), + dirtyReviveState: make(map[string]common.Hash), + pendingReviveState: make(map[string]common.Hash), + dirtyAccessedState: make(map[common.Hash]int), + pendingAccessedState: make(map[common.Hash]int), + targetEpoch: db.targetEpoch, } } @@ -171,31 +196,70 @@ func (s *StateObject) getTrie(db Database) Trie { // prefetcher s.trie = prefetcher.trie(s.data.Root) } - if s.trie == nil { - var err error - s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + if s.trie != nil { + return s.trie + } + var err error + // check if enable state epoch + if s.db.enableAccStateEpoch(false, s.address) { + log.Debug("Open StorageTrie with shadow nodes", "addr", s.address, "targetEpoch", s.targetEpoch) + s.trie, err = db.OpenStorageTrieWithShadowNode(s.addrHash, s.data.Root, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) if err != nil { - s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) - s.setError(fmt.Errorf("can't create storage trie: %v", err)) + log.Error("OpenStorageTrieWithShadowNode err", "targetEpoch", s.targetEpoch, "err", err) + s.trie, _ = db.OpenStorageTrieWithShadowNode(s.addrHash, common.Hash{}, s.targetEpoch, s.db.openShadowStorage(s.addrHash)) + s.setError(fmt.Errorf("can't create storage trie with shadowNode: %v", err)) } + return s.trie + } + + log.Debug("Open StorageTrie normal", "addr", s.address, "targetEpoch", s.targetEpoch, "addr", s.address) + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + if err != nil { + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } return s.trie } +func (s *StateObject) getPendingReviveTrie(db Database) Trie { + if s.pendingReviveTrie == nil { + s.pendingReviveTrie = s.db.db.CopyTrie(s.getTrie(db)) + } + return s.pendingReviveTrie +} + +func (s *StateObject) getDirtyReviveTrie(db Database) Trie { + if s.dirtyReviveTrie == nil { + s.dirtyReviveTrie = s.db.db.CopyTrie(s.getPendingReviveTrie(db)) + } + return s.dirtyReviveTrie +} + // GetState retrieves a value from the account storage trie. -func (s *StateObject) GetState(db Database, key common.Hash) common.Hash { +func (s *StateObject) GetState(db Database, key common.Hash) (common.Hash, error) { // If the fake storage is set, only lookup the state here(in the debugging mode) if s.fakeStorage != nil { - return s.fakeStorage[key] + return s.fakeStorage[key], nil } // If we have a dirty value for this state entry, return it - value, dirty := s.dirtyStorage[key] - if dirty { - return value + if value, dirty := s.dirtyStorage[key]; dirty { + s.accessState(key) + return value, nil + } + if s.db.enableAccStateEpoch(true, s.address) { + if revived, revive := s.queryFromReviveState(db, s.dirtyReviveState, key); revive { + s.accessState(key) + return revived, nil + } } + // Otherwise return the entry's original value - return s.GetCommittedState(db, key) + committed, err := s.GetCommittedState(db, key) + if err == nil && committed != (common.Hash{}) { + s.accessState(key) + } + return committed, err } func (s *StateObject) getOriginStorage(key common.Hash) (common.Hash, bool) { @@ -223,24 +287,25 @@ func (s *StateObject) setOriginStorage(key common.Hash, value common.Hash) { } // GetCommittedState retrieves a value from the committed account storage trie. -func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Hash { +func (s *StateObject) GetCommittedState(db Database, key common.Hash) (common.Hash, error) { // If the fake storage is set, only lookup the state here(in the debugging mode) if s.fakeStorage != nil { - return s.fakeStorage[key] + return s.fakeStorage[key], nil } // If we have a pending write or clean cached, return that if value, pending := s.pendingStorage[key]; pending { - return value + return value, nil + } + if s.db.enableAccStateEpoch(true, s.address) { + if revived, revive := s.queryFromReviveState(db, s.pendingReviveState, key); revive { + return revived, nil + } } if value, cached := s.getOriginStorage(key); cached { - return value + return value, nil } // If no live objects are available, attempt to use snapshots - var ( - enc []byte - err error - ) if s.db.snap != nil { // If the object was destructed in *this* block (and potentially resurrected), // the storage has been cleared out, and we should *not* consult the previous @@ -249,29 +314,59 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Has // have been handles via pendingStorage above. // 2) we don't have new values, and can deliver empty response back if _, destructed := s.db.snapDestructs[s.address]; destructed { - return common.Hash{} + return common.Hash{}, nil } start := time.Now() - enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + enc, err := s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) if metrics.EnabledExpensive { s.db.SnapshotStorageReads += time.Since(start) } + + // snapshot val encode is different from trie, so handle independent + if err == nil { + var value common.Hash + if len(enc) > 0 { + sv, err := snapshot.ParseSnapValFromBytes(enc) + if err != nil { + s.setError(err) + } + if err == nil && s.db.enableAccStateEpoch(true, s.address) && + types.EpochExpired(sv.Epoch, s.targetEpoch) { + // query from dirty revive trie, got the newest expired info + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return common.Hash{}, NewExpiredStateError(s.address, key, enErr).Reason("snap query") + } + return common.Hash{}, NewSnapExpiredStateError(s.address, key, sv.Epoch) + } + value.SetBytes(sv.Val.Bytes()) + } + + s.setOriginStorage(key, value) + return value, nil + } } // If snapshot unavailable or reading from it failed, load from the database - if s.db.snap == nil || err != nil { - start := time.Now() - // if metrics.EnabledExpensive { - // meter = &s.db.StorageReads - // } - enc, err = s.getTrie(db).TryGet(key.Bytes()) - if metrics.EnabledExpensive { - s.db.StorageReads += time.Since(start) - } - if err != nil { - s.setError(err) - return common.Hash{} + start := time.Now() + //if metrics.EnabledExpensive { + // meter = &s.db.StorageReads + //} + enc, err := s.getTrie(db).TryGet(key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + // query from dirty revive trie, got the newest expired info + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return common.Hash{}, NewExpiredStateError(s.address, key, enErr) + } + return common.Hash{}, NewExpiredStateError(s.address, key, enErr) } + s.setError(err) + return common.Hash{}, nil } var value common.Hash if len(enc) > 0 { @@ -282,28 +377,49 @@ func (s *StateObject) GetCommittedState(db Database, key common.Hash) common.Has value.SetBytes(content) } s.setOriginStorage(key, value) - return value + return value, nil } // SetState updates a value in account storage. -func (s *StateObject) SetState(db Database, key, value common.Hash) { +func (s *StateObject) SetState(db Database, key, value common.Hash) error { // If the fake storage is set, put the temporary state update here. if s.fakeStorage != nil { s.fakeStorage[key] = value - return + return nil } // If the new value is the same as old, don't set - prev := s.GetState(db, key) + prev, err := s.GetState(db, key) + if exErr, ok := err.(*ExpiredStateError); ok { + exErr.Reason("query from insert") + return exErr + } + if err != nil { + return err + } if prev == value { - return + s.accessState(key) + return nil + } + // when state insert, check if valid to insert new state + if s.db.enableAccStateEpoch(true, s.address) && prev == (common.Hash{}) { + _, err = s.getDirtyReviveTrie(db).TryGet(key.Bytes()) + if err != nil { + if enErr, ok := err.(*trie.ExpiredNodeError); ok { + return NewInsertExpiredStateError(s.address, key, enErr) + } + s.setError(err) + return nil + } } // New value is different, update and journal the change + s.accessState(key) s.db.journal.append(storageChange{ account: &s.address, key: key, prevalue: prev, }) s.setState(key, value) + return nil } // SetStorage replaces the entire state storage with the given one. @@ -338,6 +454,13 @@ func (s *StateObject) finalise(prefetch bool) { slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure } } + for key, value := range s.dirtyReviveState { + s.pendingReviveState[key] = value + } + for key, value := range s.dirtyAccessedState { + count := s.pendingAccessedState[key] + s.pendingAccessedState[key] = count + value + } prefetcher := s.db.prefetcher if prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != emptyRoot { @@ -346,6 +469,16 @@ func (s *StateObject) finalise(prefetch bool) { if len(s.dirtyStorage) > 0 { s.dirtyStorage = make(Storage) } + if len(s.dirtyReviveState) > 0 { + s.dirtyReviveState = make(map[string]common.Hash) + } + if len(s.dirtyAccessedState) > 0 { + s.dirtyAccessedState = make(map[common.Hash]int) + } + if s.dirtyReviveTrie != nil { + s.pendingReviveTrie = s.dirtyReviveTrie + s.dirtyReviveTrie = nil + } } // updateTrie writes cached storage modifications into the object's storage trie. @@ -353,7 +486,7 @@ func (s *StateObject) finalise(prefetch bool) { func (s *StateObject) updateTrie(db Database) Trie { // Make sure all dirty slots are finalized into the pending storage area s.finalise(false) // Don't prefetch anymore, pull directly if need be - if len(s.pendingStorage) == 0 { + if len(s.pendingStorage) == 0 && len(s.pendingReviveState) == 0 { return s.trie } // Track the amount of time wasted on updating the storage trie @@ -364,36 +497,46 @@ func (s *StateObject) updateTrie(db Database) Trie { s.db.MetricsMux.Unlock() }(time.Now()) } - // Insert all the pending updates into the trie - tr := s.getTrie(db) + // Insert all the pending updates into the pending trie + tr := s.getPendingReviveTrie(db) usedStorage := make([][]byte, 0, len(s.pendingStorage)) - dirtyStorage := make(map[common.Hash][]byte) + dirtyStorage := make(map[common.Hash]common.Hash) + accessStorage := make(map[common.Hash]struct{}) + for k := range s.pendingAccessedState { + accessStorage[k] = struct{}{} + } for key, value := range s.pendingStorage { // Skip noop changes, persist actual changes if value == s.originStorage[key] { continue } s.originStorage[key] = value - var v []byte - if value != (common.Hash{}) { - // Encoding []byte cannot fail, ok to ignore the error. - v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) - } - dirtyStorage[key] = v + dirtyStorage[key] = value + delete(accessStorage, key) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for key, value := range dirtyStorage { - if len(value) == 0 { + var v []byte + if value != (common.Hash{}) { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + } + if len(v) == 0 { s.setError(tr.TryDelete(key[:])) } else { - s.setError(tr.TryUpdate(key[:], value)) + s.setError(tr.TryUpdate(key[:], v)) } usedStorage = append(usedStorage, common.CopyBytes(key[:])) } + // refresh accessed slots' epoch + for key := range accessStorage { + s.setError(tr.TryUpdateEpoch(key[:])) + usedStorage = append(usedStorage, common.CopyBytes(key[:])) + } }() if s.db.snap != nil { // If state snapshotting is active, cache the data til commit @@ -409,7 +552,11 @@ func (s *StateObject) updateTrie(db Database) Trie { } s.db.snapStorageMux.Unlock() for key, value := range dirtyStorage { - storage[string(key[:])] = value + enc, err := snapshot.NewSnapValBytes(s.targetEpoch, value) + if err != nil { + s.setError(err) + } + storage[string(key[:])] = enc } }() } @@ -423,6 +570,20 @@ func (s *StateObject) updateTrie(db Database) Trie { if len(s.pendingStorage) > 0 { s.pendingStorage = make(Storage) } + if len(s.pendingReviveState) > 0 { + s.pendingReviveState = make(map[string]common.Hash) + } + if len(s.pendingAccessedState) > 0 { + s.pendingAccessedState = make(map[common.Hash]int) + } + if s.pendingReviveTrie != nil { + s.pendingReviveTrie = nil + } + + // reset trie as pending trie, will commit later + if tr != nil { + s.trie = s.db.db.CopyTrie(tr) + } return tr } @@ -468,6 +629,10 @@ func (s *StateObject) CommitTrie(db Database) (int, error) { defer func(start time.Time) { s.db.StorageCommits += time.Since(start) }(time.Now()) } root, committed, err := s.trie.Commit(nil) + if err != nil { + log.Error("obj CommitTrie", "addr", s.address, "root", root, "err", err) + } + log.Debug("obj CommitTrie", "addr", s.address, "root", root, "err", err) if err == nil { s.data.Root = root } @@ -524,6 +689,29 @@ func (s *StateObject) deepCopy(db *StateDB) *StateObject { stateObject.suicided = s.suicided stateObject.dirtyCode = s.dirtyCode stateObject.deleted = s.deleted + + if s.dirtyReviveTrie != nil { + stateObject.dirtyReviveTrie = db.db.CopyTrie(s.dirtyReviveTrie) + } + if s.pendingReviveTrie != nil { + stateObject.pendingReviveTrie = db.db.CopyTrie(s.pendingReviveTrie) + } + stateObject.dirtyReviveState = make(map[string]common.Hash, len(s.dirtyReviveState)) + for k, v := range s.dirtyReviveState { + stateObject.dirtyReviveState[k] = v + } + stateObject.pendingReviveState = make(map[string]common.Hash, len(s.pendingReviveState)) + for k, v := range s.pendingReviveState { + stateObject.pendingReviveState[k] = v + } + stateObject.dirtyAccessedState = make(map[common.Hash]int, len(s.dirtyAccessedState)) + for k, v := range s.dirtyAccessedState { + stateObject.dirtyAccessedState[k] = v + } + stateObject.pendingAccessedState = make(map[common.Hash]int, len(s.pendingAccessedState)) + for k, v := range s.pendingAccessedState { + stateObject.pendingAccessedState[k] = v + } return stateObject } @@ -615,3 +803,48 @@ func (s *StateObject) Nonce() uint64 { func (s *StateObject) Value() *big.Int { panic("Value on StateObject should never be called") } + +func (s *StateObject) ReviveStorageTrie(proofCache trie.MPTProofCache) error { + dr := s.getDirtyReviveTrie(s.db.db) + s.db.journal.append(reviveStorageTrieNodeChange{ + address: &s.address, + }) + // revive nub and cache revive state + for _, nub := range dr.ReviveTrie(proofCache.CacheNubs()) { + kv, err := nub.ResolveKV() + if err != nil { + return err + } + for k, enc := range kv { + var value common.Hash + if len(enc) > 0 { + _, content, _, err := rlp.Split(enc) + if err != nil { + return err + } + value.SetBytes(content) + } + s.dirtyReviveState[k] = value + } + } + return nil +} + +func (s *StateObject) accessState(key common.Hash) { + if !s.db.enableAccStateEpoch(false, s.address) { + return + } + s.db.journal.append(accessedStorageStateChange{ + address: &s.address, + slot: &key, + }) + count := s.dirtyAccessedState[key] + s.dirtyAccessedState[key] = count + 1 +} + +// TODO(0xbundler): add hash key cache later +func (s *StateObject) queryFromReviveState(db Database, reviveState map[string]common.Hash, key common.Hash) (common.Hash, bool) { + hashKey := string(s.getTrie(db).HashKey(key.Bytes())) + val, ok := reviveState[hashKey] + return val, ok +} diff --git a/core/state/state_test.go b/core/state/state_test.go index 4cc5c33a85..af9e48648b 100644 --- a/core/state/state_test.go +++ b/core/state/state_test.go @@ -104,10 +104,10 @@ func TestNull(t *testing.T) { s.state.AccountsIntermediateRoot() s.state.Commit(nil) - if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { + if value, _ := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { t.Errorf("expected empty current value, got %x", value) } - if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { + if value, _ := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { t.Errorf("expected empty committed value, got %x", value) } } @@ -130,19 +130,19 @@ func TestSnapshot(t *testing.T) { s.state.SetState(stateobjaddr, storageaddr, data2) s.state.RevertToSnapshot(snapshot) - if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { + if v, _ := s.state.GetState(stateobjaddr, storageaddr); v != data1 { t.Errorf("wrong storage value %v, want %v", v, data1) } - if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) } // revert up to the genesis state and ensure correct content s.state.RevertToSnapshot(genesis) - if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong storage value %v, want %v", v, common.Hash{}) } - if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + if v, _ := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) } } diff --git a/core/state/statedb.go b/core/state/statedb.go index 617dbfa1b7..3a2fb772c0 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -26,6 +26,8 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/gopool" "github.com/ethereum/go-ethereum/core/rawdb" @@ -87,6 +89,7 @@ type StateDB struct { fullProcessed bool pipeCommit bool + shadowNodeDB trie.ShadowNodeDatabase snaps *snapshot.Tree snap snapshot.Snapshot snapAccountMux sync.Mutex // Mutex for snap account access @@ -128,6 +131,9 @@ type StateDB struct { validRevisions []revision nextRevisionId int + targetEpoch types.StateEpoch + targetBlk *big.Int + // Measurements gathered during execution for debugging purposes MetricsMux sync.Mutex AccountReads time.Duration @@ -148,14 +154,34 @@ type StateDB struct { StorageDeleted int } -// New creates a new state from a given trie. +// NewWithStateEpoch creates a new state from a given trie. +func NewWithStateEpoch(config *params.ChainConfig, targetBlock *big.Int, root common.Hash, db Database, snaps *snapshot.Tree, sntree *trie.ShadowNodeSnapTree) (*StateDB, error) { + targetEpoch := types.GetStateEpoch(config, targetBlock) + stateDB, err := newStateDB(root, db, snaps, targetEpoch) + if err != nil { + return nil, err + } + + log.Debug("NewWithStateEpoch", "targetBlock", targetBlock, "targetEpoch", targetEpoch, "root", root) + // init target block and shadowNodeRW + stateDB.targetBlk = targetBlock + stateDB.shadowNodeDB, err = trie.NewShadowNodeDatabase(sntree, targetBlock, root) + if err != nil { + return nil, err + } + return stateDB, nil +} + +// New creates a new state from a given trie, it inits at Epoch0 func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { - return newStateDB(root, db, snaps) + return newStateDB(root, db, snaps, types.StateEpoch0) } // NewWithSharedPool creates a new state with sharedStorge on layer 1.5 +// Deprecated: disable in state expiry, it inits at Epoch0 +// TODO(0xbundler) cannot use share pool in state revive now, need optimise later func NewWithSharedPool(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { - statedb, err := newStateDB(root, db, snaps) + statedb, err := newStateDB(root, db, snaps, types.StateEpoch0) if err != nil { return nil, err } @@ -163,7 +189,7 @@ func NewWithSharedPool(root common.Hash, db Database, snaps *snapshot.Tree) (*St return statedb, nil } -func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { +func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree, targetEpoch types.StateEpoch) (*StateDB, error) { sdb := &StateDB{ db: db, originalRoot: root, @@ -175,6 +201,7 @@ func newStateDB(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, preimages: make(map[common.Hash][]byte), journal: newJournal(), hasher: crypto.NewKeccakState(), + targetEpoch: targetEpoch, } if sdb.snaps != nil { @@ -473,12 +500,12 @@ func (s *StateDB) GetCodeHash(addr common.Address) common.Hash { } // GetState retrieves a value from the given account's storage trie. -func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { +func (s *StateDB) GetState(addr common.Address, hash common.Hash) (common.Hash, error) { stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.GetState(s.db, hash) } - return common.Hash{} + return common.Hash{}, nil } // GetProof returns the Merkle proof for a given account. @@ -496,7 +523,18 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { return proof, err } -// GetStorageProof returns the Merkle proof for given storage slot. +// GetStorageWitness returns only the Merkle proof for given storage slot. +func (s *StateDB) GetStorageWitness(a common.Address, prefixKeyHex []byte, key common.Hash) ([][]byte, error) { + var proof proofList + trie := s.StorageReviveTrie(a) + if trie == nil { + return proof, errors.New("storage trie for requested address does not exist") + } + err := trie.ProveStorageWitness(crypto.Keccak256(key.Bytes()), prefixKeyHex, &proof) // TODO (asyukii): Might not need the Keccak256 hash, revisit this + return proof, err +} + +// TODO: GetStorageProof returns the combined Merkle proof and Shadow Tree proof for given storage slot. func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { var proof proofList trie := s.StorageTrie(a) @@ -508,12 +546,12 @@ func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, } // GetCommittedState retrieves a value from the given account's committed storage trie. -func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { +func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) (common.Hash, error) { stateObject := s.getStateObject(addr) if stateObject != nil { return stateObject.GetCommittedState(s.db, hash) } - return common.Hash{} + return common.Hash{}, nil } // Database retrieves the low level database supporting the lower level trie ops. @@ -533,6 +571,16 @@ func (s *StateDB) StorageTrie(addr common.Address) Trie { return cpy.getTrie(s.db) } +func (s *StateDB) StorageReviveTrie(addr common.Address) Trie { + stateObject := s.getStateObject(addr) + if stateObject == nil { + return nil + } + cpy := stateObject.deepCopy(s) + cpy.updateTrie(s.db) + return cpy.getPendingReviveTrie(s.db) +} + func (s *StateDB) HasSuicided(addr common.Address) bool { stateObject := s.getStateObject(addr) if stateObject != nil { @@ -582,11 +630,12 @@ func (s *StateDB) SetCode(addr common.Address, code []byte) { } } -func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { +func (s *StateDB) SetState(addr common.Address, key, value common.Hash) error { stateObject := s.GetOrNewStateObject(addr) if stateObject != nil { - stateObject.SetState(s.db, key, value) + return stateObject.SetState(s.db, key, value) } + return nil } // SetStorage replaces the entire storage for the specified account with given @@ -847,6 +896,8 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { preimages: make(map[common.Hash][]byte, len(s.preimages)), journal: newJournal(), hasher: crypto.NewKeccakState(), + targetEpoch: s.targetEpoch, + targetBlk: s.targetBlk, } // Copy the dirty states, logs, and preimages for addr := range s.journal.dirties { @@ -906,6 +957,11 @@ func (s *StateDB) copyInternal(doPrefetch bool) *StateDB { // know that they need to explicitly terminate an active copy). state.prefetcher = state.prefetcher.copy() } + + if s.shadowNodeDB != nil { + state.shadowNodeDB = s.shadowNodeDB + } + if s.snaps != nil { // In order for the miner to be able to use and make additions // to the snapshot tree, we need to copy that aswell. @@ -1365,6 +1421,10 @@ func (s *StateDB) LightCommit() (common.Hash, *types.DiffLayer, error) { // Commit writes the state to the underlying in-memory trie database. func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() error) (common.Hash, *types.DiffLayer, error) { + if s.targetEpoch > 0 && s.shadowNodeDB == nil { + return common.Hash{}, nil, errors.New("cannot commit shadow node") + } + if s.dbErr != nil { s.StopPrefetcher() return common.Hash{}, nil, fmt.Errorf("commit aborted due to earlier error: %v", s.dbErr) @@ -1557,6 +1617,7 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er diffLayer.Destructs, diffLayer.Accounts, diffLayer.Storages = s.SnapToDiffLayer() // Only update if there's a state transition (skip empty Clique blocks) if parent := s.snap.Root(); parent != s.expectedRoot { + // TODO snap support epoch index err := s.snaps.Update(s.expectedRoot, parent, s.snapDestructs, s.snapAccounts, s.snapStorage, verified) if err != nil { @@ -1601,6 +1662,13 @@ func (s *StateDB) Commit(failPostCommitFunc func(), postCommitFuncs ...func() er root = s.expectedRoot } + log.Debug("statedb commit", "originalRoot", s.originalRoot, "root", root, "targetBlk", s.targetBlk, "targetEpoch", s.targetEpoch) + if s.shadowNodeDB != nil { + if err := s.shadowNodeDB.Commit(s.targetBlk, root); err != nil { + return common.Hash{}, nil, err + } + } + return root, diffLayer, nil } @@ -1742,3 +1810,97 @@ func (s *StateDB) GetDirtyAccounts() []common.Address { func (s *StateDB) GetStorage(address common.Address) *sync.Map { return s.storagePool.getStorage(address) } + +// ReviveTrie revive a trie with a given witness list +func (s *StateDB) ReviveStorageTrie(witnessList types.WitnessList) error { + if !s.enableStateEpoch(true) { + return errors.New("cannot revive any state before epoch2") + } + for i := range witnessList { + wit := witnessList[i] + // got specify witness, verify proof and check if revive success + switch wit.WitnessType { + case types.StorageTrieWitnessType: + data, err := wit.WitnessData() + if err != nil { + return err + } + stWit, ok := data.(*types.StorageTrieWitness) + if !ok { + return errors.New("got StorageTrieWitnessType data error") + } + proofCaches := make([]trie.MPTProofCache, len(stWit.ProofList)) + for j := range stWit.ProofList { + proofCaches[j] = trie.MPTProofCache{ + MPTProof: stWit.ProofList[j], + } + if err := proofCaches[j].VerifyProof(); err != nil { + return err + } + + stateObject := s.getStateObject(stWit.Address) + if stateObject == nil { + return errors.New("contract object not found") + } + if err := stateObject.ReviveStorageTrie(proofCaches[j]); err != nil { + return err + } + } + default: + return errors.New("unsupported WitnessType") + } + } + + return nil +} + +func (s *StateDB) openShadowStorage(addr common.Hash) trie.ShadowNodeStorage { + return trie.NewShadowNodeStorage4Trie(addr, s.shadowNodeDB) +} + +// enableStateEpoch return if enable state expiry hard fork, if inExpired, return if after epoch1 +func (s *StateDB) enableStateEpoch(inExpired bool) bool { + if !inExpired { + return s.targetEpoch > types.StateEpoch0 + } + + return s.targetEpoch > types.StateEpoch1 +} + +// enableAccStateEpoch return if enable account state expiry hard fork, if inExpired, return if after epoch1 +func (s *StateDB) enableAccStateEpoch(inExpired bool, addr common.Address) bool { + // TODO(0xbundler): temporary code, add IsToSystemContract in whitelist for avoid expiry system contract, + // it uses in testnet, because there no crossChain msg to update them + if systemContracts[addr] { + return false + } + return s.enableStateEpoch(inExpired) +} + +// TODO(0xbundler): temporary code, remove in release version +const ( + // genesis contracts + ValidatorContract = "0x0000000000000000000000000000000000001000" + SlashContract = "0x0000000000000000000000000000000000001001" + SystemRewardContract = "0x0000000000000000000000000000000000001002" + LightClientContract = "0x0000000000000000000000000000000000001003" + TokenHubContract = "0x0000000000000000000000000000000000001004" + RelayerIncentivizeContract = "0x0000000000000000000000000000000000001005" + RelayerHubContract = "0x0000000000000000000000000000000000001006" + GovHubContract = "0x0000000000000000000000000000000000001007" + TokenManagerContract = "0x0000000000000000000000000000000000001008" + CrossChainContract = "0x0000000000000000000000000000000000002000" + StakingContract = "0x0000000000000000000000000000000000002001" +) + +var systemContracts = map[common.Address]bool{ + common.HexToAddress(ValidatorContract): true, + common.HexToAddress(SlashContract): true, + common.HexToAddress(SystemRewardContract): true, + common.HexToAddress(LightClientContract): true, + common.HexToAddress(RelayerHubContract): true, + common.HexToAddress(GovHubContract): true, + common.HexToAddress(TokenHubContract): true, + common.HexToAddress(RelayerIncentivizeContract): true, + common.HexToAddress(CrossChainContract): true, +} diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 4b3a91cde6..e364c509fe 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -452,10 +452,12 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { // Check storage. if obj := state.getStateObject(addr); obj != nil { state.ForEachStorage(addr, func(key, value common.Hash) bool { - return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + val, _ := checkstate.GetState(addr, key) + return checkeq("GetState("+key.Hex()+")", val, value) }) checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { - return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + val, _ := checkstate.GetState(addr, key) + return checkeq("GetState("+key.Hex()+")", val, value) }) } if err != nil { @@ -529,10 +531,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) } - if val := state.GetState(addr, skey); val != sval { + if val, _ := state.GetState(addr, skey); val != sval { t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := state.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the non-committed state database and check pre/post commit balance @@ -543,10 +545,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("first copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } @@ -559,10 +561,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != sval { + if val, _ := copyOne.GetCommittedState(addr, skey); val != sval { t.Fatalf("first copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } // Copy the copy and check the balance once more @@ -573,10 +575,10 @@ func TestCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != sval { t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } } @@ -603,10 +605,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := state.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("initial code mismatch: have %x, want %x", code, []byte("hello")) } - if val := state.GetState(addr, skey); val != sval { + if val, _ := state.GetState(addr, skey); val != sval { t.Fatalf("initial non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := state.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := state.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("initial committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the non-committed state database and check pre/post commit balance @@ -617,10 +619,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyOne.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("first copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyOne.GetState(addr, skey); val != sval { + if val, _ := copyOne.GetState(addr, skey); val != sval { t.Fatalf("first copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyOne.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("first copy committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } // Copy the copy and check the balance once more @@ -631,10 +633,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy pre-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy pre-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != (common.Hash{}) { t.Fatalf("second copy pre-commit committed storage slot mismatch: have %x, want %x", val, common.Hash{}) } @@ -647,10 +649,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyTwo.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("second copy post-commit code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyTwo.GetState(addr, skey); val != sval { + if val, _ := copyTwo.GetState(addr, skey); val != sval { t.Fatalf("second copy post-commit non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyTwo.GetCommittedState(addr, skey); val != sval { + if val, _ := copyTwo.GetCommittedState(addr, skey); val != sval { t.Fatalf("second copy post-commit committed storage slot mismatch: have %x, want %x", val, sval) } // Copy the copy-copy and check the balance once more @@ -661,10 +663,10 @@ func TestCopyCopyCommitCopy(t *testing.T) { if code := copyThree.GetCode(addr); !bytes.Equal(code, []byte("hello")) { t.Fatalf("third copy code mismatch: have %x, want %x", code, []byte("hello")) } - if val := copyThree.GetState(addr, skey); val != sval { + if val, _ := copyThree.GetState(addr, skey); val != sval { t.Fatalf("third copy non-committed storage slot mismatch: have %x, want %x", val, sval) } - if val := copyThree.GetCommittedState(addr, skey); val != sval { + if val, _ := copyThree.GetCommittedState(addr, skey); val != sval { t.Fatalf("third copy committed storage slot mismatch: have %x, want %x", val, sval) } } diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go index cd51820e9e..6350b6b54d 100644 --- a/core/state/trie_prefetcher.go +++ b/core/state/trie_prefetcher.go @@ -473,6 +473,7 @@ func (sf *subfetcher) loop() { trie, err = sf.db.OpenTrie(sf.root) } else { // address is useless + // TODO(0xbundler): fix fetcher later trie, err = sf.db.OpenStorageTrie(sf.accountHash, sf.root) } if err != nil { @@ -491,6 +492,7 @@ func (sf *subfetcher) loop() { sf.trie, err = sf.db.OpenTrie(sf.root) } else { // address is useless + // TODO(0xbundler): fix fetcher later sf.trie, err = sf.db.OpenStorageTrie(sf.accountHash, sf.root) } if err != nil { diff --git a/core/state_processor.go b/core/state_processor.go index b42938adf9..0dc1d72607 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -122,7 +122,7 @@ func (p *LightStateProcessor) Process(block *types.Block, statedb *state.StateDB // prepare new statedb statedb.StopPrefetcher() parent := p.bc.GetHeader(block.ParentHash(), block.NumberU64()-1) - statedb, err = state.New(parent.Root, p.bc.stateCache, p.bc.snaps) + statedb, err = state.NewWithStateEpoch(p.config, block.Number(), parent.Root, p.bc.stateCache, p.bc.snaps, p.bc.shadowNodeTree) if err != nil { return statedb, nil, nil, 0, err } @@ -261,6 +261,7 @@ func (p *LightStateProcessor) LightProcess(diffLayer *types.DiffLayer, block *ty //update storage latestRoot := common.BytesToHash(latestAccount.Root) if latestRoot != previousAccount.Root { + // TODO(0xbundler): fix light process later accountTrie, err := statedb.Database().OpenStorageTrie(addrHash, previousAccount.Root) if err != nil { errChan <- err diff --git a/core/state_transition.go b/core/state_transition.go index 8083a4ea61..90f8e1610b 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -17,6 +17,7 @@ package core import ( + "errors" "fmt" "math" "math/big" @@ -80,6 +81,7 @@ type Message interface { IsFake() bool Data() []byte AccessList() types.AccessList + WitnessList() types.WitnessList } // ExecutionResult includes all output after executing given evm @@ -118,7 +120,7 @@ func (result *ExecutionResult) Revert() []byte { } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. -func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation bool, isHomestead, isEIP2028 bool) (uint64, error) { +func IntrinsicGas(data []byte, accessList types.AccessList, witnessList types.WitnessList, isContractCreation bool, isHomestead, isEIP2028 bool) (uint64, error) { // Set the starting gas for the raw transaction var gas uint64 if isContractCreation && isHomestead { @@ -155,6 +157,17 @@ func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation b gas += uint64(len(accessList)) * params.TxAccessListAddressGas gas += uint64(accessList.StorageKeys()) * params.TxAccessListStorageKeyGas } + + if witnessList != nil { + witGas, err := types.WitnessIntrinsicGas(witnessList) + if err != nil { + return 0, err + } + if (math.MaxUint64 - gas) < witGas { + return 0, ErrGasUintOverflow + } + gas += witGas + } return gas, nil } @@ -259,6 +272,20 @@ func (st *StateTransition) preCheck() error { } } } + + // check witness and hard fork + if st.msg.WitnessList() != nil { + if !st.evm.ChainConfig().IsElwood(st.evm.Context.BlockNumber) { + return errors.New("cannot allow witness before Elwood fork") + } + witnessList := st.msg.WitnessList() + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } + return st.buyGas() } @@ -315,7 +342,7 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { } } // Check clauses 4-5, subtract intrinsic gas if everything is correct - gas, err := IntrinsicGas(st.data, st.msg.AccessList(), contractCreation, rules.IsHomestead, rules.IsIstanbul) + gas, err := IntrinsicGas(st.data, st.msg.AccessList(), st.msg.WitnessList(), contractCreation, rules.IsHomestead, rules.IsIstanbul) if err != nil { return nil, err } @@ -333,6 +360,12 @@ func (st *StateTransition) TransitionDb() (*ExecutionResult, error) { if rules.IsBerlin { st.state.PrepareAccessList(msg.From(), msg.To(), vm.ActivePrecompiles(rules), msg.AccessList()) } + + // revive state before execution + if rules.IsElwood { + st.state.ReviveStorageTrie(msg.WitnessList()) + } + var ( ret []byte vmerr error // vm errors do not effect consensus and are therefore not assigned to err diff --git a/core/state_transition_test.go b/core/state_transition_test.go new file mode 100644 index 0000000000..b11d64bfad --- /dev/null +++ b/core/state_transition_test.go @@ -0,0 +1,117 @@ +package core + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" +) + +func keybytesToHex(str []byte) []byte { + l := len(str)*2 + 1 + var nibbles = make([]byte, l) + for i, b := range str { + nibbles[i*2] = b / 16 + nibbles[i*2+1] = b % 16 + } + nibbles[l-1] = 16 + return nibbles +} + +func makeMerkleProofWitness(addr *common.Address, keyLen, witSize, proofCount, proofLen int) types.ReviveWitness { + proofList := make([]types.MPTProof, witSize) + for i := range proofList { + proof := make([][]byte, proofCount) + for j := range proof { + proof[j] = bytes.Repeat([]byte{'p'}, proofLen) + } + proofList[i] = types.MPTProof{ + RootKeyHex: keybytesToHex(bytes.Repeat([]byte{'k'}, keyLen)), + Proof: proof, + } + } + wit := types.StorageTrieWitness{ + Address: *addr, + ProofList: proofList, + } + + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } + return types.ReviveWitness{ + WitnessType: types.StorageTrieWitnessType, + Data: enc, + } +} + +func TestIntrinsicGas_WitnessList(t *testing.T) { + address := common.HexToAddress("d4584b5f6229b7be90727b0fc8c6b91bb427821f") + + test_data := []struct { + // input + data []byte + accessList types.AccessList + witnessList types.WitnessList + isContractCreation bool + isHomestead bool + isEIP2028 bool + // expect + gas uint64 + }{ + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 0, 100, 512), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 53464, + }, + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 1, 0, 512), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 56792, + }, + { + data: common.Hex2Bytes("1234567890"), + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 100, 1, 1, 0), + }, + isContractCreation: true, + isHomestead: true, + isEIP2028: true, + gas: 56868, + }, + { + data: nil, + accessList: nil, + witnessList: []types.ReviveWitness{ + makeMerkleProofWitness(&address, 30, 2, 2, 32), + makeMerkleProofWitness(&address, 20, 1, 1, 36), + }, + isContractCreation: false, + isHomestead: true, + isEIP2028: true, + gas: 27804, + }, + } + + for _, item := range test_data { + gas, err := IntrinsicGas(item.data, item.accessList, item.witnessList, item.isContractCreation, item.isHomestead, item.isEIP2028) + assert.NoError(t, err) + assert.Equal(t, item.gas, gas) + } +} diff --git a/core/tx_pool.go b/core/tx_pool.go index b5f1d3a2fd..13c6602f8f 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -157,7 +157,7 @@ const ( type blockChain interface { CurrentBlock() *types.Block GetBlock(hash common.Hash, number uint64) *types.Block - StateAt(root common.Hash) (*state.StateDB, error) + StateAt(root common.Hash, number *big.Int) (*state.StateDB, error) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription } @@ -263,6 +263,7 @@ type TxPool struct { istanbul bool // Fork indicator whether we are in the istanbul stage. eip2718 bool // Fork indicator whether we are using EIP-2718 type transactions. eip1559 bool // Fork indicator whether we are using EIP-1559 type transactions. + isElwood bool // Fork indicator whether we are using BEP-216 type transactions. currentState *state.StateDB // Current state in the blockchain head pendingNonces *txNoncer // Pending state tracking virtual nonces @@ -641,7 +642,10 @@ func (pool *TxPool) local() map[common.Address]types.Transactions { func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { // Accept only legacy transactions until EIP-2718/2930 activates. if !pool.eip2718 && tx.Type() != types.LegacyTxType { - return ErrTxTypeNotSupported + // If isElwood, accept types.ReviveStateTxType + if !(pool.isElwood || tx.Type() == types.ReviveStateTxType) { + return ErrTxTypeNotSupported + } } // Reject dynamic fee transactions until EIP-1559 activates. if !pool.eip1559 && tx.Type() == types.DynamicFeeTxType { @@ -706,13 +710,26 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } // Ensure the transaction has more gas than the basic tx fee. - intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true, pool.istanbul) + witnessList := tx.WitnessList() + intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), witnessList, tx.To() == nil, true, pool.istanbul) if err != nil { return err } if tx.Gas() < intrGas { return ErrIntrinsicGas } + + // check witness and hard fork + if witnessList != nil { + if !pool.isElwood { + return errors.New("cannot allow witness before Elwood fork") + } + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } return nil } @@ -1414,7 +1431,8 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { if newHead == nil { newHead = pool.chain.CurrentBlock().Header() // Special case during testing } - statedb, err := pool.chain.StateAt(newHead.Root) + next := new(big.Int).Add(newHead.Number, big.NewInt(1)) + statedb, err := pool.chain.StateAt(newHead.Root, next) if err != nil { log.Error("Failed to reset txpool state", "err", err) return @@ -1429,10 +1447,10 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) { pool.addTxsLocked(reinject, false) // Update all fork indicator by next pending block number. - next := new(big.Int).Add(newHead.Number, big.NewInt(1)) pool.istanbul = pool.chainconfig.IsIstanbul(next) pool.eip2718 = pool.chainconfig.IsBerlin(next) pool.eip1559 = pool.chainconfig.IsLondon(next) + pool.isElwood = pool.chainconfig.IsElwood(next) } // promoteExecutables moves transactions that have become processable from the diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index d3b40f1f4a..9b25781b2c 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -73,7 +73,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return bc.CurrentBlock() } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/core/types/access_list_tx.go b/core/types/access_list_tx.go index 8ad5e739e9..b2c4da78f3 100644 --- a/core/types/access_list_tx.go +++ b/core/types/access_list_tx.go @@ -106,6 +106,10 @@ func (tx *AccessListTx) value() *big.Int { return tx.Value } func (tx *AccessListTx) nonce() uint64 { return tx.Nonce } func (tx *AccessListTx) to() *common.Address { return tx.To } +func (tx *AccessListTx) witnessList() WitnessList { + return nil +} + func (tx *AccessListTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/dynamic_fee_tx.go b/core/types/dynamic_fee_tx.go index 53f246ea1f..54de7dcfed 100644 --- a/core/types/dynamic_fee_tx.go +++ b/core/types/dynamic_fee_tx.go @@ -94,6 +94,10 @@ func (tx *DynamicFeeTx) value() *big.Int { return tx.Value } func (tx *DynamicFeeTx) nonce() uint64 { return tx.Nonce } func (tx *DynamicFeeTx) to() *common.Address { return tx.To } +func (tx *DynamicFeeTx) witnessList() WitnessList { + return nil +} + func (tx *DynamicFeeTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/legacy_tx.go b/core/types/legacy_tx.go index cb86bed772..f0857acc8c 100644 --- a/core/types/legacy_tx.go +++ b/core/types/legacy_tx.go @@ -103,6 +103,10 @@ func (tx *LegacyTx) value() *big.Int { return tx.Value } func (tx *LegacyTx) nonce() uint64 { return tx.Nonce } func (tx *LegacyTx) to() *common.Address { return tx.To } +func (tx *LegacyTx) witnessList() WitnessList { + return nil +} + func (tx *LegacyTx) rawSignatureValues() (v, r, s *big.Int) { return tx.V, tx.R, tx.S } diff --git a/core/types/receipt.go b/core/types/receipt.go index 90c9dbcd02..27b81aa95e 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -198,7 +198,7 @@ func (r *Receipt) DecodeRLP(s *rlp.Stream) error { return errEmptyTypedReceipt } r.Type = b[0] - if r.Type == AccessListTxType || r.Type == DynamicFeeTxType { + if r.Type == AccessListTxType || r.Type == DynamicFeeTxType || r.Type == ReviveStateTxType { var dec receiptRLP if err := rlp.DecodeBytes(b[1:], &dec); err != nil { return err @@ -234,7 +234,7 @@ func (r *Receipt) decodeTyped(b []byte) error { return errEmptyTypedReceipt } switch b[0] { - case DynamicFeeTxType, AccessListTxType: + case DynamicFeeTxType, AccessListTxType, ReviveStateTxType: var data receiptRLP err := rlp.DecodeBytes(b[1:], &data) if err != nil { @@ -407,6 +407,9 @@ func (rs Receipts) EncodeIndex(i int, w *bytes.Buffer) { case DynamicFeeTxType: w.WriteByte(DynamicFeeTxType) rlp.Encode(w, data) + case ReviveStateTxType: + w.WriteByte(ReviveStateTxType) + rlp.Encode(w, data) default: // For unsupported types, write nothing. Since this is for // DeriveSha, the error will be caught matching the derived hash diff --git a/core/types/revive_state_tx.go b/core/types/revive_state_tx.go new file mode 100644 index 0000000000..4fac02a033 --- /dev/null +++ b/core/types/revive_state_tx.go @@ -0,0 +1,128 @@ +package types + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" +) + +type WitnessList []ReviveWitness + +// ReviveStateTx is the transaction for revive state. +type ReviveStateTx struct { + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + WitnessList WitnessList // revive witness + + V, R, S *big.Int // signature values +} + +func (tx *ReviveStateTx) txType() byte { + return ReviveStateTxType +} + +func (tx *ReviveStateTx) copy() TxData { + cpy := &ReviveStateTx{ + Nonce: tx.Nonce, + To: copyAddressPtr(tx.To), + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + WitnessList: make(WitnessList, len(tx.WitnessList)), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + + for i := range tx.WitnessList { + cpy.WitnessList[i] = tx.WitnessList[i].Copy() + } + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +func (tx *ReviveStateTx) chainID() *big.Int { + return deriveChainId(tx.V) +} + +func (tx *ReviveStateTx) accessList() AccessList { + return nil +} + +func (tx *ReviveStateTx) witnessList() WitnessList { + return tx.WitnessList +} + +func (tx *ReviveStateTx) data() []byte { + return tx.Data +} + +func (tx *ReviveStateTx) gas() uint64 { + return tx.Gas +} + +func (tx *ReviveStateTx) gasPrice() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) gasTipCap() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) gasFeeCap() *big.Int { + return tx.GasPrice +} + +func (tx *ReviveStateTx) value() *big.Int { + return tx.Value +} + +func (tx *ReviveStateTx) nonce() uint64 { + return tx.Nonce +} + +func (tx *ReviveStateTx) to() *common.Address { + return tx.To +} + +func (tx *ReviveStateTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *ReviveStateTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} + +func WitnessIntrinsicGas(wits WitnessList) (uint64, error) { + totalGas := uint64(0) + for i := 0; i < len(wits); i++ { + totalGas += wits[i].Size() * params.TxWitnessListStorageGasPerByte + addGas, err := wits[i].AdditionalIntrinsicGas() + if err != nil { + return 0, err + } + totalGas += addGas + } + return totalGas, nil +} diff --git a/core/types/revive_witness.go b/core/types/revive_witness.go new file mode 100644 index 0000000000..815df6f5fb --- /dev/null +++ b/core/types/revive_witness.go @@ -0,0 +1,144 @@ +package types + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" +) + +var ( + ErrUnknownWitnessType = errors.New("unknown revive witness type") + ErrWitnessEmptyData = errors.New("witness data is empty") + ErrStorageTrieWitnessEmptyProofList = errors.New("StorageTrieWitness: empty proof list") + ErrStorageTrieWitnessEmptyInnerProof = errors.New("StorageTrieWitness: empty inner proof") + ErrStorageTrieWitnessWrongProofSize = errors.New("StorageTrieWitness: wrong proof size") +) + +const ( + StorageTrieWitnessType = iota +) + +// MPTProof in order to degrade revive partial trie node, +// only allow on path proof, not support tree path, +// will verify the whole path later +// Attention: The proof could revive multi-vals, although it's a single trie path witness +type MPTProof struct { + RootKeyHex []byte // prefix key in nibbles format, max 65 bytes. TODO: optimize witness size + Proof [][]byte // list of RLP-encoded nodes +} + +type StorageTrieWitness struct { + Address common.Address // target account address + ProofList []MPTProof // revive multiple slots (same address) +} + +func (s *StorageTrieWitness) AdditionalIntrinsicGas() (uint64, error) { + count := 0 + words := 0 + for i := range s.ProofList { + for j := range s.ProofList[i].Proof { + count++ + words += (len(s.ProofList[i].Proof[j]) + 31) / 32 + } + } + + return uint64(count)*params.TxWitnessListVerifyMPTBaseGas + uint64(words)*params.TxWitnessListVerifyMPTGasPerWord, nil +} + +// VerifyWitness only check format, merkle proof check later +func (s *StorageTrieWitness) VerifyWitness() error { + if len(s.ProofList) == 0 { + return ErrStorageTrieWitnessEmptyProofList + } + for i := range s.ProofList { + if len(s.ProofList[i].Proof) == 0 { + return ErrStorageTrieWitnessEmptyInnerProof + } + for j := range s.ProofList[i].Proof { + // The smallest size is a valueNode, The largest size is the full fullNode + if len(s.ProofList[i].Proof[j]) < 32 || len(s.ProofList[i].Proof[j]) > 544 { + return ErrStorageTrieWitnessWrongProofSize + } + } + } + + return nil +} + +// ReviveWitnessData the common method of witness +type ReviveWitnessData interface { + // AdditionalIntrinsicGas got additional gas consumption + AdditionalIntrinsicGas() (uint64, error) + // VerifyWitness check if valid witness format, used before state revival + VerifyWitness() error +} + +// ReviveWitness for revive witness +// Attention: it's not thread safe +type ReviveWitness struct { + WitnessType byte // only support Merkle Proof for now + Data []byte // witness data, it's rlp format + cache ReviveWitnessData `rlp:"-" json:"-"` // cache if not encode to rlp or json, it used for performance +} + +// Size estimate witness byte size +func (r *ReviveWitness) Size() uint64 { + return uint64(len(r.Data) + 1) +} + +// Copy deep copy +func (r *ReviveWitness) Copy() ReviveWitness { + witness := ReviveWitness{ + WitnessType: r.WitnessType, + Data: make([]byte, len(r.Data)), + } + copy(witness.Data, r.Data) + return witness +} + +func (r *ReviveWitness) WitnessData() (ReviveWitnessData, error) { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return nil, err + } + } + return r.cache, nil +} + +func (r *ReviveWitness) AdditionalIntrinsicGas() (uint64, error) { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return 0, err + } + } + return r.cache.AdditionalIntrinsicGas() +} + +func (r *ReviveWitness) VerifyWitness() error { + if r.cache == nil { + if err := r.parseWitness(); err != nil { + return err + } + } + return r.cache.VerifyWitness() +} + +func (r *ReviveWitness) parseWitness() error { + if len(r.Data) == 0 { + return ErrWitnessEmptyData + } + switch r.WitnessType { + case StorageTrieWitnessType: + var cache StorageTrieWitness + if err := rlp.DecodeBytes(r.Data, &cache); err != nil { + return err + } + r.cache = &cache + default: + return ErrUnknownWitnessType + } + + return nil +} diff --git a/core/types/revive_witness_test.go b/core/types/revive_witness_test.go new file mode 100644 index 0000000000..025258b21f --- /dev/null +++ b/core/types/revive_witness_test.go @@ -0,0 +1,127 @@ +package types + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" +) + +func makeSimpleReviveWitness(witType byte, data []byte) ReviveWitness { + return ReviveWitness{ + WitnessType: witType, + Data: data, + } +} + +func makeReviveWitnessFromStorageTrieWitness(wit StorageTrieWitness) ReviveWitness { + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } + return ReviveWitness{ + WitnessType: StorageTrieWitnessType, + Data: enc, + } +} + +func makeStorageTrieWitness(addr common.Address, proofCount int, proofLen ...int) StorageTrieWitness { + proofList := make([]MPTProof, proofCount) + for i := 0; i < proofCount; i++ { + proof := make([][]byte, len(proofLen)) + for j := range proofLen { + proof[j] = bytes.Repeat([]byte{'f'}, proofLen[j]) + } + proofList[i] = MPTProof{ + RootKeyHex: nil, + Proof: proof, + } + } + wit := StorageTrieWitness{ + Address: addr, + ProofList: proofList, + } + + return wit +} + +func TestVerifyWitness(t *testing.T) { + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + testData := []struct { + wit StorageTrieWitness + expect error + }{ + { + wit: makeStorageTrieWitness(addr, 0, 0), + expect: ErrStorageTrieWitnessEmptyProofList, + }, + { + wit: makeStorageTrieWitness(addr, 1, 31), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeStorageTrieWitness(addr, 1, 545), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeStorageTrieWitness(addr, 1, 32), + expect: nil, + }, + { + wit: makeStorageTrieWitness(addr, 1, 544), + expect: nil, + }, + { + wit: makeStorageTrieWitness(addr, 10, 32, 544, 33, 543, 128, 99), + expect: nil, + }, + } + + for i := range testData { + assert.Equal(t, testData[i].expect, testData[i].wit.VerifyWitness(), i) + } +} + +func TestReviveWitness_VerifyWitness(t *testing.T) { + + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + testData := []struct { + wit ReviveWitness + expect error + }{ + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 0, 0)), + expect: ErrStorageTrieWitnessEmptyProofList, + }, + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 1, 31)), + expect: ErrStorageTrieWitnessWrongProofSize, + }, + { + wit: makeSimpleReviveWitness(0, nil), + expect: ErrWitnessEmptyData, + }, + { + wit: makeSimpleReviveWitness(1, bytes.Repeat([]byte{'e'}, 10)), + expect: ErrUnknownWitnessType, + }, + { + wit: makeSimpleReviveWitness(0, bytes.Repeat([]byte{'e'}, 10)), + expect: nil, + }, + { + wit: makeReviveWitnessFromStorageTrieWitness(makeStorageTrieWitness(addr, 10, 33, 544, 99)), + expect: nil, + }, + } + + for i := range testData { + if i == 4 { + assert.Error(t, testData[i].wit.VerifyWitness()) + continue + } + assert.Equal(t, testData[i].expect, testData[i].wit.VerifyWitness(), i) + } +} diff --git a/core/types/state_epoch.go b/core/types/state_epoch.go new file mode 100644 index 0000000000..eb677009d7 --- /dev/null +++ b/core/types/state_epoch.go @@ -0,0 +1,45 @@ +package types + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/params" +) + +var ( + DefaultStateEpochPeriod = uint64(7_008_000) + StateEpoch0 = StateEpoch(0) + StateEpoch1 = StateEpoch(1) +) + +type StateEpoch uint16 + +// GetStateEpoch computes the current state epoch by hard fork and block number +// state epoch will indicate if the state is accessible or expiry. +// Before ClaudeBlock indicates state epoch0. +// ClaudeBlock indicates start state epoch1. +// ElwoodBlock indicates start state epoch2 and start epoch rotate by StateEpochPeriod. +// When N>=2 and epochN started, epoch(N-2)'s state will expire. +func GetStateEpoch(config *params.ChainConfig, blockNumber *big.Int) StateEpoch { + epochPeriod := DefaultStateEpochPeriod + if config.Parlia != nil && config.Parlia.StateEpochPeriod != 0 { + epochPeriod = config.Parlia.StateEpochPeriod + } + if config.IsElwood(blockNumber) { + epochPeriodInt := new(big.Int).SetUint64(epochPeriod) + ret := new(big.Int).Sub(blockNumber, config.ElwoodBlock) + ret.Div(ret, epochPeriodInt) + ret.Add(ret, common.Big2) + return StateEpoch(ret.Uint64()) + } else if config.IsClaude(blockNumber) { + return 1 + } else { + return 0 + } +} + +// EpochExpired check pre epoch if expired compared to current epoch +func EpochExpired(pre StateEpoch, cur StateEpoch) bool { + return cur >= 2 && pre < cur-1 +} diff --git a/core/types/state_epoch_test.go b/core/types/state_epoch_test.go new file mode 100644 index 0000000000..892ae89bef --- /dev/null +++ b/core/types/state_epoch_test.go @@ -0,0 +1,101 @@ +package types + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/params" + "github.com/stretchr/testify/assert" +) + +var epochPeriod = new(big.Int).SetUint64(DefaultStateEpochPeriod) + +func TestStateForkConfig(t *testing.T) { + temp := ¶ms.ChainConfig{} + assert.NoError(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(1), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(0), + ElwoodBlock: big.NewInt(0), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10000), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(2), + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(0), + ElwoodBlock: big.NewInt(1), + } + assert.Error(t, temp.CheckConfigForkOrder()) + + temp = ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10001), + } + assert.NoError(t, temp.CheckConfigForkOrder()) +} + +func TestSimpleStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(20000), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, StateEpoch0, GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(1000))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(19999))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(20000))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(20000), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) +} + +func TestNoZeroStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(1), + ElwoodBlock: big.NewInt(2), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(1))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(2))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(2), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) +} + +func TestNearestStateEpoch(t *testing.T) { + temp := ¶ms.ChainConfig{ + ClaudeBlock: big.NewInt(10000), + ElwoodBlock: big.NewInt(10001), + } + assert.NoError(t, temp.CheckConfigForkOrder()) + + assert.Equal(t, StateEpoch(0), GetStateEpoch(temp, big.NewInt(0))) + assert.Equal(t, StateEpoch(1), GetStateEpoch(temp, big.NewInt(10000))) + assert.Equal(t, StateEpoch(2), GetStateEpoch(temp, big.NewInt(10001))) + assert.Equal(t, StateEpoch(3), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), epochPeriod))) + assert.Equal(t, StateEpoch(102), GetStateEpoch(temp, new(big.Int).Add(big.NewInt(10001), new(big.Int).Mul(big.NewInt(100), epochPeriod)))) +} diff --git a/core/types/transaction.go b/core/types/transaction.go index 95a9da87af..4984eb709a 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -45,6 +45,8 @@ const ( LegacyTxType = iota AccessListTxType DynamicFeeTxType + + ReviveStateTxType = 127 // reserve for BSC revive state ) // Transaction is an Ethereum transaction. @@ -74,6 +76,7 @@ type TxData interface { chainID() *big.Int accessList() AccessList + witnessList() WitnessList data() []byte gas() uint64 gasPrice() *big.Int @@ -191,6 +194,10 @@ func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { var inner DynamicFeeTx err := rlp.DecodeBytes(b[1:], &inner) return &inner, err + case ReviveStateTxType: + var inner ReviveStateTx + err := rlp.DecodeBytes(b[1:], &inner) + return &inner, err default: return nil, ErrTxTypeNotSupported } @@ -268,6 +275,10 @@ func (tx *Transaction) Data() []byte { return tx.inner.data() } // AccessList returns the access list of the transaction. func (tx *Transaction) AccessList() AccessList { return tx.inner.accessList() } +func (tx *Transaction) WitnessList() WitnessList { + return tx.inner.witnessList() +} + // Gas returns the gas limit of the transaction. func (tx *Transaction) Gas() uint64 { return tx.inner.gas() } @@ -655,48 +666,51 @@ func (t *TransactionsByPriceAndNonce) Forward(tx *Transaction) { // // NOTE: In a future PR this will be removed. type Message struct { - to *common.Address - from common.Address - nonce uint64 - amount *big.Int - gasLimit uint64 - gasPrice *big.Int - gasFeeCap *big.Int - gasTipCap *big.Int - data []byte - accessList AccessList - isFake bool -} - -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice, gasFeeCap, gasTipCap *big.Int, data []byte, accessList AccessList, isFake bool) Message { + to *common.Address + from common.Address + nonce uint64 + amount *big.Int + gasLimit uint64 + gasPrice *big.Int + gasFeeCap *big.Int + gasTipCap *big.Int + data []byte + accessList AccessList + witnessList WitnessList + isFake bool +} + +func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice, gasFeeCap, gasTipCap *big.Int, data []byte, accessList AccessList, witnessList WitnessList, isFake bool) Message { return Message{ - from: from, - to: to, - nonce: nonce, - amount: amount, - gasLimit: gasLimit, - gasPrice: gasPrice, - gasFeeCap: gasFeeCap, - gasTipCap: gasTipCap, - data: data, - accessList: accessList, - isFake: isFake, + from: from, + to: to, + nonce: nonce, + amount: amount, + gasLimit: gasLimit, + gasPrice: gasPrice, + gasFeeCap: gasFeeCap, + gasTipCap: gasTipCap, + data: data, + accessList: accessList, + witnessList: witnessList, + isFake: isFake, } } // AsMessage returns the transaction as a core.Message. func (tx *Transaction) AsMessage(s Signer, baseFee *big.Int) (Message, error) { msg := Message{ - nonce: tx.Nonce(), - gasLimit: tx.Gas(), - gasPrice: new(big.Int).Set(tx.GasPrice()), - gasFeeCap: new(big.Int).Set(tx.GasFeeCap()), - gasTipCap: new(big.Int).Set(tx.GasTipCap()), - to: tx.To(), - amount: tx.Value(), - data: tx.Data(), - accessList: tx.AccessList(), - isFake: false, + nonce: tx.Nonce(), + gasLimit: tx.Gas(), + gasPrice: new(big.Int).Set(tx.GasPrice()), + gasFeeCap: new(big.Int).Set(tx.GasFeeCap()), + gasTipCap: new(big.Int).Set(tx.GasTipCap()), + to: tx.To(), + amount: tx.Value(), + data: tx.Data(), + accessList: tx.AccessList(), + witnessList: tx.WitnessList(), + isFake: false, } // If baseFee provided, set gasPrice to effectiveGasPrice. if baseFee != nil { @@ -716,17 +730,18 @@ func (tx *Transaction) AsMessageNoNonceCheck(s Signer) (Message, error) { return msg, err } -func (m Message) From() common.Address { return m.from } -func (m Message) To() *common.Address { return m.to } -func (m Message) GasPrice() *big.Int { return m.gasPrice } -func (m Message) GasFeeCap() *big.Int { return m.gasFeeCap } -func (m Message) GasTipCap() *big.Int { return m.gasTipCap } -func (m Message) Value() *big.Int { return m.amount } -func (m Message) Gas() uint64 { return m.gasLimit } -func (m Message) Nonce() uint64 { return m.nonce } -func (m Message) Data() []byte { return m.data } -func (m Message) AccessList() AccessList { return m.accessList } -func (m Message) IsFake() bool { return m.isFake } +func (m Message) From() common.Address { return m.from } +func (m Message) To() *common.Address { return m.to } +func (m Message) GasPrice() *big.Int { return m.gasPrice } +func (m Message) GasFeeCap() *big.Int { return m.gasFeeCap } +func (m Message) GasTipCap() *big.Int { return m.gasTipCap } +func (m Message) Value() *big.Int { return m.amount } +func (m Message) Gas() uint64 { return m.gasLimit } +func (m Message) Nonce() uint64 { return m.nonce } +func (m Message) Data() []byte { return m.data } +func (m Message) AccessList() AccessList { return m.accessList } +func (m Message) WitnessList() WitnessList { return m.witnessList } +func (m Message) IsFake() bool { return m.isFake } // copyAddressPtr copies an address. func copyAddressPtr(a *common.Address) *common.Address { diff --git a/core/types/transaction_marshalling.go b/core/types/transaction_marshalling.go index aad31a5a97..346fbb10f3 100644 --- a/core/types/transaction_marshalling.go +++ b/core/types/transaction_marshalling.go @@ -46,6 +46,9 @@ type txJSON struct { ChainID *hexutil.Big `json:"chainId,omitempty"` AccessList *AccessList `json:"accessList,omitempty"` + // Revive State transaction fields: + WitnessList *WitnessList `json:"witnessList,omitempty"` + // Only used for encoding: Hash common.Hash `json:"hash"` } @@ -94,6 +97,17 @@ func (t *Transaction) MarshalJSON() ([]byte, error) { enc.V = (*hexutil.Big)(tx.V) enc.R = (*hexutil.Big)(tx.R) enc.S = (*hexutil.Big)(tx.S) + case *ReviveStateTx: + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.WitnessList = &tx.WitnessList + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) } return json.Marshal(&enc) } @@ -263,6 +277,54 @@ func (t *Transaction) UnmarshalJSON(input []byte) error { } } + case ReviveStateTxType: + var itx ReviveStateTx + inner = &itx + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.WitnessList == nil { + return errors.New("missing required field 'WitnessList' in transaction") + } + itx.WitnessList = *dec.WitnessList + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, true); err != nil { + return err + } + } default: return ErrTxTypeNotSupported } diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 1d0d2a4c75..184b91c063 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -40,6 +40,8 @@ type sigCache struct { func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { var signer Signer switch { + case config.IsClaude(blockNumber): + signer = NewBEP215Signer(config.ChainID) case config.IsLondon(blockNumber): signer = NewLondonSigner(config.ChainID) case config.IsBerlin(blockNumber): @@ -63,6 +65,9 @@ func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { // have the current block number available, use MakeSigner instead. func LatestSigner(config *params.ChainConfig) Signer { if config.ChainID != nil { + if config.ClaudeBlock != nil { + return NewBEP215Signer(config.ChainID) + } if config.LondonBlock != nil { return NewLondonSigner(config.ChainID) } @@ -329,6 +334,89 @@ func (s eip2930Signer) Hash(tx *Transaction) common.Hash { } } +type BEP215Signer struct { + EIP155Signer +} + +func NewBEP215Signer(chainId *big.Int) BEP215Signer { + return BEP215Signer{NewEIP155Signer(chainId)} +} +func (s BEP215Signer) ChainID() *big.Int { + return s.chainId +} + +func (s BEP215Signer) Equal(s2 Signer) bool { + bep215, ok := s2.(BEP215Signer) + return ok && bep215.chainId.Cmp(s.chainId) == 0 +} + +func (s BEP215Signer) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType && tx.Type() != ReviveStateTxType { + return common.Address{}, ErrTxTypeNotSupported + } + + if !tx.Protected() { + return HomesteadSigner{}.Sender(tx) + } + if tx.ChainId().Cmp(s.chainId) != 0 { + return common.Address{}, ErrInvalidChainId + } + V, R, S := tx.RawSignatureValues() + V = new(big.Int).Sub(V, s.chainIdMul) + V.Sub(V, big8) + return recoverPlain(s.Hash(tx), R, S, V, true) +} + +// SignatureValues returns signature values. This signature +// needs to be in the [R || S || V] format where V is 0 or 1. +func (s BEP215Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { + if tx.Type() != LegacyTxType && tx.Type() != ReviveStateTxType { + return nil, nil, nil, ErrTxTypeNotSupported + } + R, S, V = decodeSignature(sig) + if s.chainId.Sign() != 0 { + V = big.NewInt(int64(sig[64] + 35)) + V.Add(V, s.chainIdMul) + } + return R, S, V, nil +} + +// Hash returns the hash to be signed by the sender. +// It does not uniquely identify the transaction. +func (s BEP215Signer) Hash(tx *Transaction) common.Hash { + switch tx.Type() { + case LegacyTxType: + return rlpHash([]interface{}{ + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + s.chainId, uint(0), uint(0), + }) + case ReviveStateTxType: + return prefixedRlpHash( + tx.Type(), + []interface{}{ + s.chainId, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + tx.WitnessList(), + }) + default: + // This _should_ not happen, but in case someone sends in a bad + // json struct via RPC, it's probably more prudent to return an + // empty hash instead of killing the node with a panic + //panic("Unsupported transaction type: %d", tx.typ) + return common.Hash{} + } +} + // EIP155Signer implements Signer using the EIP-155 rules. This accepts transactions which // are replay-protected as well as unprotected homestead transactions. type EIP155Signer struct { diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 3177a04d45..7949b4b87f 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -27,6 +27,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" @@ -563,6 +565,112 @@ func TestTransactionCoding(t *testing.T) { } } +// TestTransactionCoding tests serializing/de-serializing to/from rlp and JSON, signer sender & chainId. +func TestReviveStateTxAndSigner(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("could not generate key: %v", err) + } + var ( + signer = NewBEP215Signer(common.Big1) + from = crypto.PubkeyToAddress(key.PublicKey) + addr = common.HexToAddress("0x0000000000000000000000000000000000000001") + recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + ) + wit := StorageTrieWitness{ + Address: addr, + ProofList: []MPTProof{{ + RootKeyHex: []byte{0x09, 0x5e, 0x7b, 0xae, 0xa6, 0xa6, 0xc7, 0xc4, 0xc2}, + Proof: [][]byte{common.Hex2Bytes("6a6c7c4c2dfe7c4c2dac326af552d87baea6a6c7c4c2")}, + }}, + } + + enc, err := rlp.EncodeToBytes(wit) + if err != nil { + panic(err) + } + witness := WitnessList{{ + WitnessType: 0, + Data: enc, + }} + for i := uint64(0); i < 500; i++ { + var txdata TxData + switch i % 5 { + case 0: + // Legacy tx. + txdata = &LegacyTx{ + Nonce: i, + To: &recipient, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 1: + // Legacy tx contract creation. + txdata = &LegacyTx{ + Nonce: i, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 2: + // Tx with non-zero witness in revive state. + txdata = &ReviveStateTx{ + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + WitnessList: witness, + Data: []byte("abcdef"), + } + case 3: + // Tx with empty revive state. + txdata = &ReviveStateTx{ + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + Data: []byte("abcdef"), + } + case 4: + // Contract creation with revive state. + txdata = &ReviveStateTx{ + Nonce: i, + GasPrice: big.NewInt(10), + Gas: 123457, + Data: []byte("abcdef"), + WitnessList: witness, + } + } + tx, err := SignNewTx(key, signer, txdata) + if err != nil { + t.Fatalf("could not sign transaction: %v", err) + } + // RLP + parsedTx, err := encodeDecodeBinary(tx) + if err != nil { + t.Fatal(err) + } + if err := assertEqual(parsedTx, tx); err != nil { + t.Fatal(err) + } + + // JSON + parsedTx, err = encodeDecodeJSON(tx) + if err != nil { + t.Fatal(err) + } + if err := assertEqual(parsedTx, tx); err != nil { + t.Fatal(err) + } + + assert.Equal(t, common.Big1, tx.ChainId()) + sender, err := signer.Sender(tx) + assert.NoError(t, err) + assert.Equal(t, from, sender) + } +} + func encodeDecodeJSON(tx *Transaction) (*Transaction, error) { data, err := json.Marshal(tx) if err != nil { diff --git a/core/vm/errors.go b/core/vm/errors.go index 004f8ef1c8..03ecf2092a 100644 --- a/core/vm/errors.go +++ b/core/vm/errors.go @@ -19,6 +19,8 @@ package vm import ( "errors" "fmt" + + "github.com/ethereum/go-ethereum/common" ) // List evm execution errors @@ -70,3 +72,23 @@ type ErrInvalidOpCode struct { } func (e *ErrInvalidOpCode) Error() string { return fmt.Sprintf("invalid opcode: %s", e.opcode) } + +type EVMError struct { + from common.Address + to common.Address + opcode OpCode + Err error +} + +func NewEVMErr(contract *Contract, op OpCode, err error) *EVMError { + return &EVMError{ + from: contract.Caller(), + to: contract.Address(), + opcode: op, + Err: err, + } +} + +func (e *EVMError) Error() string { + return fmt.Sprintf("EVM err, from %v to %v at %v, got: %v", e.from, e.to, e.opcode, e.Err) +} diff --git a/core/vm/evm.go b/core/vm/evm.go index eb483fd6de..1cf6978412 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "time" + "github.com/ethereum/go-ethereum/log" + "github.com/holiman/uint256" "github.com/ethereum/go-ethereum/common" @@ -135,6 +137,9 @@ type EVM struct { // available gas is calculated in gasCall* according to the 63/64 rule and later // applied in opCall*. callGasTemp uint64 + + // ErrorCollection all op code and err list will collect in here + ErrorCollection []*EVMError } // NewEVM returns a new EVM. The returned EVM is not thread safe and should @@ -152,6 +157,7 @@ func NewEVM(blockCtx BlockContext, txCtx TxContext, statedb StateDB, chainConfig evm.depth = 0 evm.interpreter = NewEVMInterpreter(evm, config) + evm.ErrorCollection = []*EVMError{} return evm } @@ -259,6 +265,16 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas //} else { // evm.StateDB.DiscardSnapshot(snapshot) } + + errors := evm.Errors() + if err != nil || len(errors) > 0 { + log.Error("execution got err", "from", caller.Address(), "to", addr) + for _, e := range errors { + log.Error("op err", "err", e) + } + log.Error("return err", "err", err) + } + return ret, gas, err } @@ -531,3 +547,11 @@ func (evm *EVM) Create2(caller ContractRef, code []byte, gas uint64, endowment * // ChainConfig returns the environment's chain configuration func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig } + +func (evm *EVM) AppendErr(err *EVMError) { + evm.ErrorCollection = append(evm.ErrorCollection, err) +} + +func (evm *EVM) Errors() []*EVMError { + return evm.ErrorCollection +} diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index 95173cb274..89c76e5cf7 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -19,6 +19,8 @@ package vm import ( "errors" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/params" @@ -95,9 +97,12 @@ var ( func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { var ( - y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) ) + current, err := evm.StateDB.GetState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } // The legacy gas metering only takes into consideration the current state // Legacy rules should be applied if we are in Petersburg (removal of EIP-1283) // OR Constantinople is not active @@ -135,7 +140,14 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi if current == value { // noop (1) return params.NetSstoreNoopGas, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return params.NetSstoreInitGas, nil @@ -182,15 +194,25 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.Back(0) - current = evm.StateDB.GetState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) ) + current, err := evm.StateDB.GetState(contract.Address(), x.Bytes32()) + if err != nil { + return 0, err + } value := common.Hash(y.Bytes32()) if current == value { // noop (1) return params.SloadGasEIP2200, nil } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return params.SstoreSetGasEIP2200, nil diff --git a/core/vm/instructions.go b/core/vm/instructions.go index eacc77587a..e2bf1f0825 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -523,7 +523,10 @@ func opMstore8(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([] func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) - val := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) + val, err := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) + if err != nil { + return nil, err + } loc.SetBytes(val.Bytes()) return nil, nil } @@ -534,8 +537,10 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b } loc := scope.Stack.pop() val := scope.Stack.pop() - interpreter.evm.StateDB.SetState(scope.Contract.Address(), - loc.Bytes32(), val.Bytes32()) + if err := interpreter.evm.StateDB.SetState(scope.Contract.Address(), + loc.Bytes32(), val.Bytes32()); err != nil { + return nil, err + } return nil, nil } diff --git a/core/vm/interface.go b/core/vm/interface.go index ad9b05d666..8bc0891a60 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -43,9 +43,9 @@ type StateDB interface { SubRefund(uint64) GetRefund() uint64 - GetCommittedState(common.Address, common.Hash) common.Hash - GetState(common.Address, common.Hash) common.Hash - SetState(common.Address, common.Hash, common.Hash) + GetCommittedState(common.Address, common.Hash) (common.Hash, error) + GetState(common.Address, common.Hash) (common.Hash, error) + SetState(common.Address, common.Hash, common.Hash) error Suicide(common.Address) bool HasSuicided(common.Address) bool @@ -74,6 +74,7 @@ type StateDB interface { AddPreimage(common.Hash, []byte) ForEachStorage(common.Address, func(common.Hash, common.Hash) bool) error + ReviveStorageTrie(witnessList types.WitnessList) error } // CallContext provides a basic interface for the EVM calling conventions. The EVM diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 030a063948..a78ba9666e 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -20,6 +20,8 @@ import ( "hash" "sync" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/log" @@ -172,6 +174,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( // so that it get's executed _after_: the capturestate needs the stacks before // they are returned to the pools defer func() { + if err != nil { + in.evm.AppendErr(NewEVMErr(contract, op, err)) + } returnStack(stack) }() contract.Input = input @@ -235,6 +240,10 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( dynamicCost, err = operation.dynamicGas(in.evm, contract, stack, mem, memorySize) cost += dynamicCost // for tracing if err != nil || !contract.UseGas(dynamicCost) { + // capture expired state error + if _, ok := err.(*state.ExpiredStateError); ok { + break + } return nil, ErrOutOfGas } if memorySize > 0 { diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go index 551e1f5f11..7570476133 100644 --- a/core/vm/operations_acl.go +++ b/core/vm/operations_acl.go @@ -19,6 +19,8 @@ package vm import ( "errors" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/params" @@ -32,11 +34,14 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.peek() - slot = common.Hash(x.Bytes32()) - current = evm.StateDB.GetState(contract.Address(), slot) - cost = uint64(0) + y, x = stack.Back(1), stack.peek() + slot = common.Hash(x.Bytes32()) + cost = uint64(0) ) + current, err := evm.StateDB.GetState(contract.Address(), slot) + if err != nil { + return 0, err + } // Check slot presence in the access list if addrPresent, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { cost = params.ColdSloadCostEIP2929 @@ -56,7 +61,14 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { // return params.SloadGasEIP2200, nil return cost + params.WarmStorageReadCostEIP2929, nil // SLOAD_GAS } - original := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + original, err := evm.StateDB.GetCommittedState(contract.Address(), x.Bytes32()) + if err != nil { + // if origin is expired and revive, just small gas + if _, ok := err.(*state.ExpiredStateError); ok { + return params.NetSstoreDirtyGas, nil + } + return 0, err + } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) return cost + params.SstoreSetGasEIP2200, nil diff --git a/eth/api.go b/eth/api.go index f81dfa922b..b1bea41d87 100644 --- a/eth/api.go +++ b/eth/api.go @@ -291,7 +291,7 @@ func (api *PublicDebugAPI) DumpBlock(blockNr rpc.BlockNumber) (state.Dump, error if block == nil { return state.Dump{}, fmt.Errorf("block #%d not found", blockNr) } - stateDb, err := api.eth.BlockChain().StateAt(block.Root()) + stateDb, err := api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.Dump{}, err } @@ -379,7 +379,7 @@ func (api *PublicDebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, sta if block == nil { return state.IteratorDump{}, fmt.Errorf("block #%d not found", number) } - stateDb, err = api.eth.BlockChain().StateAt(block.Root()) + stateDb, err = api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.IteratorDump{}, err } @@ -389,7 +389,7 @@ func (api *PublicDebugAPI) AccountRange(blockNrOrHash rpc.BlockNumberOrHash, sta if block == nil { return state.IteratorDump{}, fmt.Errorf("block %s not found", hash.Hex()) } - stateDb, err = api.eth.BlockChain().StateAt(block.Root()) + stateDb, err = api.eth.BlockChain().StateAt(block.Root(), block.Number()) if err != nil { return state.IteratorDump{}, err } diff --git a/eth/api_backend.go b/eth/api_backend.go index 7c0f1f0482..71728f2864 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -152,7 +152,7 @@ func (b *EthAPIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B if header == nil { return nil, nil, errors.New("header not found") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Number) return stateDb, header, err } @@ -171,7 +171,7 @@ func (b *EthAPIBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockN if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, nil, errors.New("hash is not currently canonical") } - stateDb, err := b.eth.BlockChain().StateAt(header.Root) + stateDb, err := b.eth.BlockChain().StateAt(header.Root, header.Number) return stateDb, header, err } return nil, nil, errors.New("invalid arguments; neither block nor hash specified") diff --git a/eth/backend.go b/eth/backend.go index fc0ca6534c..855939dd56 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -204,6 +204,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { TrieCleanRejournal: config.TrieCleanCacheRejournal, TrieDirtyLimit: config.TrieDirtyCache, TrieDirtyDisabled: config.NoPruning, + NoPruning: config.NoPruning, TrieTimeLimit: config.TrieTimeout, NoTries: config.TriesVerifyMode != core.LocalVerify, SnapshotLimit: config.SnapshotCache, @@ -216,7 +217,9 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) { bcOps = append(bcOps, core.EnableLightProcessor) } if config.PipeCommit { - bcOps = append(bcOps, core.EnablePipelineCommit) + // TODO(0xbundler): may support pipeCommit in state expiry later + log.Info("temporary not support pipeCommit in state expiry") + //bcOps = append(bcOps, core.EnablePipelineCommit) } if config.PersistDiff { bcOps = append(bcOps, core.EnablePersistDiff(config.DiffBlock)) diff --git a/eth/catalyst/api_test.go b/eth/catalyst/api_test.go index 7e931960e5..978b37de65 100644 --- a/eth/catalyst/api_test.go +++ b/eth/catalyst/api_test.go @@ -257,7 +257,7 @@ func TestEth2NewBlock(t *testing.T) { ethservice.BlockChain().SubscribeRemovedLogsEvent(rmLogsCh) for i := 0; i < 10; i++ { - statedb, _ := ethservice.BlockChain().StateAt(parent.Root()) + statedb, _ := ethservice.BlockChain().StateAt(parent.Root(), new(big.Int).Add(parent.Number(), common.Big1)) nonce := statedb.GetNonce(testAddr) tx, _ := types.SignTx(types.NewContractCreation(nonce, new(big.Int), 1000000, big.NewInt(2*params.InitialBaseFee), logCode), types.LatestSigner(ethservice.BlockChain().Config()), testKey) ethservice.TxPool().AddLocal(tx) @@ -420,7 +420,7 @@ func TestFullAPI(t *testing.T) { logCode = common.Hex2Bytes("60606040525b7f24ec1d3ff24c2f6ff210738839dbc339cd45a5294d85c79361016243157aae7b60405180905060405180910390a15b600a8060416000396000f360606040526008565b00") ) for i := 0; i < 10; i++ { - statedb, _ := ethservice.BlockChain().StateAt(parent.Root()) + statedb, _ := ethservice.BlockChain().StateAt(parent.Root(), new(big.Int).Add(parent.Number(), common.Big1)) nonce := statedb.GetNonce(testAddr) tx, _ := types.SignTx(types.NewContractCreation(nonce, new(big.Int), 1000000, big.NewInt(2*params.InitialBaseFee), logCode), types.LatestSigner(ethservice.BlockChain().Config()), testKey) ethservice.TxPool().AddLocal(tx) diff --git a/eth/protocols/eth/handler_test.go b/eth/protocols/eth/handler_test.go index 55e612b801..8dff04b198 100644 --- a/eth/protocols/eth/handler_test.go +++ b/eth/protocols/eth/handler_test.go @@ -467,10 +467,10 @@ func testGetNodeData(t *testing.T, protocol uint) { // Sanity check whether all state matches. accounts := []common.Address{testAddr, acc1Addr, acc2Addr} for i := uint64(0); i <= backend.chain.CurrentBlock().NumberU64(); i++ { - root := backend.chain.GetBlockByNumber(i).Root() - reconstructed, _ := state.New(root, state.NewDatabase(reconstructDB), nil) + block := backend.chain.GetBlockByNumber(i) + reconstructed, _ := state.New(block.Root(), state.NewDatabase(reconstructDB), nil) for j, acc := range accounts { - state, _ := backend.chain.StateAt(root) + state, _ := backend.chain.StateAt(block.Root(), block.Number()) bw := state.GetBalance(acc) bh := reconstructed.GetBalance(acc) diff --git a/eth/protocols/snap/handler.go b/eth/protocols/snap/handler.go index 314776dffe..030992d2fd 100644 --- a/eth/protocols/snap/handler.go +++ b/eth/protocols/snap/handler.go @@ -529,6 +529,7 @@ func ServiceGetTrieNodesQuery(chain *core.BlockChain, req *GetTrieNodesPacket, s if err != nil || account == nil { break } + // TODO default using epoch0 trie, but need sndb to query shadow nodes stTrie, err := trie.NewSecure(common.BytesToHash(account.Root), triedb) loads++ // always account database reads, even for failures if err != nil { diff --git a/eth/protocols/snap/sync_test.go b/eth/protocols/snap/sync_test.go index 1dfba03c86..dd10edea67 100644 --- a/eth/protocols/snap/sync_test.go +++ b/eth/protocols/snap/sync_test.go @@ -1604,6 +1604,7 @@ func verifyTrie(db ethdb.KeyValueStore, root common.Hash, t *testing.T) { } accounts++ if acc.Root != emptyRoot { + // TODO default using epoch0 trie, but need sndb to query shadow nodes storeTrie, err := trie.NewSecure(acc.Root, triedb) if err != nil { t.Fatal(err) diff --git a/eth/state_accessor.go b/eth/state_accessor.go index b0b9f38f64..b9bd32a92e 100644 --- a/eth/state_accessor.go +++ b/eth/state_accessor.go @@ -53,10 +53,11 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state database state.Database report = true origin = block.NumberU64() + epoch = types.GetStateEpoch(eth.blockchain.Config(), block.Number()) ) // Check the live database first if we have the state fully available, use that. if checkLive { - statedb, err = eth.blockchain.StateAt(block.Root()) + statedb, err = eth.blockchain.StateAt(block.Root(), block.Number()) if err == nil { return statedb, nil } @@ -66,7 +67,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // Create an ephemeral trie.Database for isolating the live one. Otherwise // the internal junks created by tracing will be persisted into the disk. database = state.NewDatabaseWithConfig(eth.chainDb, &trie.Config{Cache: 16}) - if statedb, err = state.New(block.Root(), database, nil); err == nil { + if statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), block.Number(), block.Root(), database, nil, eth.blockchain.ShadowNodeTree()); err == nil { log.Info("Found disk backend for state trie", "root", block.Root(), "number", block.Number()) return statedb, nil } @@ -89,7 +90,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state // we would rewind past a persisted block (specific corner case is chain // tracing from the genesis). if !checkLive { - statedb, err = state.New(current.Root(), database, nil) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { return statedb, nil } @@ -105,7 +106,7 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state } current = parent - statedb, err = state.New(current.Root(), database, nil) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), current.Root(), database, nil, eth.blockchain.ShadowNodeTree()) if err == nil { break } @@ -148,13 +149,13 @@ func (eth *Ethereum) StateAtBlock(block *types.Block, reexec uint64, base *state return nil, fmt.Errorf("stateAtBlock commit failed, number %d root %v: %w", current.NumberU64(), current.Root().Hex(), err) } - statedb, err = state.New(root, database, nil) + statedb, err = state.NewWithStateEpoch(eth.blockchain.Config(), current.Number(), root, database, nil, eth.blockchain.ShadowNodeTree()) if err != nil { return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) } database.TrieDB().Reference(root, common.Hash{}) if parent != (common.Hash{}) { - database.TrieDB().Dereference(parent) + database.TrieDB().Dereference(parent, epoch) } parent = root } diff --git a/eth/tracers/api.go b/eth/tracers/api.go index 40aec6b3be..09b38fb53a 100644 --- a/eth/tracers/api.go +++ b/eth/tracers/api.go @@ -345,8 +345,10 @@ func (api *API) traceChain(ctx context.Context, start, end *types.Block, config } // clean out any derefs derefsMu.Lock() + numberBigInt := new(big.Int).SetUint64(number) + epoch := types.GetStateEpoch(api.backend.ChainConfig(), numberBigInt) for _, h := range derefTodo { - statedb.Database().TrieDB().Dereference(h) + statedb.Database().TrieDB().Dereference(h, epoch) } derefTodo = derefTodo[:0] derefsMu.Unlock() @@ -375,7 +377,7 @@ func (api *API) traceChain(ctx context.Context, start, end *types.Block, config // Release the parent state because it's already held by the tracer if parent != (common.Hash{}) { - trieDb.Dereference(parent) + trieDb.Dereference(parent, epoch) } // Prefer disk if the trie db memory grows too much s1, s2 := trieDb.Size() diff --git a/eth/tracers/api_test.go b/eth/tracers/api_test.go index c71ebfa6bc..0e36527234 100644 --- a/eth/tracers/api_test.go +++ b/eth/tracers/api_test.go @@ -141,7 +141,7 @@ func (b *testBackend) ChainDb() ethdb.Database { } func (b *testBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, checkLive bool, preferDisk bool) (*state.StateDB, error) { - statedb, err := b.chain.StateAt(block.Root()) + statedb, err := b.chain.StateAt(block.Root(), block.Number()) if err != nil { return nil, errStateNotFound } @@ -153,7 +153,7 @@ func (b *testBackend) StateAtTransaction(ctx context.Context, block *types.Block if parent == nil { return nil, vm.BlockContext{}, nil, errBlockNotFound } - statedb, err := b.chain.StateAt(parent.Root()) + statedb, err := b.chain.StateAt(parent.Root(), block.Number()) if err != nil { return nil, vm.BlockContext{}, nil, errStateNotFound } diff --git a/eth/tracers/js/goja.go b/eth/tracers/js/goja.go index e91e222a67..348e0dc6b6 100644 --- a/eth/tracers/js/goja.go +++ b/eth/tracers/js/goja.go @@ -702,8 +702,12 @@ func (do *dbObj) GetState(addrSlice goja.Value, hashSlice goja.Value) goja.Value return nil } hash := common.BytesToHash(h) - state := do.db.GetState(addr, hash).Bytes() - res, err := do.toBuf(do.vm, state) + val, err := do.db.GetState(addr, hash) + if err != nil { + do.vm.Interrupt(err) + return nil + } + res, err := do.toBuf(do.vm, val.Bytes()) if err != nil { do.vm.Interrupt(err) return nil diff --git a/eth/tracers/logger/logger.go b/eth/tracers/logger/logger.go index aea44801d8..87f62db44a 100644 --- a/eth/tracers/logger/logger.go +++ b/eth/tracers/logger/logger.go @@ -188,10 +188,8 @@ func (l *StructLogger) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, s } // capture SLOAD opcodes and record the read entry in the local storage if op == vm.SLOAD && stackLen >= 1 { - var ( - address = common.Hash(stackData[stackLen-1].Bytes32()) - value = l.env.StateDB.GetState(contract.Address(), address) - ) + address := common.Hash(stackData[stackLen-1].Bytes32()) + value, _ := l.env.StateDB.GetState(contract.Address(), address) l.storage[contract.Address()][address] = value storage = l.storage[contract.Address()].Copy() } else if op == vm.SSTORE && stackLen >= 2 { diff --git a/eth/tracers/logger/logger_test.go b/eth/tracers/logger/logger_test.go index 6b1e740814..8c3bcfe85d 100644 --- a/eth/tracers/logger/logger_test.go +++ b/eth/tracers/logger/logger_test.go @@ -48,9 +48,11 @@ type dummyStatedb struct { state.StateDB } -func (*dummyStatedb) GetRefund() uint64 { return 1337 } -func (*dummyStatedb) GetState(_ common.Address, _ common.Hash) common.Hash { return common.Hash{} } -func (*dummyStatedb) SetState(_ common.Address, _ common.Hash, _ common.Hash) {} +func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetState(_ common.Address, _ common.Hash) (common.Hash, error) { + return common.Hash{}, nil +} +func (*dummyStatedb) SetState(_ common.Address, _ common.Hash, _ common.Hash) error { return nil } func TestStoreCapture(t *testing.T) { var ( diff --git a/eth/tracers/native/prestate.go b/eth/tracers/native/prestate.go index b513f383b9..532c6de6dd 100644 --- a/eth/tracers/native/prestate.go +++ b/eth/tracers/native/prestate.go @@ -174,5 +174,10 @@ func (t *prestateTracer) lookupStorage(addr common.Address, key common.Hash) { if _, ok := t.prestate[addr].Storage[key]; ok { return } - t.prestate[addr].Storage[key] = t.env.StateDB.GetState(addr, key) + // TODO when trace meet state err, just pass + val, err := t.env.StateDB.GetState(addr, key) + if err != nil { + return + } + t.prestate[addr].Storage[key] = val } diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index db388a3d9d..d708064b34 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -559,6 +559,24 @@ func (ec *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64 return uint64(hex), nil } +// Result struct for EstimateGasAndReviveState +type EstimateGasAndReviveStateResult struct { + Hex hexutil.Uint64 `json:"gas"` + ReviveWitness []types.ReviveWitness `json:"reviveWitness"` +} + +// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It also attempts to revive expired +// storage trie and returns the revive witness list. +func (ec *Client) EstimateGasAndReviveState(ctx context.Context, msg ethereum.CallMsg) (*EstimateGasAndReviveStateResult, error) { + var result EstimateGasAndReviveStateResult + err := ec.c.CallContext(ctx, &result, "eth_estimateGasAndReviveState", toCallArg(msg)) + if err != nil { + return &result, err + } + return &result, nil +} + // SendTransaction injects a signed transaction into the pending pool for execution. // // If the transaction was a contract creation use the TransactionReceipt method to get the diff --git a/ethclient/ethclient_test.go b/ethclient/ethclient_test.go index b488487620..b0b4074997 100644 --- a/ethclient/ethclient_test.go +++ b/ethclient/ethclient_test.go @@ -376,6 +376,9 @@ func TestEthClient(t *testing.T) { "TestDiffAccounts": { func(t *testing.T) { testDiffAccounts(t, client) }, }, + "TestEstimateGasAndReviveState": { + func(t *testing.T) { testEstimateGasAndReviveState(t, client) }, + }, // DO not have TestAtFunctions now, because we do not have pending block now } @@ -620,6 +623,26 @@ func testCallContract(t *testing.T, client *rpc.Client) { } } +func testEstimateGasAndReviveState(t *testing.T, client *rpc.Client) { + ec := NewClient(client) + + // EstimateGas + msg := ethereum.CallMsg{ + From: testAddr, + To: &common.Address{}, + Gas: 21000, + Value: big.NewInt(1), + } + result, err := ec.EstimateGasAndReviveState(context.Background(), msg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Hex != 21000 { + t.Fatalf("unexpected gas price: %v", result.Hex) + } + // TODO(asyukii): add more test cases here +} + func testDiffAccounts(t *testing.T, client *rpc.Client) { ec := NewClient(client) ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond) diff --git a/go.mod b/go.mod index 8405c9661a..f894cd9d85 100644 --- a/go.mod +++ b/go.mod @@ -77,12 +77,14 @@ require ( require ( github.com/Azure/azure-pipeline-go v0.2.2 // indirect github.com/Azure/go-autorest/autorest/adal v0.8.0 // indirect + github.com/RoaringBitmap/roaring v1.2.3 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.2 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.1.1 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.1.1 // indirect github.com/aws/smithy-go v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.2.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91 // indirect @@ -101,6 +103,7 @@ require ( github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mitchellh/mapstructure v1.4.1 // indirect github.com/mitchellh/pointerstructure v1.2.0 // indirect + github.com/mschoch/smat v0.2.0 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect github.com/opentracing/opentracing-go v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 5ce59554c8..c161ec913f 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= +github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/VictoriaMetrics/fastcache v1.6.0 h1:C/3Oi3EiBCqufydp1neRZkqcwmEiuRT9c3fqvvgKm5o= github.com/VictoriaMetrics/fastcache v1.6.0/go.mod h1:0qHz5QP0GMX4pfmMA/zt5RgfNuXJrTP0zS7DqpHGGTw= github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= @@ -95,6 +97,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjLyS07ChA= +github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bmizerany/pat v0.0.0-20170815010413-6226ea591a40/go.mod h1:8rLXio+WjiTceGBHIoTvn60HIbs7Hm7bcHjyrSqYB9c= github.com/bnb-chain/ics23 v0.1.0 h1:DvjGOts2FBfbxB48384CYD1LbcrfjThFz8kowY/7KxU= github.com/bnb-chain/ics23 v0.1.0/go.mod h1:cU6lTGolbbLFsGCgceNB2AzplH1xecLp6+KXvxM32nI= @@ -402,6 +406,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg= +github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= +github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/naoina/go-stringutil v0.1.0 h1:rCUeRUHjBjGTSHl0VC00jUPLz8/F9dDzYI70Hzifhks= diff --git a/graphql/graphql.go b/graphql/graphql.go index cc3d36235d..511a18bdd4 100644 --- a/graphql/graphql.go +++ b/graphql/graphql.go @@ -130,7 +130,7 @@ func (a *Account) Storage(ctx context.Context, args struct{ Slot common.Hash }) if err != nil { return common.Hash{}, err } - return state.GetState(a.address, args.Slot), nil + return state.GetState(a.address, args.Slot) } // Log represents an individual log message. All arguments are mandatory. diff --git a/interfaces.go b/interfaces.go index 76c1ef6908..9d69dc98b8 100644 --- a/interfaces.go +++ b/interfaces.go @@ -141,7 +141,8 @@ type CallMsg struct { Value *big.Int // amount of wei sent along with the call Data []byte // input data, usually an ABI-encoded contract method invocation - AccessList types.AccessList // EIP-2930 access list. + AccessList types.AccessList // EIP-2930 access list. + WitnessList types.WitnessList // EIP-2930 access list. } // A ContractCaller provides contract calls, essentially transactions that are executed by diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 1c12fcdabd..2b2a9c16c4 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -686,7 +686,11 @@ func (s *PublicBlockChainAPI) GetProof(ctx context.Context, address common.Addre if storageError != nil { return nil, storageError } - storageProof[i] = StorageResult{key, (*hexutil.Big)(state.GetState(address, common.HexToHash(key)).Big()), toHexSlice(proof)} + val, stateErr := state.GetState(address, common.HexToHash(key)) + if stateErr != nil { + return nil, stateErr + } + storageProof[i] = StorageResult{key, (*hexutil.Big)(val.Big()), toHexSlice(proof)} } else { storageProof[i] = StorageResult{key, &hexutil.Big{}, []string{}} } @@ -839,7 +843,10 @@ func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.A if state == nil || err != nil { return nil, err } - res := state.GetState(address, common.HexToHash(key)) + res, err := state.GetState(address, common.HexToHash(key)) + if err != nil { + return nil, err + } return res[:], state.Error() } @@ -950,6 +957,61 @@ func DoCall(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash return result, nil } +func DoCallExpired(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, timeout time.Duration, globalGasCap uint64) (*core.ExecutionResult, []*vm.EVMError, *state.StateDB, error) { + defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) + + state, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if state == nil || err != nil { + return nil, nil, nil, err + } + if err := overrides.Apply(state); err != nil { + return nil, nil, nil, err + } + // Setup context so it may be cancelled the call has completed + // or, in case of unmetered gas, setup a context with a timeout. + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + // Make sure the context is cancelled when the call has completed + // this makes sure resources are cleaned up. + defer cancel() + + // Get a new instance of the EVM. + msg, err := args.ToMessage(globalGasCap, header.BaseFee) + if err != nil { + return nil, nil, nil, err + } + evm, vmError, err := b.GetEVM(ctx, msg, state, header, &vm.Config{NoBaseFee: true}) + if err != nil { + return nil, nil, nil, err + } + // Wait for the context to be done and cancel the evm. Even if the + // EVM has finished, cancelling may be done (repeatedly) + gopool.Submit(func() { + <-ctx.Done() + evm.Cancel() + }) + + // Execute the message. + gp := new(core.GasPool).AddGas(math.MaxUint64) + result, err := core.ApplyMessage(evm, msg, gp) + if err := vmError(); err != nil { + return nil, evm.ErrorCollection, nil, err + } + + // If the timer caused an abort, return an appropriate error message + if evm.Cancelled() { + return nil, evm.ErrorCollection, nil, fmt.Errorf("execution aborted (timeout = %v)", timeout) + } + if err != nil { + return result, evm.ErrorCollection, nil, fmt.Errorf("err: %w (supplied gas %d)", err, msg.Gas()) + } + return result, evm.ErrorCollection, state, nil +} + func newRevertError(result *core.ExecutionResult) *revertError { reason, errUnpack := abi.UnpackRevert(result.Revert()) err := errors.New("execution reverted") @@ -1128,6 +1190,222 @@ func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args TransactionA return DoEstimateGas(ctx, s.b, args, bNrOrHash, s.b.RPCGasCap()) } +// Result structs for EstimateGasAndReviveState +type EstimateGasAndReviveStateResult struct { + Hex hexutil.Uint64 `json:"gas"` + ReviveWitness []types.ReviveWitness `json:"reviveWitness"` +} + +// DoEstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It also attempts to revive expired +// storage trie and returns the revive witness list. +func DoEstimateGasAndReviveState(ctx context.Context, b Backend, args TransactionArgs, blockNrOrHash rpc.BlockNumberOrHash, gasCap uint64) (*EstimateGasAndReviveStateResult, error) { + + var result EstimateGasAndReviveStateResult + + // Initialize witnessList + if args.WitnessList == nil { + args.WitnessList = (*types.WitnessList)(&[]types.ReviveWitness{}) + } + witLen := len(*args.WitnessList) + + // Binary search the gas requirement, as it may be higher than the amount used + var ( + lo uint64 = params.TxGas - 1 + hi uint64 + cap uint64 + ) + // Use zero address if sender unspecified. + if args.From == nil { + args.From = new(common.Address) + } + // Determine the highest gas limit can be used during the estimation. + if args.Gas != nil && uint64(*args.Gas) >= params.TxGas { + hi = uint64(*args.Gas) + } else { + // Retrieve the block to act as the gas ceiling + block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) + if err != nil { + return nil, err + } + if block == nil { + return nil, errors.New("block not found") + } + hi = block.GasLimit() + } + // Normalize the max fee per gas the call is willing to spend. + var feeCap *big.Int + if args.GasPrice != nil && (args.MaxFeePerGas != nil || args.MaxPriorityFeePerGas != nil) { + return nil, errors.New("both gasPrice and (maxFeePerGas or maxPriorityFeePerGas) specified") + } else if args.GasPrice != nil { + feeCap = args.GasPrice.ToInt() + } else if args.MaxFeePerGas != nil { + feeCap = args.MaxFeePerGas.ToInt() + } else { + feeCap = common.Big0 + } + // Recap the highest gas limit with account's available balance. + if feeCap.BitLen() != 0 { + stateDb, _, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if err != nil { + return nil, err + } + balance := stateDb.GetBalance(*args.From) // from can't be nil + available := new(big.Int).Set(balance) + if args.Value != nil { + if args.Value.ToInt().Cmp(available) >= 0 { + return nil, errors.New("insufficient funds for transfer") + } + available.Sub(available, args.Value.ToInt()) + } + allowance := new(big.Int).Div(available, feeCap) + + // If the allowance is larger than maximum uint64, skip checking + if allowance.IsUint64() && hi > allowance.Uint64() { + transfer := args.Value + if transfer == nil { + transfer = new(hexutil.Big) + } + log.Warn("Gas estimation capped by limited funds", "original", hi, "balance", balance, + "sent", transfer.ToInt(), "maxFeePerGas", feeCap, "fundable", allowance) + hi = allowance.Uint64() + } + } + // Recap the highest gas allowance with specified gascap. + if gasCap != 0 && hi > gasCap { + log.Debug("Caller gas above allowance, capping", "requested", hi, "cap", gasCap) + hi = gasCap + } + cap = hi + + // Create a helper to check if a gas allowance results in an executable transaction + expiedNodeCache := make(map[common.Address]map[string]bool) + executable := func(gas uint64) (bool, *core.ExecutionResult, bool, error) { + args.Gas = (*hexutil.Uint64)(&gas) + + result, evmErrors, callState, err := DoCallExpired(ctx, b, args, blockNrOrHash, nil, 0, gasCap) // TODO (asyukii): Use a different call function to return EVM errors + + // Create MPTProof + resolveWitness := false + if len(evmErrors) > 0 { + addressToProofMap := make(map[common.Address][]types.MPTProof) + for _, evmErr := range evmErrors { + if stateErr, ok := evmErr.Err.(*state.ExpiredStateError); ok { + if _, ec := expiedNodeCache[stateErr.Addr]; !ec { + expiedNodeCache[stateErr.Addr] = make(map[string]bool) + } + if expiedNodeCache[stateErr.Addr][string(stateErr.Path)] { + // revive not works, just return + return true, nil, resolveWitness, stateErr + } + + expiedNodeCache[stateErr.Addr][string(stateErr.Path)] = true + proof, err := callState.GetStorageWitness(stateErr.Addr, stateErr.Path, stateErr.Key) + if err != nil { + return true, nil, false, err + } + if len(proof) == 0 { + continue + } + addressToProofMap[stateErr.Addr] = append(addressToProofMap[stateErr.Addr], types.MPTProof{ + RootKeyHex: stateErr.Path, + Proof: proof, + }) + } + } + + // Create a ReviveWitness object for each address and add it to witnessList + for addr, proofs := range addressToProofMap { + // Build a storageTrieWitness + storageTrieWitness := types.StorageTrieWitness{ + Address: addr, + ProofList: proofs, + } + // Encode StorageTrieWitness + enc, err := rlp.EncodeToBytes(storageTrieWitness) + if err != nil { + return true, nil, resolveWitness, err + } + // Create a ReviveWitness + reviveWitness := types.ReviveWitness{ + WitnessType: types.StorageTrieWitnessType, + Data: enc, + } + // Append to witness list + *args.WitnessList = append(*args.WitnessList, reviveWitness) + resolveWitness = true + } + } + if err != nil { + if errors.Is(err, core.ErrIntrinsicGas) { + return true, nil, resolveWitness, nil // Special case, raise gas limit + } + return true, nil, resolveWitness, err // Bail out + } + return result.Failed(), result, resolveWitness, nil + } + // Execute the binary search and hone in on an executable gas limit + for lo+1 < hi { + mid := (hi + lo) / 2 + failed, _, resolveWitness, err := executable(mid) + + if resolveWitness { + if witLen == len(*args.WitnessList) { + // If witnessList is not updated, it means that the proofs are not + // sufficient to revive the state. + return nil, fmt.Errorf("cannot generate enough proofs to revive the state") + } + witLen = len(*args.WitnessList) + continue + } + + // If the error is not nil(consensus error), it means the provided message + // call or transaction will never be accepted no matter how much gas it is + // assigned. Return the error directly, don't struggle any more. + if err != nil { + return nil, err + } + if failed { + lo = mid + } else { + hi = mid + } + } + // Reject the transaction as invalid if it still fails at the highest allowance + if hi == cap { + failed, result, _, err := executable(hi) + if err != nil { + return nil, err + } + if failed { + if result != nil && result.Err != vm.ErrOutOfGas { + if len(result.Revert()) > 0 { + return nil, newRevertError(result) + } + return nil, result.Err + } + // Otherwise, the specified gas cap is too low + return nil, fmt.Errorf("gas required exceeds allowance (%d)", cap) + } + } + result = EstimateGasAndReviveStateResult{ + Hex: hexutil.Uint64(hi), + ReviveWitness: *args.WitnessList, + } + return &result, nil +} + +// EstimateGasAndReviveState returns an estimate of the amount of gas needed to execute the +// given transaction against the current pending block. It will also revive the state +// temporarily to estimate the gas. +func (s *PublicBlockChainAPI) EstimateGasAndReviveState(ctx context.Context, args TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash) (*EstimateGasAndReviveStateResult, error) { + bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash + } + return DoEstimateGasAndReviveState(ctx, s.b, args, bNrOrHash, s.b.RPCGasCap()) +} + // GetDiffAccounts returns changed accounts in a specific block number. func (s *PublicBlockChainAPI) GetDiffAccounts(ctx context.Context, blockNr rpc.BlockNumber) ([]common.Address, error) { if s.b.Chain() == nil { @@ -1197,11 +1475,11 @@ func (s *PublicBlockChainAPI) needToReplay(ctx context.Context, block *types.Blo if err != nil { return false, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - parentState, err := s.b.Chain().StateAt(parent.Root()) + parentState, err := s.b.Chain().StateAt(parent.Root(), parent.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64()-1, err) } - currentState, err := s.b.Chain().StateAt(block.Root()) + currentState, err := s.b.Chain().StateAt(block.Root(), block.Number()) if err != nil { return false, fmt.Errorf("statedb not found for block number (%d): %v", block.NumberU64(), err) } @@ -1227,7 +1505,7 @@ func (s *PublicBlockChainAPI) replay(ctx context.Context, block *types.Block, ac if err != nil { return nil, nil, fmt.Errorf("block not found for block number (%d): %v", block.NumberU64()-1, err) } - statedb, err := s.b.Chain().StateAt(parent.Root()) + statedb, err := s.b.Chain().StateAt(parent.Root(), block.Number()) if err != nil { return nil, nil, fmt.Errorf("state not found for block number (%d): %v", block.NumberU64()-1, err) } @@ -1481,25 +1759,26 @@ func (s *PublicBlockChainAPI) rpcMarshalBlock(ctx context.Context, b *types.Bloc // RPCTransaction represents a transaction that will serialize to the RPC representation of a transaction type RPCTransaction struct { - BlockHash *common.Hash `json:"blockHash"` - BlockNumber *hexutil.Big `json:"blockNumber"` - From common.Address `json:"from"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice *hexutil.Big `json:"gasPrice"` - GasFeeCap *hexutil.Big `json:"maxFeePerGas,omitempty"` - GasTipCap *hexutil.Big `json:"maxPriorityFeePerGas,omitempty"` - Hash common.Hash `json:"hash"` - Input hexutil.Bytes `json:"input"` - Nonce hexutil.Uint64 `json:"nonce"` - To *common.Address `json:"to"` - TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` - Value *hexutil.Big `json:"value"` - Type hexutil.Uint64 `json:"type"` - Accesses *types.AccessList `json:"accessList,omitempty"` - ChainID *hexutil.Big `json:"chainId,omitempty"` - V *hexutil.Big `json:"v"` - R *hexutil.Big `json:"r"` - S *hexutil.Big `json:"s"` + BlockHash *common.Hash `json:"blockHash"` + BlockNumber *hexutil.Big `json:"blockNumber"` + From common.Address `json:"from"` + Gas hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + GasFeeCap *hexutil.Big `json:"maxFeePerGas,omitempty"` + GasTipCap *hexutil.Big `json:"maxPriorityFeePerGas,omitempty"` + Hash common.Hash `json:"hash"` + Input hexutil.Bytes `json:"input"` + Nonce hexutil.Uint64 `json:"nonce"` + To *common.Address `json:"to"` + TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` + Value *hexutil.Big `json:"value"` + Type hexutil.Uint64 `json:"type"` + Accesses *types.AccessList `json:"accessList,omitempty"` + Witness *types.WitnessList `json:"witnessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` } // newRPCTransaction returns a transaction that will serialize to the RPC @@ -1528,6 +1807,10 @@ func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber result.TransactionIndex = (*hexutil.Uint64)(&index) } switch tx.Type() { + case types.ReviveStateTxType: + wl := tx.WitnessList() + result.Witness = &wl + result.ChainID = (*hexutil.Big)(tx.ChainId()) case types.AccessListTxType: al := tx.AccessList() result.Accesses = &al @@ -1862,7 +2145,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceiptsByBlockNumber(ctx conte tx := txs[idx] var signer types.Signer = types.FrontierSigner{} if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) + signer = types.NewBEP215Signer(tx.ChainId()) } from, _ := types.Sender(signer, tx) diff --git a/internal/ethapi/transaction_args.go b/internal/ethapi/transaction_args.go index 9284439162..b6de1a599c 100644 --- a/internal/ethapi/transaction_args.go +++ b/internal/ethapi/transaction_args.go @@ -52,6 +52,9 @@ type TransactionArgs struct { // Introduced by AccessListTxType transaction. AccessList *types.AccessList `json:"accessList,omitempty"` ChainID *hexutil.Big `json:"chainId,omitempty"` + + // Introduced by ReviveStateTxType transaction. + WitnessList *types.WitnessList `json:"witnessList,omitempty"` } // from retrieves the transaction sender address. @@ -245,7 +248,11 @@ func (args *TransactionArgs) ToMessage(globalGasCap uint64, baseFee *big.Int) (t if args.AccessList != nil { accessList = *args.AccessList } - msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, gasFeeCap, gasTipCap, data, accessList, true) + var witnessList types.WitnessList + if args.WitnessList != nil { + witnessList = *args.WitnessList + } + msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, gasFeeCap, gasTipCap, data, accessList, witnessList, true) return msg, nil } diff --git a/les/api_backend.go b/les/api_backend.go index 0e02a03050..1ce9140bac 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -141,7 +141,7 @@ func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.B if header == nil { return nil, nil, errors.New("header not found") } - return light.NewState(ctx, header, b.eth.odr), header, nil + return light.NewState(ctx, b.ChainConfig(), header, b.eth.odr), header, nil } func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) { @@ -156,7 +156,7 @@ func (b *LesApiBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockN if blockNrOrHash.RequireCanonical && b.eth.blockchain.GetCanonicalHash(header.Number.Uint64()) != hash { return nil, nil, errors.New("hash is not currently canonical") } - return light.NewState(ctx, header, b.eth.odr), header, nil + return light.NewState(ctx, b.ChainConfig(), header, b.eth.odr), header, nil } return nil, nil, errors.New("invalid arguments; neither block nor hash specified") } diff --git a/les/odr_test.go b/les/odr_test.go index ad77abf5b9..18624d8280 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -100,7 +100,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon st, err = state.New(header.Root, state.NewDatabase(db), nil) } else { header := lc.GetHeaderByHash(bhash) - st = light.NewState(ctx, header, lc.Odr()) + st = light.NewState(ctx, config, header, lc.Odr()) } if err == nil { bal := st.GetBalance(addr) @@ -135,7 +135,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai from := statedb.GetOrNewStateObject(bankAddr) from.SetBalance(math.MaxBig256) - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} context := core.NewEVMBlockContext(header, bc, nil) txContext := core.NewEVMTxContext(msg) @@ -148,9 +148,9 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai } } else { header := lc.GetHeaderByHash(bhash) - state := light.NewState(ctx, header, lc.Odr()) + state := light.NewState(ctx, config, header, lc.Odr()) state.SetBalance(bankAddr, math.MaxBig256) - msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(bankAddr, &testContractAddr, 0, new(big.Int), 100000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} context := core.NewEVMBlockContext(header, lc, nil) txContext := core.NewEVMTxContext(msg) vmenv := vm.NewEVM(context, txContext, state, config, vm.Config{NoBaseFee: true}) diff --git a/les/server_requests.go b/les/server_requests.go index 7564420ce6..4a5456ba86 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -428,6 +428,7 @@ func handleGetProofs(msg Decoder) (serveRequestFn, uint64, uint64, error) { p.bumpInvalid() continue } + // TODO(0xbundler): fix les later trie, err = statedb.OpenStorageTrie(common.BytesToHash(request.AccKey), account.Root) if trie == nil || err != nil { p.Log().Warn("Failed to open storage trie for proof", "block", header.Number, "hash", header.Hash(), "account", common.BytesToHash(request.AccKey), "root", account.Root, "err", err) diff --git a/les/state_accessor.go b/les/state_accessor.go index 112e6fd44d..b28ba33add 100644 --- a/les/state_accessor.go +++ b/les/state_accessor.go @@ -30,7 +30,7 @@ import ( // stateAtBlock retrieves the state database associated with a certain block. func (leth *LightEthereum) stateAtBlock(ctx context.Context, block *types.Block, reexec uint64) (*state.StateDB, error) { - return light.NewState(ctx, block.Header(), leth.odr), nil + return light.NewState(ctx, leth.chainConfig, block.Header(), leth.odr), nil } // stateAtTransaction returns the execution environment of a certain transaction. diff --git a/light/odr_test.go b/light/odr_test.go index fdf657a82e..0a69cb642a 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -146,7 +146,7 @@ func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc var st *state.StateDB if bc == nil { header := lc.GetHeaderByHash(bhash) - st = NewState(ctx, header, lc.Odr()) + st = NewState(ctx, lc.Config(), header, lc.Odr()) } else { header := bc.GetHeaderByHash(bhash) st, _ = state.New(header.Root, state.NewDatabase(db), nil) @@ -185,7 +185,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if bc == nil { chain = lc header = lc.GetHeaderByHash(bhash) - st = NewState(ctx, header, lc.Odr()) + st = NewState(ctx, lc.Config(), header, lc.Odr()) } else { chain = bc header = bc.GetHeaderByHash(bhash) @@ -194,7 +194,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain // Perform read-only call. st.SetBalance(testBankAddress, math.MaxBig256) - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, true)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, big.NewInt(params.InitialBaseFee), big.NewInt(params.InitialBaseFee), new(big.Int), data, nil, nil, true)} txContext := core.NewEVMTxContext(msg) context := core.NewEVMBlockContext(header, chain, nil) vmenv := vm.NewEVM(context, txContext, st, config, vm.Config{NoBaseFee: true}) diff --git a/light/trie.go b/light/trie.go index d41536e069..2de90fc698 100644 --- a/light/trie.go +++ b/light/trie.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" @@ -35,8 +37,9 @@ var ( sha3Nil = crypto.Keccak256Hash(nil) ) -func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { - state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) +func NewState(ctx context.Context, config *params.ChainConfig, head *types.Header, odr OdrBackend) *state.StateDB { + tree, _ := trie.NewShadowNodeSnapTree(odr.Database(), true) + state, _ := state.NewWithStateEpoch(config, head.Number, head.Root, NewStateDatabase(ctx, head, odr), nil, tree) return state } @@ -62,6 +65,10 @@ func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil } +func (db *odrDatabase) OpenStorageTrieWithShadowNode(addrHash, root common.Hash, curEpoch types.StateEpoch, sndb trie.ShadowNodeStorage) (state.Trie, error) { + return db.OpenStorageTrie(addrHash, root) +} + func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie { switch t := t.(type) { case *odrTrie: @@ -112,6 +119,10 @@ type odrTrie struct { trie *trie.Trie } +func (t *odrTrie) ProveStorageWitness(key []byte, prefixKey []byte, proofDb ethdb.KeyValueWriter) error { + panic("implement me") +} + func (t *odrTrie) TryGet(key []byte) ([]byte, error) { key = crypto.Keccak256(key) var res []byte @@ -140,6 +151,10 @@ func (t *odrTrie) TryUpdate(key, value []byte) error { }) } +func (t *odrTrie) TryUpdateEpoch(key []byte) error { + return nil +} + func (t *odrTrie) TryDelete(key []byte) error { key = crypto.Keccak256(key) return t.do(key, func() error { @@ -169,6 +184,14 @@ func (t *odrTrie) GetKey(sha []byte) []byte { return nil } +func (t *odrTrie) HashKey(key []byte) []byte { + return nil +} + +func (t *odrTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { return errors.New("not implemented, needs client/server interface split") } @@ -198,6 +221,10 @@ func (db *odrTrie) NoTries() bool { return false } +func (t *odrTrie) ReviveTrie(proof []*trie.MPTProofNub) []*trie.MPTProofNub { + return t.trie.ReviveTrie(proof) +} + type nodeIterator struct { trie.NodeIterator t *odrTrie diff --git a/light/txpool.go b/light/txpool.go index d12694d8f9..b37870abd1 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -18,6 +18,7 @@ package light import ( "context" + "errors" "fmt" "math/big" "sync" @@ -69,6 +70,7 @@ type TxPool struct { istanbul bool // Fork indicator whether we are in the istanbul stage. eip2718 bool // Fork indicator whether we are in the eip2718 stage. + isElwood bool // Fork indicator whether we are in the BEP-216 stage. } // TxRelayBackend provides an interface to the mechanism that forwards transacions @@ -115,7 +117,7 @@ func NewTxPool(config *params.ChainConfig, chain *LightChain, relay TxRelayBacke // currentState returns the light state of the current head header func (pool *TxPool) currentState(ctx context.Context) *state.StateDB { - return NewState(ctx, pool.chain.CurrentHeader(), pool.odr) + return NewState(ctx, pool.config, pool.chain.CurrentHeader(), pool.odr) } // GetNonce returns the "pending" nonce of a given address. It always queries @@ -318,6 +320,7 @@ func (pool *TxPool) setNewHead(head *types.Header) { next := new(big.Int).Add(head.Number, big.NewInt(1)) pool.istanbul = pool.config.IsIstanbul(next) pool.eip2718 = pool.config.IsBerlin(next) + pool.isElwood = pool.config.IsElwood(next) } // Stop stops the light transaction pool @@ -385,13 +388,26 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error } // Should supply enough intrinsic gas - gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true, pool.istanbul) + witnessList := tx.WitnessList() + gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), witnessList, tx.To() == nil, true, pool.istanbul) if err != nil { return err } if tx.Gas() < gas { return core.ErrIntrinsicGas } + + // check witness and hard fork + if witnessList != nil { + if !pool.isElwood { + return errors.New("cannot allow witness before Elwood fork") + } + for i := range witnessList { + if err := witnessList[i].VerifyWitness(); err != nil { + return err + } + } + } return currentState.Error() } diff --git a/miner/miner.go b/miner/miner.go index 7f0f1583e8..9bd7a50c0b 100644 --- a/miner/miner.go +++ b/miner/miner.go @@ -201,7 +201,7 @@ func (miner *Miner) Pending() (*types.Block, *state.StateDB) { if block == nil { return nil, nil } - stateDb, err := miner.worker.chain.StateAt(block.Root()) + stateDb, err := miner.worker.chain.StateAt(block.Root(), block.Number()) if err != nil { return nil, nil } diff --git a/miner/miner_test.go b/miner/miner_test.go index cf619845dd..028b345831 100644 --- a/miner/miner_test.go +++ b/miner/miner_test.go @@ -19,6 +19,7 @@ package miner import ( "errors" + "math/big" "testing" "time" @@ -75,7 +76,7 @@ func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block return bc.CurrentBlock() } -func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) { +func (bc *testBlockChain) StateAt(common.Hash, *big.Int) (*state.StateDB, error) { return bc.statedb, nil } diff --git a/miner/worker.go b/miner/worker.go index 9f4cce2988..d6f51aedb7 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -679,8 +679,8 @@ func (w *worker) resultLoop() { continue } writeBlockTimer.UpdateSince(start) - log.Info("Successfully sealed new block", "number", block.Number(), "sealhash", sealhash, "hash", hash, - "elapsed", common.PrettyDuration(time.Since(task.createdAt))) + log.Info("Successfully sealed new block", "number", block.Number(), "hash", hash, "from", block.Coinbase(), + "sealhash", sealhash, "elapsed", common.PrettyDuration(time.Since(task.createdAt))) // Broadcast the block and announce chain insertion event w.mux.Post(core.NewMinedBlockEvent{Block: block}) @@ -698,7 +698,8 @@ func (w *worker) makeEnv(parent *types.Block, header *types.Header, coinbase com prevEnv *environment) (*environment, error) { // Retrieve the parent state to execute on top and start a prefetcher for // the miner to speed block sealing up a bit - state, err := w.chain.StateAtWithSharedPool(parent.Root()) + //state, err := w.chain.StateAtWithSharedPool(parent.Root()) + state, err := w.chain.StateAt(parent.Root(), header.Number) if err != nil { return nil, err } diff --git a/params/config.go b/params/config.go index f4c485f227..10e420d75f 100644 --- a/params/config.go +++ b/params/config.go @@ -115,10 +115,13 @@ var ( MoranBlock: big.NewInt(22107423), GibbsBlock: big.NewInt(23846001), PlanckBlock: big.NewInt(27281024), + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on mainNet + //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -142,9 +145,13 @@ var ( NanoBlock: big.NewInt(23482428), MoranBlock: big.NewInt(23603940), PlanckBlock: big.NewInt(28196022), + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on testnet + //ElwoodBlock: big.NewInt(-), + Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -168,10 +175,13 @@ var ( NanoBlock: nil, MoranBlock: nil, PlanckBlock: nil, + //ClaudeBlock: big.NewInt(-), // enable state expiry hard fork1 on QA net + //ElwoodBlock: big.NewInt(-), Parlia: &ParliaConfig{ - Period: 3, - Epoch: 200, + Period: 3, + Epoch: 200, + StateEpochPeriod: 7_008_000, }, } @@ -180,16 +190,16 @@ var ( // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), new(EthashConfig), nil, nil} + AllEthashProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil, nil} // AllCliqueProtocolChanges contains every protocol change (EIPs) introduced // and accepted by the Ethereum core developers into the Clique consensus. // // This configuration is intentionally not using keyed fields to force anyone // adding flags to the config to also have to set these fields. - AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), nil, nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} + AllCliqueProtocolChanges = &ChainConfig{big.NewInt(1337), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), nil, nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, nil, nil, &CliqueConfig{Period: 0, Epoch: 30000}, nil} - TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, new(EthashConfig), nil, nil} + TestChainConfig = &ChainConfig{big.NewInt(1), big.NewInt(0), nil, false, big.NewInt(0), common.Hash{}, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0), nil, nil, nil, nil, new(EthashConfig), nil, nil} TestRules = TestChainConfig.Rules(new(big.Int), false) ) @@ -286,6 +296,8 @@ type ChainConfig struct { NanoBlock *big.Int `json:"nanoBlock,omitempty" toml:",omitempty"` // nanoBlock switch block (nil = no fork, 0 = already activated) MoranBlock *big.Int `json:"moranBlock,omitempty" toml:",omitempty"` // moranBlock switch block (nil = no fork, 0 = already activated) PlanckBlock *big.Int `json:"planckBlock,omitempty" toml:",omitempty"` // planckBlock switch block (nil = no fork, 0 = already activated) + ClaudeBlock *big.Int `json:"claudeBlock,omitempty" toml:",omitempty"` // claudeBlock switch block (nil = no fork, 0 = already activated) + ElwoodBlock *big.Int `json:"elwoodBlock,omitempty" toml:",omitempty"` // elwoodBlock switch block (nil = no fork, 0 = already activated) // Various consensus engines Ethash *EthashConfig `json:"ethash,omitempty" toml:",omitempty"` @@ -314,8 +326,9 @@ func (c *CliqueConfig) String() string { // ParliaConfig is the consensus engine configs for proof-of-staked-authority based sealing. type ParliaConfig struct { - Period uint64 `json:"period"` // Number of seconds between blocks to enforce - Epoch uint64 `json:"epoch"` // Epoch length to update validatorSet + Period uint64 `json:"period"` // Number of seconds between blocks to enforce + Epoch uint64 `json:"epoch"` // Epoch length to update validatorSet + StateEpochPeriod uint64 `json:"stateEpochPeriod"` // StateEpochPeriod it indicates the length of a state epoch, default 7_008_000 } // String implements the stringer interface, returning the consensus engine details. @@ -336,7 +349,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, Ramanujan: %v, Niels: %v, MirrorSync: %v, Bruno: %v, Berlin: %v, YOLO v3: %v, CatalystBlock: %v, London: %v, ArrowGlacier: %v, MergeFork:%v, Euler: %v, Gibbs: %v, Nano: %v, Moran: %v, Planck: %v, Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Petersburg: %v Istanbul: %v, Muir Glacier: %v, Ramanujan: %v, Niels: %v, MirrorSync: %v, Bruno: %v, Berlin: %v, YOLO v3: %v, CatalystBlock: %v, London: %v, ArrowGlacier: %v, MergeFork:%v, Euler: %v, Gibbs: %v, Nano: %v, Moran: %v, Planck: %v, Claude: %v, Elwood: %v, Engine: %v}", c.ChainID, c.HomesteadBlock, c.DAOForkBlock, @@ -364,6 +377,8 @@ func (c *ChainConfig) String() string { c.NanoBlock, c.MoranBlock, c.PlanckBlock, + c.ClaudeBlock, + c.ElwoodBlock, engine, ) } @@ -527,6 +542,22 @@ func (c *ChainConfig) IsOnPlanck(num *big.Int) bool { return configNumEqual(c.PlanckBlock, num) } +func (c *ChainConfig) IsClaude(num *big.Int) bool { + return isForked(c.ClaudeBlock, num) +} + +func (c *ChainConfig) IsOnClaude(num *big.Int) bool { + return configNumEqual(c.ClaudeBlock, num) +} + +func (c *ChainConfig) IsElwood(num *big.Int) bool { + return isForked(c.ElwoodBlock, num) +} + +func (c *ChainConfig) IsOnElwood(num *big.Int) bool { + return configNumEqual(c.ElwoodBlock, num) +} + // CheckCompatible checks whether scheduled fork transitions have been imported // with a mismatching chain configuration. func (c *ChainConfig) CheckCompatible(newcfg *ChainConfig, height uint64) *ConfigCompatError { @@ -578,6 +609,25 @@ func (c *ChainConfig) CheckConfigForkOrder() error { lastFork = cur } } + + // check state expiry's hard forks + if c.ClaudeBlock != nil || c.ElwoodBlock != nil { + if c.ClaudeBlock == nil { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v ElwoodBlock %v", + c.ClaudeBlock, c.ElwoodBlock) + } + + if c.ClaudeBlock.Cmp(common.Big0) <= 0 { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v", + c.ClaudeBlock) + } + + if c.ElwoodBlock != nil && c.ClaudeBlock.Cmp(c.ElwoodBlock) >= 0 { + return fmt.Errorf("unsupported state expiry fork number ClaudeBlock: %v ElwoodBlock %v", + c.ClaudeBlock, c.ElwoodBlock) + } + } + return nil } @@ -658,6 +708,12 @@ func (c *ChainConfig) checkCompatible(newcfg *ChainConfig, head *big.Int) *Confi if isForkIncompatible(c.PlanckBlock, newcfg.PlanckBlock, head) { return newCompatError("planck fork block", c.PlanckBlock, newcfg.PlanckBlock) } + if isForkIncompatible(c.ClaudeBlock, newcfg.ClaudeBlock, head) { + return newCompatError("claude fork block", c.ClaudeBlock, newcfg.ClaudeBlock) + } + if isForkIncompatible(c.ElwoodBlock, newcfg.ElwoodBlock, head) { + return newCompatError("elwood fork block", c.ElwoodBlock, newcfg.ElwoodBlock) + } return nil } @@ -730,6 +786,8 @@ type Rules struct { IsNano bool IsMoran bool IsPlanck bool + IsClaude bool + IsElwood bool } // Rules ensures c's ChainID is not nil. @@ -754,5 +812,7 @@ func (c *ChainConfig) Rules(num *big.Int, isMerge bool) Rules { IsNano: c.IsNano(num), IsMoran: c.IsMoran(num), IsPlanck: c.IsPlanck(num), + IsClaude: c.IsClaude(num), + IsElwood: c.IsElwood(num), } } diff --git a/params/protocol_params.go b/params/protocol_params.go index e244c24231..e3f3aa0ae5 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -86,10 +86,13 @@ const ( SelfdestructRefundGas uint64 = 24000 // Refunded following a selfdestruct operation. MemoryGas uint64 = 3 // Times the address of the (highest referenced byte in memory + 1). NOTE: referencing happens on read, write and in instructions such as RETURN and CALL. - TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. - TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) - TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list - TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxDataNonZeroGasFrontier uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. + TxDataNonZeroGasEIP2028 uint64 = 16 // Per byte of non zero data attached to a transaction after EIP 2028 (part in Istanbul) + TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list + TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxWitnessListStorageGasPerByte uint64 = 16 // Per byte gas in BEP-215 witness list + TxWitnessListVerifyMPTBaseGas uint64 = 60 // Base gas in BEP-215 witness list verify, refer to EVM Sha256 gas + TxWitnessListVerifyMPTGasPerWord uint64 = 12 // Per-word price in BEP-215 witness list verify, refer to EVM Sha256 gas // These have been changed during the course of the chain CallGasFrontier uint64 = 40 // Once per CALL operation & message call transaction. diff --git a/tests/state_test_util.go b/tests/state_test_util.go index e71cd2bf3c..bca38048ab 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -364,7 +364,7 @@ func (tx *stTransaction) toMessage(ps stPostState, baseFee *big.Int) (core.Messa } msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, gasPrice, - tx.MaxFeePerGas, tx.MaxPriorityFeePerGas, data, accessList, false) + tx.MaxFeePerGas, tx.MaxPriorityFeePerGas, data, accessList, nil, false) return msg, nil } diff --git a/tests/transaction_test_util.go b/tests/transaction_test_util.go index 82ee01de15..f63ed97dca 100644 --- a/tests/transaction_test_util.go +++ b/tests/transaction_test_util.go @@ -55,7 +55,7 @@ func (tt *TransactionTest) Run(config *params.ChainConfig) error { return nil, nil, err } // Intrinsic gas - requiredGas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, isHomestead, isIstanbul) + requiredGas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.WitnessList(), tx.To() == nil, isHomestead, isIstanbul) if err != nil { return nil, nil, err } diff --git a/trie/committer.go b/trie/committer.go index 5d9a27a503..a7df16d438 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -119,6 +119,13 @@ func (c *committer) commit(n node, db *Database) (node, int, error) { return hn, childCommitted + 1, nil } return collapsed, childCommitted, nil + case *rootNode: + collapsed := cn.copy() + hashedNode := c.store(collapsed, db) + if hn, ok := hashedNode.(hashNode); ok { + return hn, 1, nil + } + return collapsed, 0, nil case hashNode: return cn, 0, nil default: @@ -250,6 +257,8 @@ func estimateSize(n node) int { return 1 + len(n) case hashNode: return 1 + len(n) + case *rootNode: + return 5 + len(n.TrieRoot) + len(n.ShadowTreeRoot) default: panic(fmt.Sprintf("node type %T", n)) } diff --git a/trie/database.go b/trie/database.go index db465d4e9e..6c07654683 100644 --- a/trie/database.go +++ b/trie/database.go @@ -28,6 +28,7 @@ import ( "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" @@ -109,14 +110,25 @@ func (n rawNode) EncodeRLP(w io.Writer) error { return err } +func (n rawNode) nodeType() int { + return rawNodeType +} + // rawFullNode represents only the useful data content of a full node, with the // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. -type rawFullNode [17]node +type rawFullNode struct { + children [BranchNodeLength]node + epoch types.StateEpoch `rlp:"-" json:"-"` +} func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawFullNode) nodeType() int { + return rawFullNodeType +} + func (n rawFullNode) EncodeRLP(w io.Writer) error { eb := rlp.NewEncoderBuffer(w) n.encode(eb) @@ -127,13 +139,18 @@ func (n rawFullNode) EncodeRLP(w io.Writer) error { // caches and flags stripped out to minimize its data storage. This type honors // the same RLP encoding as the original parent. type rawShortNode struct { - Key []byte - Val node + Key []byte + Val node + epoch types.StateEpoch `rlp:"-" json:"-"` } func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } +func (n rawShortNode) nodeType() int { + return rawShortNodeType +} + // cachedNode is all the information we know about a single cached trie node // in the memory database write layer. type cachedNode struct { @@ -189,18 +206,46 @@ func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { } } +func (n *cachedNode) getEpoch() (types.StateEpoch, error) { + switch n := n.node.(type) { + case *rawShortNode: + return n.epoch, nil + case *rawFullNode: + return n.epoch, nil + case *rootNode: + return n.Epoch, nil + default: + return 0, fmt.Errorf("unknown node type: %T", n) // TODO(asyukii): may never reach this case, consider panic + } +} + +func (n *cachedNode) updateEpoch(epoch types.StateEpoch) { + switch n := n.node.(type) { + case *rawShortNode: + n.epoch = epoch + case *rawFullNode: + n.epoch = epoch + case *rootNode: + n.Epoch = epoch + default: + return + } +} + // forGatherChildren traverses the node hierarchy of a collapsed storage node and // invokes the callback for all the hashnode children. func forGatherChildren(n node, onChild func(hash common.Hash)) { switch n := n.(type) { case *rawShortNode: forGatherChildren(n.Val, onChild) - case rawFullNode: + case *rawFullNode: for i := 0; i < 16; i++ { - forGatherChildren(n[i], onChild) + forGatherChildren(n.children[i], onChild) } case hashNode: onChild(common.BytesToHash(n)) + case *rootNode: + onChild(n.TrieRoot) case valueNode, nil, rawNode: default: panic(fmt.Sprintf("unknown node type: %T", n)) @@ -213,19 +258,22 @@ func simplifyNode(n node) node { switch n := n.(type) { case *shortNode: // Short nodes discard the flags and cascade - return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)} + return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val), epoch: n.epoch} case *fullNode: // Full nodes discard the flags and cascade - node := rawFullNode(n.Children) - for i := 0; i < len(node); i++ { - if node[i] != nil { - node[i] = simplifyNode(node[i]) + node := &rawFullNode{ + children: n.Children, + epoch: n.epoch, + } + for i := 0; i < len(node.children); i++ { + if node.children[i] != nil { + node.children[i] = simplifyNode(node.children[i]) } } return node - case valueNode, hashNode, rawNode: + case valueNode, hashNode, rawNode, *rootNode: return n default: @@ -247,7 +295,7 @@ func expandNode(hash hashNode, n node) node { }, } - case rawFullNode: + case *rawFullNode: // Full nodes need child expansion node := &fullNode{ flags: nodeFlag{ @@ -255,13 +303,13 @@ func expandNode(hash hashNode, n node) node { }, } for i := 0; i < len(node.Children); i++ { - if n[i] != nil { - node.Children[i] = expandNode(nil, n[i]) + if n.children[i] != nil { + node.Children[i] = expandNode(nil, n.children[i]) } } return node - case valueNode, hashNode: + case valueNode, hashNode, *rootNode: return n default: @@ -324,6 +372,15 @@ func (db *Database) insert(hash common.Hash, size int, node node) { // If the node's already cached, skip if _, ok := db.dirties[hash]; ok { + // update the epoch + switch n := node.(type) { + case *shortNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + case *fullNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + case *rootNode: + db.dirties[hash].updateEpoch(n.getEpoch()) + } return } memcacheDirtyWriteMeter.Mark(int64(size)) @@ -522,7 +579,7 @@ func (db *Database) reference(child common.Hash, parent common.Hash) { } // Dereference removes an existing reference from a root node. -func (db *Database) Dereference(root common.Hash) { +func (db *Database) Dereference(root common.Hash, epoch types.StateEpoch) { // Sanity check to ensure that the meta-root is not removed if root == (common.Hash{}) { log.Error("Attempted to dereference the trie cache meta root") @@ -532,7 +589,7 @@ func (db *Database) Dereference(root common.Hash) { defer db.lock.Unlock() nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now() - db.dereference(root, common.Hash{}) + db.dereference(root, common.Hash{}, epoch) db.gcnodes += uint64(nodes - len(db.dirties)) db.gcsize += storage - db.dirtiesSize @@ -546,8 +603,12 @@ func (db *Database) Dereference(root common.Hash) { "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize) } +func checkBEP206PruneRule(childEpoch, parentEpoch, currEpoch types.StateEpoch) bool { + return types.EpochExpired(childEpoch, currEpoch) && (types.EpochExpired(parentEpoch, currEpoch) || parentEpoch >= childEpoch+2) +} + // dereference is the private locked version of Dereference. -func (db *Database) dereference(child common.Hash, parent common.Hash) { +func (db *Database) dereference(child common.Hash, parent common.Hash, epoch types.StateEpoch) { // Dereference the parent-child node := db.dirties[parent] @@ -558,6 +619,7 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { db.childrenSize -= (common.HashLength + 2) // uint16 counter } } + parentEpoch, getParentEpochErr := node.getEpoch() // If the child does not exist, it's a previously committed node. node, ok := db.dirties[child] if !ok { @@ -571,7 +633,15 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { // no problem in itself, but don't make maxint parents out of it. node.parents-- } - if node.parents == 0 { + childEpoch, getChildEpochErr := node.getEpoch() + canPruneExpired := false + if childEpoch != 0 && getParentEpochErr == nil && getChildEpochErr == nil { // TODO(asyukii): temporary fix, will not prune epoch 0 nodes because of account trie nodes + canPruneExpired = checkBEP206PruneRule(childEpoch, parentEpoch, epoch) + } + if canPruneExpired || node.parents == 0 { // Delete nodes if expired or no more parents node referencing this node + if canPruneExpired { + log.Info("Dereferencing expired trie node") + } // Remove the node from the flush-list switch child { case db.oldest: @@ -586,7 +656,7 @@ func (db *Database) dereference(child common.Hash, parent common.Hash) { } // Dereference all children and delete the node node.forChilds(func(hash common.Hash) { - db.dereference(hash, child) + db.dereference(hash, child, epoch) }) delete(db.dirties, child) db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size)) diff --git a/trie/dummy_trie.go b/trie/dummy_trie.go index 42e79a3719..4d9abc410a 100644 --- a/trie/dummy_trie.go +++ b/trie/dummy_trie.go @@ -32,6 +32,10 @@ func (t *EmptyTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWrit return nil } +func (t *EmptyTrie) ProveStorageWitness(key, from []byte, proofDb ethdb.KeyValueWriter) error { + return nil +} + // NewSecure creates a dummy trie func NewEmptyTrie() *EmptyTrie { return &EmptyTrie{} @@ -45,6 +49,10 @@ func (t *EmptyTrie) TryGet(key []byte) ([]byte, error) { return nil, nil } +func (t *EmptyTrie) TryUpdateEpoch(key []byte) error { + return nil +} + func (t *EmptyTrie) TryGetNode(path []byte) ([]byte, int, error) { return nil, 0, nil } @@ -70,6 +78,14 @@ func (t *EmptyTrie) GetKey(shaKey []byte) []byte { return nil } +func (t *EmptyTrie) Epoch() types.StateEpoch { + return types.StateEpoch0 +} + +func (t *EmptyTrie) HashKey(key []byte) []byte { + return nil +} + func (t *EmptyTrie) Commit(onleaf LeafCallback) (root common.Hash, committed int, err error) { return common.Hash{}, 0, nil } @@ -99,3 +115,7 @@ func (t *EmptyTrie) NodeIterator(start []byte) NodeIterator { func (t *EmptyTrie) TryUpdateAccount(key []byte, account *types.StateAccount) error { return nil } + +func (t *EmptyTrie) ReviveTrie(proof []*MPTProofNub) []*MPTProofNub { + return nil +} diff --git a/trie/errors.go b/trie/errors.go index 567b80078c..e11ba816b8 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -19,6 +19,8 @@ package trie import ( "fmt" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" ) @@ -33,3 +35,19 @@ type MissingNodeError struct { func (err *MissingNodeError) Error() string { return fmt.Sprintf("missing trie node %x (path %x)", err.NodeHash, err.Path) } + +type ExpiredNodeError struct { + Path []byte // hex-encoded path to the expired node + Epoch types.StateEpoch +} + +func NewExpiredNodeError(path []byte, epoch types.StateEpoch) error { + return &ExpiredNodeError{ + Path: path, + Epoch: epoch, + } +} + +func (err *ExpiredNodeError) Error() string { + return "expired trie node" +} diff --git a/trie/hasher.go b/trie/hasher.go index e9f45f8341..edad7f7aa7 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -19,6 +19,8 @@ package trie import ( "sync" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" "golang.org/x/crypto/sha3" @@ -139,6 +141,28 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached return collapsed, cached } +// shadowFullNodeToHash hash shadowBranchNode +func (h *hasher) shadowBranchNodeToHash(n *shadowBranchNode) *common.Hash { + n.encode(h.encbuf) + enc := h.encodedBytes() + return h.hashCommon(enc) +} + +func (h *hasher) shadowNodeHashListToHash(hashList []*common.Hash) *common.Hash { + if len(hashList) == 0 { + return nil + } + w := h.encbuf + offset := w.List() + for _, hash := range hashList { + w.WriteBytes(hash[:]) + } + w.ListEnd(offset) + + enc := h.encodedBytes() + return h.hashCommon(enc) +} + // shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode // should have hex-type Key, which will be converted (without modification) // into compact form for RLP encoding. @@ -190,6 +214,15 @@ func (h *hasher) hashData(data []byte) hashNode { return n } +// hashCommon hashes the provided data +func (h *hasher) hashCommon(data []byte) *common.Hash { + var n common.Hash + h.sha.Reset() + h.sha.Write(data) + h.sha.Read(n[:]) + return &n +} + // proofHash is used to construct trie proofs, and returns the 'collapsed' // node (for later RLP encoding) aswell as the hashed node -- unless the // node is smaller than 32 bytes, in which case it will be returned as is. diff --git a/trie/node.go b/trie/node.go index 6ce6551ded..fdf08be1ec 100644 --- a/trie/node.go +++ b/trie/node.go @@ -22,26 +22,47 @@ import ( "strings" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" ) +const ( + BranchNodeLength = 17 +) + +const ( + shortNodeType = iota + fullNodeType + hashNodeType + valueNodeType + rawNodeType + rawShortNodeType + rawFullNodeType + rootNodeType +) + var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} type node interface { cache() (hashNode, bool) encode(w rlp.EncoderBuffer) fstring(string) string + nodeType() int } type ( fullNode struct { - Children [17]node // Actual trie node data to encode/decode (needs custom encoder) - flags nodeFlag + Children [BranchNodeLength]node // Actual trie node data to encode/decode (needs custom encoder) + flags nodeFlag + epoch types.StateEpoch `rlp:"-" json:"-"` + shadowNode shadowBranchNode `rlp:"-" json:"-"` } shortNode struct { - Key []byte - Val node - flags nodeFlag + Key []byte + Val node + flags nodeFlag + epoch types.StateEpoch `rlp:"-" json:"-"` + shadowNode shadowExtensionNode `rlp:"-" json:"-"` } hashNode []byte valueNode []byte @@ -58,6 +79,38 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { return eb.Flush() } +func (n *fullNode) GetShadowNode() *shadowBranchNode { + return &n.shadowNode +} + +func (n *fullNode) GetChildEpoch(index int) types.StateEpoch { + if index < 16 { + return n.GetShadowNode().EpochMap[index] + } + return n.epoch +} + +func (n *fullNode) UpdateChildEpoch(index int, epoch types.StateEpoch) { + if index < 16 { + n.GetShadowNode().EpochMap[index] = epoch + } +} + +func (n *fullNode) ChildExpired(prefix []byte, index int, currentEpoch types.StateEpoch) (bool, error) { + childEpoch := n.GetChildEpoch(index) + if types.EpochExpired(childEpoch, currentEpoch) { + return true, &ExpiredNodeError{ + Path: prefix, + Epoch: childEpoch, + } + } + return false, nil +} + +func (n *shortNode) GetShadowNode() *shadowExtensionNode { + return &n.shadowNode +} + func (n *fullNode) copy() *fullNode { copy := *n; return © } func (n *shortNode) copy() *shortNode { copy := *n; return © } @@ -78,6 +131,11 @@ func (n *shortNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } +func (n *fullNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } +func (n *shortNode) setEpoch(epoch types.StateEpoch) { n.epoch = epoch } +func (n *fullNode) getEpoch() types.StateEpoch { return n.epoch } +func (n *shortNode) getEpoch() types.StateEpoch { return n.epoch } + func (n *fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) for i, node := range &n.Children { @@ -99,6 +157,22 @@ func (n valueNode) fstring(ind string) string { return fmt.Sprintf("%x ", []byte(n)) } +func (n *shortNode) nodeType() int { + return shortNodeType +} + +func (n *fullNode) nodeType() int { + return fullNodeType +} + +func (n hashNode) nodeType() int { + return hashNodeType +} + +func (n valueNode) nodeType() int { + return valueNodeType +} + // mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered. func mustDecodeNode(hash, buf []byte) node { n, err := decodeNode(hash, buf) @@ -145,6 +219,9 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { case 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") + case 3: + n, err := DecodeRootNode(buf) + return n, wrapError(err, "root") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) } @@ -155,21 +232,24 @@ func decodeShort(hash, elems []byte) (node, error) { if err != nil { return nil, err } - flag := nodeFlag{hash: hash} + n := &shortNode{flags: nodeFlag{hash: hash}} key := compactToHex(kbuf) + n.Key = key if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) if err != nil { return nil, fmt.Errorf("invalid value node: %v", err) } - return &shortNode{key, valueNode(val), flag}, nil + n.Val = valueNode(val) + return n, nil } r, _, err := decodeRef(rest) if err != nil { return nil, wrapError(err, "val") } - return &shortNode{key, r, flag}, nil + n.Val = r + return n, nil } func decodeFull(hash, elems []byte) (*fullNode, error) { diff --git a/trie/node_enc.go b/trie/node_enc.go index cade35b707..c8990b17d1 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -59,9 +59,9 @@ func (n valueNode) encode(w rlp.EncoderBuffer) { w.WriteBytes(n) } -func (n rawFullNode) encode(w rlp.EncoderBuffer) { +func (n *rawFullNode) encode(w rlp.EncoderBuffer) { offset := w.List() - for _, c := range n { + for _, c := range n.children { if c != nil { c.encode(w) } else { diff --git a/trie/proof.go b/trie/proof.go index b413eeaf1a..be492c1db8 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -21,6 +21,8 @@ import ( "errors" "fmt" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/memorydb" @@ -39,32 +41,11 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e key = keybytesToHex(key) var nodes []node tn := t.root - for len(key) > 0 && tn != nil { - switch n := tn.(type) { - case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { - // The trie doesn't contain the key. - tn = nil - } else { - tn = n.Val - key = key[len(n.Key):] - } - nodes = append(nodes, n) - case *fullNode: - tn = n.Children[key[0]] - key = key[1:] - nodes = append(nodes, n) - case hashNode: - var err error - tn, err = t.resolveHash(n, nil) - if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - return err - } - default: - panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) - } + _, err := t.traverseNodes(tn, key, &nodes) + if err != nil { + return err } + hasher := newHasher(false) defer returnHasherToPool(hasher) @@ -99,6 +80,58 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri return t.trie.Prove(key, fromLevel, proofDb) } +// ProveStorageWitness constructs a merkle proof for a storage key. If the prefix key is specified, +// the proof will start from the node that contains the prefix key to get the partial proof. +// The result contains all encoded nodes from the starting node to the node that contains +// the value. The value itself is also included in the last node and can be retrieved by +// verifying the proof. +func (t *Trie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { + + if len(key) == 0 { + return fmt.Errorf("key is empty") + } + + key = keybytesToHex(key) + + // traverse down using the prefixKeyHex + var nodes []node + tn := t.root + startNode, err := t.traverseNodes(tn, prefixKeyHex, nil) // obtain the node that contains the prefixKeyHex + if err != nil { + return err + } + + key = key[len(prefixKeyHex):] // obtain the suffix key + + // traverse through the suffix key + _, err = t.traverseNodes(startNode, key, &nodes) + if err != nil { + return err + } + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + // construct the proof + for _, n := range nodes { + var hn node + n, hn = hasher.proofHash(n) + if hash, ok := hn.(hashNode); ok { + enc := nodeToBytes(n) + if !ok { + hash = hasher.hashData(enc) + } + proofDb.Put(hash, enc) + } + } + + return nil +} + +func (t *SecureTrie) ProveStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueWriter) error { + return t.trie.ProveStorageWitness(key, prefixKeyHex, proofDb) +} + // VerifyProof checks merkle proofs. The given proof must contain the value for // key in a trie with the given root hash. VerifyProof returns an error if the // proof contains invalid trie nodes or the wrong value. @@ -128,6 +161,328 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) } } +// MPTProofNub include fullNode shortNode, revive n1 first if exist, +// revive n2 later if exist, include node hash +type MPTProofNub struct { + n1PrefixKey []byte // n1's prefix hex key, max 64bytes + n1 node + n2PrefixKey []byte // n2's prefix hex key, max 64bytes + n2 node +} + +// ResolveKV revive state could revive KV from fullNode[0-15] or fullNode[16] or shortNode.Val, return KVs for cache & snap +func (m *MPTProofNub) ResolveKV() (map[string][]byte, error) { + kvMap := make(map[string][]byte) + if err := resolveKV(m.n1, m.n1PrefixKey, kvMap); err != nil { + return nil, err + } + if err := resolveKV(m.n2, m.n2PrefixKey, kvMap); err != nil { + return nil, err + } + + return kvMap, nil +} + +func resolveKV(origin node, prefixKey []byte, kvWriter map[string][]byte) error { + switch n := origin.(type) { + case nil, hashNode: + return nil + case valueNode: + kvWriter[string(hexToKeybytes(prefixKey))] = n + return nil + case *shortNode: + return resolveKV(n.Val, append(prefixKey, n.Key...), kvWriter) + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + if err := resolveKV(n.Children[i], append(prefixKey, byte(i)), kvWriter); err != nil { + return err + } + } + return resolveKV(n.Children[BranchNodeLength-1], prefixKey, kvWriter) + default: + panic(fmt.Sprintf("invalid node: %v", origin)) + } +} + +type MPTProofCache struct { + types.MPTProof + + cacheHexPath [][]byte // cache path for performance + cacheHashes [][]byte // cache hash for performance + cacheNodes []node // cache node for performance + cacheNubs []*MPTProofNub // cache proof nubs to check revive duplicate +} + +// VerifyProof verify proof in MPT witness +// 1. calculate hash +// 2. decode trie node +// 3. verify partial merkle proof of the witness +// 4. split to partial witness +func (m *MPTProofCache) VerifyProof() error { + m.cacheHashes = make([][]byte, len(m.Proof)) + m.cacheNodes = make([]node, len(m.Proof)) + m.cacheHexPath = make([][]byte, len(m.Proof)) + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + var child []byte + for i := len(m.Proof) - 1; i >= 0; i-- { + m.cacheHashes[i] = hasher.hashData(m.Proof[i]) + n, err := decodeNode(m.cacheHashes[i], m.Proof[i]) + if err != nil { + return err + } + m.cacheNodes[i] = n + + switch t := n.(type) { + case *shortNode: + m.cacheHexPath[i] = t.Key + if err := matchHashNodeInShortNode(child, t); err != nil { + return err + } + case *fullNode: + index, err := matchHashNodeInFullNode(child, t) + if err != nil { + return err + } + if index >= 0 { + m.cacheHexPath[i] = []byte{byte(index)} + } + case valueNode: + if child != nil { + return errors.New("proof wrong child in valueNode") + } + default: + return fmt.Errorf("proof got wrong trie node: %v", t.nodeType()) + } + + child = m.cacheHashes[i] + } + + // cache proof nubs + m.cacheNubs = make([]*MPTProofNub, 0, len(m.Proof)) + prefix := m.RootKeyHex + for i := 0; i < len(m.cacheNodes); i++ { + if i-1 >= 0 { + prefix = copyNewSlice(prefix, m.cacheHexPath[i-1]) + } + // prefix = append(prefix, m.cacheHexPath[i]...) + n1 := m.cacheNodes[i] + nub := MPTProofNub{ + n1PrefixKey: prefix, + n1: n1, + n2: nil, + n2PrefixKey: nil, + } + + // check if satisfy partial witness rules, + // that short node must with its child, may full node or valueNode + merge, err := mergeNextNode(m.cacheNodes, i) + if err != nil { + return err + } + if merge { + i++ + prefix = copyNewSlice(prefix, m.cacheHexPath[i-1]) + nub.n2 = m.cacheNodes[i] + nub.n2PrefixKey = prefix + } + m.cacheNubs = append(m.cacheNubs, &nub) + } + + return nil +} + +func copyNewSlice(s1, s2 []byte) []byte { + ret := make([]byte, len(s1)+len(s2)) + copy(ret, s1) + copy(ret[len(s1):], s2) + return ret +} + +func (m *MPTProofCache) CacheNubs() []*MPTProofNub { + return m.cacheNubs +} + +// mergeNextNode check short node must with child in same nub +func mergeNextNode(nodes []node, i int) (bool, error) { + if i >= len(nodes) { + return false, errors.New("mergeNextNode input outbound index") + } + + n1 := nodes[i] + switch n := n1.(type) { + case *shortNode: + need, err := needNextProofNode(n, n.Val) + if err != nil { + return false, err + } + if need && i+1 >= len(nodes) { + return false, errors.New("mergeNextNode short node must with its child") + } + return need, nil + case valueNode: + return false, errors.New("mergeNextNode value node need merge with prev node") + } + + if i+1 >= len(nodes) { + return false, nil + } + return nodes[i+1].nodeType() == valueNodeType, nil +} + +// needNextProofNode check if node need merge next node into a proofNub, because TrieExtendNode must with its child to revive together +func needNextProofNode(parent, origin node) (bool, error) { + switch n := origin.(type) { + case *fullNode: + for i := 0; i < BranchNodeLength-1; i++ { + need, err := needNextProofNode(n, n.Children[i]) + if err != nil { + return false, err + } + if need { + return true, nil + } + } + return false, nil + case *shortNode: + if parent.nodeType() == shortNodeType { + return false, errors.New("needNextProofNode cannot short node's child is short node") + } + return needNextProofNode(n, n.Val) + case valueNode: + return false, nil + case hashNode: + if parent.nodeType() == fullNodeType { + return false, nil + } + return true, nil + default: + return false, errors.New("needNextProofNode unsupported node") + } +} + +func matchHashNodeInFullNode(child []byte, n *fullNode) (int, error) { + if child == nil { + return -1, nil + } + + for i := 0; i < BranchNodeLength-1; i++ { + switch v := n.Children[i].(type) { + case hashNode: + if bytes.Equal(child, v) { + return i, nil + } + } + } + return -1, errors.New("proof cannot find target child in fullNode") +} + +func matchHashNodeInShortNode(child []byte, n *shortNode) error { + if child == nil { + return nil + } + + switch v := n.Val.(type) { + case hashNode: + if !bytes.Equal(child, v) { + return errors.New("proof wrong child in shortNode") + } + default: + return errors.New("proof must hashNode when meet shortNode") + } + return nil +} + +// VerifyStorageWitness checks a merkle proof for a storage key. If the prefix key is specified, +// it will traverse down to the node that contains the prefix key. From there, proof will be verified. +// VerifyStorageProof returns an error if the proof contains invalid trie nodes. +func (t *Trie) VerifyStorageWitness(key []byte, prefixKeyHex []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { + + if len(key) == 0 { + return nil, fmt.Errorf("empty key provided") + } + + key = keybytesToHex(key) + + tn := t.root + startNode, err := t.traverseNodes(tn, prefixKeyHex, nil) + if err != nil { + return nil, err + } + + key = key[len(prefixKeyHex):] // obtain the suffix key + + hasher := newHasher(false) + defer returnHasherToPool(hasher) + + _, hn := hasher.proofHash(startNode) + wantHash, ok := hn.(hashNode) + + if !ok { // node is not hashed + return nil, nil + } + for i := 0; ; i++ { + buf, _ := proofDb.Get(wantHash[:]) + if buf == nil { + return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) + } + n, err := decodeNode(wantHash[:], buf) + if err != nil { + return nil, fmt.Errorf("bad proof node %d: %v", i, err) + } + keyrest, cld := get(n, key, true) + switch cld := cld.(type) { + case nil: + return nil, nil + case hashNode: + key = keyrest + copy(wantHash[:], cld) + case valueNode: + return cld, nil + } + } +} + +// traverseNodes traverses the trie with the given key starting at the given node. +// If the trie contains the key, the returned node is the node that contains the +// value for the key. If nodes is specified, the traversed nodes are appended to +// it. +func (t *Trie) traverseNodes(tn node, key []byte, nodes *[]node) (node, error) { + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + key = key[len(n.Key):] + } + if nodes != nil { + *nodes = append(*nodes, n) + } + case *fullNode: + tn = n.Children[key[0]] + key = key[1:] + if nodes != nil { + *nodes = append(*nodes, n) + } + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + return tn, err + } + default: + panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) + } + } + + return tn, nil +} + // proofToPath converts a merkle proof to trie node path. The main purpose of // this function is recovering a node path from the merkle proof stream. All // necessary nodes will be resolved and leave the remaining as hashnode. diff --git a/trie/proof_test.go b/trie/proof_test.go index 29866714c2..ede4ccea5b 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -20,11 +20,18 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" + "encoding/hex" mrand "math/rand" "sort" + "strings" "testing" "time" + "github.com/ethereum/go-ethereum/rlp" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" @@ -892,6 +899,219 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { } } +// TestStorageProof tests the storage proof generation and verification. +// This test will also test for partial proof generation and verification. +func TestStorageProof(t *testing.T) { + trie, vals := randomTrie(500) + for _, kv := range vals { + prefixKeys := getPrefixKeysHex(trie, kv.k) + for _, prefixKey := range prefixKeys { + proof := memorydb.New() + key := kv.k + err := trie.ProveStorageWitness(key, prefixKey, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", kv.k) + } + val, err := trie.VerifyStorageWitness(key, prefixKey, proof) + if err != nil { + t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) + } + if val != nil && !bytes.Equal(val, kv.v) { + t.Fatalf("failed to verify proof for key %x: prefix %x: %v\nraw proof: %x", key, prefixKey, err, proof) + } + } + } +} + +// TestEmptyStorageProof tests storage verification with empty proof. +// The verifier should nil for both value and error. +func TestEmptyStorageProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := []byte("k") + + val, err := trie.VerifyStorageWitness(key, nil, proof) + if val != nil && err != nil { + t.Fatalf("expected nil value and error for empty proof") + } +} + +// TestEmptyKeyStorageProof tests the storage proof with empty key. +// The prover is expected to return +func TestEmptyKeyStorageProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + err := trie.ProveStorageWitness([]byte(""), nil, proof) + if err == nil { + t.Fatalf("expected error for empty key") + } +} + +// TestEmptyPrefixKeyStorageProof tests the storage proof with empty prefix key, +// which means that all proofs generated are from the root node. +func TestEmptyPrefixKeyStorageProof(t *testing.T) { + trie, vals := randomTrie(500) + for _, kv := range vals { + proof := memorydb.New() + key := kv.k + + err := trie.ProveStorageWitness(key, nil, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", key) + } + val, err := trie.VerifyStorageWitness(key, nil, proof) + if err != nil { + t.Fatalf("failed to verify proof for key %x: %v\nraw proof: %x", key, err, proof) + } + if !bytes.Equal(val, kv.v) { + t.Fatalf("verified value mismatch for key %x: have %x, want %x", key, val, kv.v) + } + } +} + +// TestBadStorageProof tests a few cases which the proof is wrong. +// The proof is expected to detect the error. +func TestBadStorageProof(t *testing.T) { + + trie, vals := randomTrie(500) + for _, kv := range vals { + prefixKeys := getPrefixKeysHex(trie, kv.k) + for _, prefixKey := range prefixKeys { + proof := memorydb.New() + key := kv.k + err := trie.ProveStorageWitness(key, prefixKey, proof) + if err != nil { + t.Fatalf("missing key %x while constructing proof", key) + } + + if proof.Len() == 0 { + continue + } + + it := proof.NewIterator(nil, nil) + for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ { + it.Next() + } + itKey := it.Key() + itVal, _ := proof.Get(itKey) + proof.Delete(itKey) + it.Release() + + mutateByte(itVal) + proof.Put(crypto.Keccak256(itVal), itVal) + + if val, err := trie.VerifyStorageWitness(key, prefixKey, proof); err == nil && val != nil { + t.Fatalf("expected proof to fail for key: %x, prefix: %x", key, prefixKey) + } + } + } +} + +// TestBadKeyStorageProof tests the storage proof with a bad key. +// The verifier is expected to return nil for both value and error. +func TestBadKeyStorageProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := []byte("x") + trie.ProveStorageWitness(key, nil, proof) + + val, err := trie.VerifyStorageWitness(key, nil, proof) + if val != nil && err != nil { + t.Fatalf("expected nil value and error for bad key") + } +} + +// TestBadPrefixKeyStorageProof tests the storage proof with a bad prefix key. +// The verifier is expected to return nil for both value and error. +func TestBadPrefixKeyStorageProof(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := []byte("k") + + prefixKey := keybytesToHex([]byte("x")) + + trie.ProveStorageWitness(key, prefixKey, proof) + + val, err := trie.VerifyStorageWitness(key, prefixKey, proof) + if val != nil && err != nil { + t.Fatalf("expected nil value and error for bad prefix key") + } +} + +// TestKeyPrefixKeySame tests the storage proof with the same key and prefix key. +// The proof size should be 0 and the verifier should return nil for both value and error. +func TestKeyPrefixKeySame(t *testing.T) { + trie := new(Trie) + updateString(trie, "k", "v") + + proof := memorydb.New() + key := []byte("k") + + trie.ProveStorageWitness(key, key, proof) + if proof.Len() != 0 { + t.Fatalf("expected proof size to be 0 for same key and prefix key") + } + + val, err := trie.VerifyStorageWitness(key, key, proof) + if val != nil && err != nil { + t.Fatalf("expected nil value and error for same key and prefix key") + } +} + +// TestUnexpiredStorageProof tests the storage proof with a trie containing +// both expired and unexpired data. The prover is expected to give valid proof +// for the unexpired data. +func TestUnexpiredStorageProof(t *testing.T) { + trie := new(Trie) + + expiredData := map[string]string{ + "abcd": "A", + "abce": "B", + "abde": "C", + "abdf": "D", + } + + unexpiredData := map[string]string{ + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + // Loop through the data and insert it into the trie + for k, v := range expiredData { + updateString(trie, k, v) + } + + for k, v := range unexpiredData { + updateString(trie, k, v) + } + + prefixKey := keybytesToHex([]byte("abcd"))[:2] + + trie.ExpireByPrefix(prefixKey) + + proof := memorydb.New() + key := []byte("degi") + + trie.ProveStorageWitness(key, nil, proof) + val, err := trie.VerifyStorageWitness(key, nil, proof) + if err != nil { + t.Fatalf("failed to verify proof: %v\nraw proof: %x", err, proof) + } + if !bytes.Equal(val, []byte("H")) { + t.Fatalf("verified value mismatch: have %x, want %v", val, "H") + } +} + // mutateByte changes one byte in b. func mutateByte(b []byte) { for r := mrand.Intn(len(b)); ; { @@ -1068,6 +1288,25 @@ func nonRandomTrie(n int) (*Trie, map[string]*kv) { return trie, vals } +func nonRandomTrieWithShadowNodes(n int) (*Trie, map[string]*kv) { + trie := new(Trie) + trie.withShadowNodes = true + trie.currentEpoch = 10 // TODO (asyukii): might need to change this + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + elem := &kv{key, value, false} + trie.Update(elem.k, elem.v) // TODO (asyukii): this is not working, the shadow branch node is not being updated + vals[string(elem.k)] = elem + } + return trie, vals +} + func TestRangeProofKeysWithSharedPrefix(t *testing.T) { keys := [][]byte{ common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), @@ -1100,3 +1339,127 @@ func TestRangeProofKeysWithSharedPrefix(t *testing.T) { t.Error("expected more to be false") } } + +func getPrefixKeysHex(t *Trie, key []byte) [][]byte { + var prefixKeys [][]byte + key = keybytesToHex(key) + tn := t.root + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + // Check if there is a previous key in prefixKeys + if len(prefixKeys) == 0 { + prefixKeys = append(prefixKeys, n.Key) + } else { + prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], n.Key...)) + } + key = key[len(n.Key):] + } + case *fullNode: + tn = n.Children[key[0]] + if len(prefixKeys) == 0 { + prefixKeys = append(prefixKeys, key[:1]) + } else { + prefixKeys = append(prefixKeys, append(prefixKeys[len(prefixKeys)-1], key[:1]...)) + } + key = key[1:] + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + return nil + } + default: + return nil + } + } + + return prefixKeys +} + +// TODO add more UTs +// 1. whole trie expired without root node +// 2. whole trie expired with root node +// 3. witness with one short node, val is full node +// 3. witness with one short node, val is value node +// 3. witness with one short node, val is recursive full node + short node + value node +// 4. witness with one full node with val node, children is short node and recursive full node + short node + value node +func TestMPTProofCache_VerifyProof_normalCase(t *testing.T) { + cache := makeMPTProofCache(nil, []string{ + "0xf90211a03697534056039e03300557bd69fe16e18ce4a6ccd5522db4dfa97dfe1fad3d3aa0b1bf1f230b98b9034738d599177ae817c08143b9395a47f300636b0dd2fb3c5ea0aa04a4966751d4c50063fe13a96a6c7924f665819733f556849b5eb9fa1d6839a0e162e080d1c12c59dc984fb2246d8ad61209264bee40d3fdd07c4ea4ff411b6aa0e5c3f2dde71bf303423f34674748567dcdf8379129653b8213f698468738d492a068a3e3059b6e7115a055a7874f81c5a1e84ddc1967527973f8c78cd86a1c9f8fa0d734bd63b7be8e8471091b792f5bbcbc7b0ce582f6d985b7a15a3c0155242c56a00143c06f57a65c8485dbae750aa51df5dff1bf7bdf28060129a20de9e51364eda07b416f79b3f4e39d0159efff351009d44002d9e83530fb5a5778eb55f5f4432ca036706b52196fa0b73feb2e7ff8f1379c7176d427dd44ad63c7b65e66693904a1a0fd6c8b815e2769ce379a20eaccdba1f145fb11f77c280553f15ee4f1ee135375a02f5233009f082177e5ed2bfa6e180bf1a7310e6bc3c079cb85a4ac6fee4ae379a03f07f1bb33fa26ebd772fa874914dc7a08581095e5159fdcf9221be6cbeb6648a097557eec1ac08c3bfe45ce8e34cd329164a33928ac83fef1009656536ef6907fa028196bfb31aa7f14a0a8000b00b0aa5d09450c32d537e45eebee70b14313ff1ca0126ce265ca7bbb0e0b01f068d1edef1544cbeb2f048c99829713c18d7abc049a80", + "0xf90211a01f7c019858f447dbff8ed8e4329e88600bab8f17fec9594c664e25acc95da3dba00916866833439bc250b5e08edc2b7c041634ca3b33a013f79eb299a3e33a056fa0a8fae921061f3bc81154b5d2c149d3860a3cdef00d02173ee3b5837de6b26f56a097583621a54e74994619a97cf82f823005a35ad1bf4795047726619eccd11ad4a0be39789c4abdb2a185cce40f2c77575eee2a1eab38d3168395560da15307614aa0c46a7fcea5501656d70508178f0731460edcce1c01e32ed08c1468f2593db277a0460f030038b09b8461d834ded79d77fd3ed25c4e248775752bf1c830a530ef2ba03d887abec623c6b4d93be75d1608dcacb291bf30406fbbc944a94aa4203fb1eba0475eeb07d471044af74313093cfbb0201e17a405d7af13a7ae6b245e18915515a0ea794035230f90e14f4f601d0e9f04217ed02710df363417bc3dd7dda5a84608a08aa9f0c44e5a9e359b65d4f0e40937662f07af10642a626e19ca7bc56c329706a05677c9b342c1cbcfd491cd491800d44423d1bdc08ab00eb59376a7118f7c4e23a0b761f22d67328e6c90caf8be65affb701b045f4f1f581472fc5f724e0c61328da0b6727240643009e59bba78249918a83ca8faeaf1d5b47fe0b41c356e8ae300dda0d7cbeb12faf439126bdb86f94b6262c3a70e88e4d8a47b49735f0b9d632a5df8a0428fe930556ce5ea94bcc4d092fe0f05a2a9b360175857711729ab3a7092a81780", + "0xf90211a0a81ff74945250ba9925753e379aa8815a3fe77926aad2d02db4d78a3da7ecb48a0e795c64d2b738b34ebb77a3307df800c9f1fe324b442cd71ffa5fd268ec12ffca020a5f968a1c8292d08135cb451daed999a44eb0cdd04526a89c38e753398c50ca0cc7165f6b984f2f2569d101d70f72af94eb79b18c126895da2d3cf557b5aca51a00a10f4dc5851e71f195cca5bd7a2a62a6c3ca03d95e91c7dc55e55f4cb726903a09ecaaf877b18a55ca4001fe46cc10389c94104abc2994cdfe5843f28814b119ca099f59b2b52ee9d9b44ab7123a86775dcfe6c50301dd7cf9c6fc6ea968c1d2a01a0fccda5c1489dd3268fcf2471f6a372fc3afe5eebecc0ba9fc3c023d85da26aeda0d730ea7aabdad2e5451e826726e53c86e6e805e46220bc7edd80bf3f7e467f96a074f3d767a84557aa9559b08b2d1f93d8205827c042d1ac616b8cf37e22de2beca03ca1fe479ea1e64c3bfeae9c603ef4f55c270de0bb4c79d045f67d3481ca2852a0bdcbc3d154db40b5faa1f6875f46d485b96bdfffb4513da631f9b302768faae7a096923f4f559c7b7f912587292cc865aba5934e9e75c9aa88f20de73e928c6c51a0f52ca9710977327dd407ba9d9f00b0075d7e312ee19686f1b53579585ff50132a076adb4cf98f9af261d1b2a147b9b2cbace1647eab537ca1a55e8d40d35775b74a0723fa2f72d6a983939bc9bf9ee22e05fc142437726450e5707c60155bc34ec0580", + "0xf90211a039a64fd9b31f3ea3bf9a991ca828eadb67cf9ae0ebbcdd195d297454699580e8a0c95d85a63beb9a02b56d032116d86fdf9a64065dd5d44e33acef674ae3ceb6f9a0d83ef07a99302abccdce1289d353a20494713f45d8edbd3c2e87f08788a878d9a063a19aee41fec40a98edcf4d60c759c509b512bfc5d9feae7de50c8c2d00eee7a029ce5c8c1ac9939cbf481274b8e6f5f24a430136dc5aab7deb489a9ce7db5a95a03b45c53d2e4f54e49a53eb298aee828f4803581ee60ca52940533b77c3e3fadfa082b8084dfc0337a49fd5d53ce107fa0bdcdf25ac7bbac017d7ac250b77c8764da0d4dab90004fd3bb36b2fa5f6e914a7c90820d119e5ba3ba8439270c6cfbd0a18a0bd9286c9248ae8a953d00aa9906f06b6f0364d0bb9fb615de04c9f8d5b5fd346a00955eccaa41a17fafb0ed66272b5183b3a30a973af7a43e0f25f0b640fd5df0ca0a7ba773602ace05991211770fc5555dd9c55e2518bcb005d1db022584a89132da0450b066588b44701992ab4a926d5dd185dd465abf1e51a3f60ab4a9f924cf85aa05eda0687636512339d0db99eade61c0d44358bb8027185296ad59b4cedd2935aa0cd18604dee296e5443b598e470a04efde04fbb0213c0fda8cc3af20dae2a1c34a01c0474c1bf15e4732d1c22a1303573c7ec8643f711354e853ab26038b2eebe25a081a0b2d329a9519375f71384a7130faafc3a1379db59bfb7db7135c9d27d9cc080", + "0xf901f1a0e14d9caa85464966f0371f427250e7bf2d86f6d41535b08ea79044391d6f0fe7a08bea233c92e4c1d05bd03d62cd93974cc29f6df54e6fa2574bf8dfb3af936a85a0dd4f4af9ab72d1bbbf59c286a731614f48bf3cdcef572cd819426fc2a30ae5d9a058465ea8e97b2f3a873646f9504aff111ab967abaea8ebb2fe91bf058ea162f3a06466c41eb770bae5c07c26cd692540e7f9af70fee2bc164e12f29083ebd2cca1a0fe0925da033ffb967ca1e1d6505c2ca740553916c3b9a205131d719b7921fb00a0973ecee958d1305f1b6f8159a9732e47f5fbd4121f60254ea44a8f028308150780a0f81717c7f8702ed39d89a34fc86827aab31db26b22d22ee399667e4a44081c95a0e276ec24ee74ee61bdbe92e8afa57813d59f26e0ad34f24d71f0627719f8a11ea00a8ec2f6922480890f9e67c62c1654dcccf1710a01a378f614e02a3785d42d7ea08c43cc0c6c690dd87cf42691e9ead799c5ddbcfc9458ebd054a46acecedbe9f7a0d3727ed077c60ed6ff1d9d5d20fdf2912a513942d0efee9ec2d97d33ab9b7f56a03f293b0c0f25b9aa2735c4e42b132008d8af2ecfaaefdc940dc94cf312f5a1dda04860bc5f19829bb277554927c5b6dbc0af93025aca49aed3be020a3080996a26a0e47475bf4f62364d8a0dd87f2c21d0b64ef4d9aaf5fc46d9b1e3af5b83f56d4580", + "0xf89180a0dd3524693059f48ce47c5dd82d0462953a5141c7abc5a63981637b71e0100bcf80a033ca98826ea65c1c4be16e1df12e78142d992bf1fc0189401632ceff3fcbbe7880a0e5daf803b0890d164e287fa43b51332286cddeb9b62280426037fc02dc2b4d5f8080a0bfa907c2e30720a07347d23cfa573dec8c9b3afa51a4a0e0783a481da0fcb1b38080808080808080", + "0xe79e208246cec5810061f4ff7efe1dcd6cb407d59abc3478830df04484584c868786d647b234389e", + }) + + err := cache.VerifyProof() + assert.NoError(t, err) + assert.Equal(t, common.Hex2Bytes("3310913fe74cfb66dbde8fe8557b48e8e65617a17c2375a581c32d49f812cde4"), cache.cacheHashes[0]) + key := common.Hex2Bytes("95eea00c49d14a895954837cd876ffa8cfad96cbaacc40fc31d6df2c902528a8") + hash := make([]byte, common.HashLength) + h := newHasher(false) + h.sha.Reset() + h.sha.Write(key) + h.sha.Read(hash) + ln := cache.cacheNubs[6].n1.(*shortNode) + hexKey := append(cache.cacheNubs[6].n1PrefixKey, ln.Key...) + assert.Equal(t, hash, hexToKeybytes(hexKey)) +} + +func TestMPTProofCache_ResolveKV_normalCase(t *testing.T) { + cache := makeMPTProofCache(nil, []string{ + "0xf90211a03697534056039e03300557bd69fe16e18ce4a6ccd5522db4dfa97dfe1fad3d3aa0b1bf1f230b98b9034738d599177ae817c08143b9395a47f300636b0dd2fb3c5ea0aa04a4966751d4c50063fe13a96a6c7924f665819733f556849b5eb9fa1d6839a0e162e080d1c12c59dc984fb2246d8ad61209264bee40d3fdd07c4ea4ff411b6aa0e5c3f2dde71bf303423f34674748567dcdf8379129653b8213f698468738d492a068a3e3059b6e7115a055a7874f81c5a1e84ddc1967527973f8c78cd86a1c9f8fa0d734bd63b7be8e8471091b792f5bbcbc7b0ce582f6d985b7a15a3c0155242c56a00143c06f57a65c8485dbae750aa51df5dff1bf7bdf28060129a20de9e51364eda07b416f79b3f4e39d0159efff351009d44002d9e83530fb5a5778eb55f5f4432ca036706b52196fa0b73feb2e7ff8f1379c7176d427dd44ad63c7b65e66693904a1a0fd6c8b815e2769ce379a20eaccdba1f145fb11f77c280553f15ee4f1ee135375a02f5233009f082177e5ed2bfa6e180bf1a7310e6bc3c079cb85a4ac6fee4ae379a03f07f1bb33fa26ebd772fa874914dc7a08581095e5159fdcf9221be6cbeb6648a097557eec1ac08c3bfe45ce8e34cd329164a33928ac83fef1009656536ef6907fa028196bfb31aa7f14a0a8000b00b0aa5d09450c32d537e45eebee70b14313ff1ca0126ce265ca7bbb0e0b01f068d1edef1544cbeb2f048c99829713c18d7abc049a80", + "0xf90211a01f7c019858f447dbff8ed8e4329e88600bab8f17fec9594c664e25acc95da3dba00916866833439bc250b5e08edc2b7c041634ca3b33a013f79eb299a3e33a056fa0a8fae921061f3bc81154b5d2c149d3860a3cdef00d02173ee3b5837de6b26f56a097583621a54e74994619a97cf82f823005a35ad1bf4795047726619eccd11ad4a0be39789c4abdb2a185cce40f2c77575eee2a1eab38d3168395560da15307614aa0c46a7fcea5501656d70508178f0731460edcce1c01e32ed08c1468f2593db277a0460f030038b09b8461d834ded79d77fd3ed25c4e248775752bf1c830a530ef2ba03d887abec623c6b4d93be75d1608dcacb291bf30406fbbc944a94aa4203fb1eba0475eeb07d471044af74313093cfbb0201e17a405d7af13a7ae6b245e18915515a0ea794035230f90e14f4f601d0e9f04217ed02710df363417bc3dd7dda5a84608a08aa9f0c44e5a9e359b65d4f0e40937662f07af10642a626e19ca7bc56c329706a05677c9b342c1cbcfd491cd491800d44423d1bdc08ab00eb59376a7118f7c4e23a0b761f22d67328e6c90caf8be65affb701b045f4f1f581472fc5f724e0c61328da0b6727240643009e59bba78249918a83ca8faeaf1d5b47fe0b41c356e8ae300dda0d7cbeb12faf439126bdb86f94b6262c3a70e88e4d8a47b49735f0b9d632a5df8a0428fe930556ce5ea94bcc4d092fe0f05a2a9b360175857711729ab3a7092a81780", + "0xf90211a0a81ff74945250ba9925753e379aa8815a3fe77926aad2d02db4d78a3da7ecb48a0e795c64d2b738b34ebb77a3307df800c9f1fe324b442cd71ffa5fd268ec12ffca020a5f968a1c8292d08135cb451daed999a44eb0cdd04526a89c38e753398c50ca0cc7165f6b984f2f2569d101d70f72af94eb79b18c126895da2d3cf557b5aca51a00a10f4dc5851e71f195cca5bd7a2a62a6c3ca03d95e91c7dc55e55f4cb726903a09ecaaf877b18a55ca4001fe46cc10389c94104abc2994cdfe5843f28814b119ca099f59b2b52ee9d9b44ab7123a86775dcfe6c50301dd7cf9c6fc6ea968c1d2a01a0fccda5c1489dd3268fcf2471f6a372fc3afe5eebecc0ba9fc3c023d85da26aeda0d730ea7aabdad2e5451e826726e53c86e6e805e46220bc7edd80bf3f7e467f96a074f3d767a84557aa9559b08b2d1f93d8205827c042d1ac616b8cf37e22de2beca03ca1fe479ea1e64c3bfeae9c603ef4f55c270de0bb4c79d045f67d3481ca2852a0bdcbc3d154db40b5faa1f6875f46d485b96bdfffb4513da631f9b302768faae7a096923f4f559c7b7f912587292cc865aba5934e9e75c9aa88f20de73e928c6c51a0f52ca9710977327dd407ba9d9f00b0075d7e312ee19686f1b53579585ff50132a076adb4cf98f9af261d1b2a147b9b2cbace1647eab537ca1a55e8d40d35775b74a0723fa2f72d6a983939bc9bf9ee22e05fc142437726450e5707c60155bc34ec0580", + "0xf90211a039a64fd9b31f3ea3bf9a991ca828eadb67cf9ae0ebbcdd195d297454699580e8a0c95d85a63beb9a02b56d032116d86fdf9a64065dd5d44e33acef674ae3ceb6f9a0d83ef07a99302abccdce1289d353a20494713f45d8edbd3c2e87f08788a878d9a063a19aee41fec40a98edcf4d60c759c509b512bfc5d9feae7de50c8c2d00eee7a029ce5c8c1ac9939cbf481274b8e6f5f24a430136dc5aab7deb489a9ce7db5a95a03b45c53d2e4f54e49a53eb298aee828f4803581ee60ca52940533b77c3e3fadfa082b8084dfc0337a49fd5d53ce107fa0bdcdf25ac7bbac017d7ac250b77c8764da0d4dab90004fd3bb36b2fa5f6e914a7c90820d119e5ba3ba8439270c6cfbd0a18a0bd9286c9248ae8a953d00aa9906f06b6f0364d0bb9fb615de04c9f8d5b5fd346a00955eccaa41a17fafb0ed66272b5183b3a30a973af7a43e0f25f0b640fd5df0ca0a7ba773602ace05991211770fc5555dd9c55e2518bcb005d1db022584a89132da0450b066588b44701992ab4a926d5dd185dd465abf1e51a3f60ab4a9f924cf85aa05eda0687636512339d0db99eade61c0d44358bb8027185296ad59b4cedd2935aa0cd18604dee296e5443b598e470a04efde04fbb0213c0fda8cc3af20dae2a1c34a01c0474c1bf15e4732d1c22a1303573c7ec8643f711354e853ab26038b2eebe25a081a0b2d329a9519375f71384a7130faafc3a1379db59bfb7db7135c9d27d9cc080", + "0xf901f1a0e14d9caa85464966f0371f427250e7bf2d86f6d41535b08ea79044391d6f0fe7a08bea233c92e4c1d05bd03d62cd93974cc29f6df54e6fa2574bf8dfb3af936a85a0dd4f4af9ab72d1bbbf59c286a731614f48bf3cdcef572cd819426fc2a30ae5d9a058465ea8e97b2f3a873646f9504aff111ab967abaea8ebb2fe91bf058ea162f3a06466c41eb770bae5c07c26cd692540e7f9af70fee2bc164e12f29083ebd2cca1a0fe0925da033ffb967ca1e1d6505c2ca740553916c3b9a205131d719b7921fb00a0973ecee958d1305f1b6f8159a9732e47f5fbd4121f60254ea44a8f028308150780a0f81717c7f8702ed39d89a34fc86827aab31db26b22d22ee399667e4a44081c95a0e276ec24ee74ee61bdbe92e8afa57813d59f26e0ad34f24d71f0627719f8a11ea00a8ec2f6922480890f9e67c62c1654dcccf1710a01a378f614e02a3785d42d7ea08c43cc0c6c690dd87cf42691e9ead799c5ddbcfc9458ebd054a46acecedbe9f7a0d3727ed077c60ed6ff1d9d5d20fdf2912a513942d0efee9ec2d97d33ab9b7f56a03f293b0c0f25b9aa2735c4e42b132008d8af2ecfaaefdc940dc94cf312f5a1dda04860bc5f19829bb277554927c5b6dbc0af93025aca49aed3be020a3080996a26a0e47475bf4f62364d8a0dd87f2c21d0b64ef4d9aaf5fc46d9b1e3af5b83f56d4580", + "0xf89180a0dd3524693059f48ce47c5dd82d0462953a5141c7abc5a63981637b71e0100bcf80a033ca98826ea65c1c4be16e1df12e78142d992bf1fc0189401632ceff3fcbbe7880a0e5daf803b0890d164e287fa43b51332286cddeb9b62280426037fc02dc2b4d5f8080a0bfa907c2e30720a07347d23cfa573dec8c9b3afa51a4a0e0783a481da0fcb1b38080808080808080", + "0xe79e208246cec5810061f4ff7efe1dcd6cb407d59abc3478830df04484584c868786d647b234389e", + }) + + err := cache.VerifyProof() + assert.NoError(t, err) + key := common.Hex2Bytes("95eea00c49d14a895954837cd876ffa8cfad96cbaacc40fc31d6df2c902528a8") + hash := make([]byte, common.HashLength) + h := newHasher(false) + h.sha.Reset() + h.sha.Write(key) + h.sha.Read(hash) + t.Log("expect:", hex.EncodeToString(hash)) + for i, nub := range cache.cacheNubs { + kvMap, err := nub.ResolveKV() + assert.NoError(t, err) + if i != 6 { + assert.Equal(t, 0, len(kvMap)) + continue + } + for k, v := range kvMap { + t.Log("k:", hex.EncodeToString([]byte(k)), "v:", hex.EncodeToString(v)) + } + enc := kvMap[string(hash)] + _, content, _, _ := rlp.Split(enc) + assert.Equal(t, common.Hex2Bytes("d647b234389e"), content) + } +} + +func makeMPTProofCache(key []byte, proofs []string) MPTProofCache { + + proof := make([][]byte, len(proofs)) + for i := range proofs { + proof[i] = common.Hex2Bytes(strings.TrimPrefix(proofs[i], "0x")) + } + return MPTProofCache{ + MPTProof: types.MPTProof{ + RootKeyHex: key, + Proof: proof, + }, + } +} diff --git a/trie/secure_trie.go b/trie/secure_trie.go index dd7598d893..99bdc9ea31 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -63,6 +63,25 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) { return &SecureTrie{trie: *trie}, nil } +// NewSecureWithShadowNodes it opens a trie with shadow nodes, it needs to know current epoch to check if expired +// it uses sndb to query or store shadow nodes, if you using NewSecure, it opens storage trie at epoch0. +func NewSecureWithShadowNodes(curEpoch types.StateEpoch, root common.Hash, db *Database, sndb ShadowNodeStorage) (*SecureTrie, error) { + if db == nil || sndb == nil { + panic("trie.NewSecure called without a database") + } + + rn, err := resolveRootNodeTrieDb(db, root) + if err != nil { + return nil, err + } + + trie, err := NewWithShadowNode(curEpoch, rn, db, sndb) + if err != nil { + return nil, err + } + return &SecureTrie{trie: *trie}, nil +} + // Get returns the value for key stored in the trie. // The value bytes must not be modified by the caller. func (t *SecureTrie) Get(key []byte) []byte { @@ -77,7 +96,12 @@ func (t *SecureTrie) Get(key []byte) []byte { // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryGet(key []byte) ([]byte, error) { - return t.trie.TryGet(t.hashKey(key)) + return t.trie.TryGet(t.HashKey(key)) +} + +func (t *SecureTrie) TryUpdateEpoch(key []byte) error { + _, err := t.trie.TryGetAndUpdateEpoch(t.HashKey(key)) + return err } // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not @@ -89,7 +113,7 @@ func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) { // TryUpdate account will abstract the write of an account to the // secure trie. func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { - hk := t.hashKey(key) + hk := t.HashKey(key) data, err := rlp.EncodeToBytes(acc) if err != nil { return err @@ -122,7 +146,7 @@ func (t *SecureTrie) Update(key, value []byte) { // // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryUpdate(key, value []byte) error { - hk := t.hashKey(key) + hk := t.HashKey(key) err := t.trie.TryUpdate(hk, value) if err != nil { return err @@ -141,7 +165,7 @@ func (t *SecureTrie) Delete(key []byte) { // TryDelete removes any existing value for key from the trie. // If a node was not found in the database, a MissingNodeError is returned. func (t *SecureTrie) TryDelete(key []byte) error { - hk := t.hashKey(key) + hk := t.HashKey(key) delete(t.getSecKeyCache(), string(hk)) return t.trie.TryDelete(hk) } @@ -155,6 +179,10 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte { return t.trie.db.preimage(common.BytesToHash(shaKey)) } +func (t *SecureTrie) Epoch() types.StateEpoch { + return t.trie.currentEpoch +} + // Commit writes all nodes and the secure hash pre-images to the trie's database. // Nodes are stored with their sha3 hash as the key. // @@ -205,10 +233,10 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { return t.trie.NodeIterator(start) } -// hashKey returns the hash of key as an ephemeral buffer. +// HashKey returns the hash of key as an ephemeral buffer. // The caller must not hold onto the return value because it will become -// invalid on the next call to hashKey or secKey. -func (t *SecureTrie) hashKey(key []byte) []byte { +// invalid on the next call to HashKey or secKey. +func (t *SecureTrie) HashKey(key []byte) []byte { hash := make([]byte, common.HashLength) h := newHasher(false) h.sha.Reset() @@ -228,3 +256,7 @@ func (t *SecureTrie) getSecKeyCache() map[string][]byte { } return t.secKeyCache } + +func (t *SecureTrie) ReviveTrie(proof []*MPTProofNub) []*MPTProofNub { + return t.trie.ReviveTrie(proof) +} diff --git a/trie/shadow_node.go b/trie/shadow_node.go new file mode 100644 index 0000000000..67c3b8baa9 --- /dev/null +++ b/trie/shadow_node.go @@ -0,0 +1,346 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/log" + + "github.com/ethereum/go-ethereum/rlp" + + "github.com/ethereum/go-ethereum/ethdb" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" +) + +const ( + ShadowTreeRootNodePath = "s" +) + +type rootNode struct { + Epoch types.StateEpoch + TrieRoot common.Hash + ShadowTreeRoot common.Hash + cachedHash common.Hash `rlp:"-" json:"-"` + cachedEnc []byte `rlp:"-" json:"-"` +} + +func newEpoch0RootNode(trieRoot common.Hash) *rootNode { + return newRootNode(types.StateEpoch0, trieRoot, emptyRoot) +} + +func newRootNode(epoch types.StateEpoch, trieRoot, shadowTreeRoot common.Hash) *rootNode { + n := &rootNode{ + Epoch: epoch, + TrieRoot: trieRoot, + ShadowTreeRoot: shadowTreeRoot, + } + n.resolveCache() + return n +} + +func (n *rootNode) copy() *rootNode { copy := *n; return © } + +func (n *rootNode) encode(w rlp.EncoderBuffer) { + rlp.Encode(w, n) +} + +func (n *rootNode) cache() (hashNode, bool) { return n.cachedHash[:], true } + +func (n *rootNode) setEpoch(epoch types.StateEpoch) { n.Epoch = epoch } +func (n *rootNode) getEpoch() types.StateEpoch { return n.Epoch } + +func (n *rootNode) String() string { return n.fstring("") } + +func (n *rootNode) fstring(ind string) string { + return fmt.Sprintf("rootNode{epoch:%d, trieRoot:%s, shadowTreeRoot:%s}", n.Epoch, n.TrieRoot, n.ShadowTreeRoot) +} + +func (n *rootNode) nodeType() int { + return rootNodeType +} + +func (n *rootNode) resolveCache() { + buf := rlp.NewEncoderBuffer(nil) + n.encode(buf) + n.cachedEnc = buf.ToBytes() + + // cache hash + h := newHasher(false) + h.sha.Reset() + h.sha.Write(n.cachedEnc) + h.sha.Read(n.cachedHash[:]) + returnHasherToPool(h) +} + +func DecodeRootNode(enc []byte) (*rootNode, error) { + n := &rootNode{} + if err := rlp.DecodeBytes(enc, n); err != nil { + return nil, err + } + n.resolveCache() + return n, nil +} + +type shadowExtensionNode struct { + ShadowHash *common.Hash +} + +func NewShadowExtensionNode(hash *common.Hash) shadowExtensionNode { + return shadowExtensionNode{ + ShadowHash: hash, + } +} + +func (n *shadowExtensionNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + if n.ShadowHash == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteBytes(n.ShadowHash[:]) + } + w.ListEnd(offset) +} + +func decodeShadowExtensionNode(enc []byte) (*shadowExtensionNode, error) { + var n shadowExtensionNode + elems, _, err := rlp.SplitList(enc) + if err != nil { + return nil, err + } + + sh, _, err := rlp.SplitString(elems) + if err != nil { + return nil, err + } + + if len(sh) == 0 { + n.ShadowHash = nil + } else { + hash := common.BytesToHash(sh) + n.ShadowHash = &hash + } + return &n, nil +} + +type shadowBranchNode struct { + ShadowHash *common.Hash + EpochMap [16]types.StateEpoch +} + +func NewShadowBranchNode(hash *common.Hash, epochMap [16]types.StateEpoch) shadowBranchNode { + return shadowBranchNode{hash, epochMap} +} +func (n *shadowBranchNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + if n.ShadowHash == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteBytes(n.ShadowHash[:]) + } + epochsList := w.List() + for _, e := range n.EpochMap { + w.WriteUint64(uint64(e)) + } + w.ListEnd(epochsList) + w.ListEnd(offset) +} + +func decodeShadowBranchNode(enc []byte) (*shadowBranchNode, error) { + var n shadowBranchNode + elems, _, err := rlp.SplitList(enc) + if err != nil { + return nil, err + } + + sh, rest, err := rlp.SplitString(elems) + if err != nil { + return nil, err + } + + if len(sh) == 0 { + n.ShadowHash = nil + } else { + hash := common.BytesToHash(sh) + n.ShadowHash = &hash + } + + if err = rlp.DecodeBytes(rest, &n.EpochMap); err != nil { + return nil, err + } + + return &n, nil +} + +type ShadowNodeStorage interface { + // Get key is the shadow node prefix path + Get(path string) ([]byte, error) + Put(path string, val []byte) error + Delete(path string) error +} + +type ShadowNodeDatabase interface { + Get(addr common.Hash, path string) ([]byte, error) + Delete(addr common.Hash, path string) error + Put(addr common.Hash, path string, val []byte) error + OpenStorage(addr common.Hash) ShadowNodeStorage + Commit(number *big.Int, blockRoot common.Hash) error +} + +type shadowNodeStorage4Trie struct { + addr common.Hash + db ShadowNodeDatabase +} + +func NewShadowNodeStorage4Trie(addr common.Hash, db ShadowNodeDatabase) ShadowNodeStorage { + return &shadowNodeStorage4Trie{ + addr: addr, + db: db, + } +} + +func (s *shadowNodeStorage4Trie) Get(path string) ([]byte, error) { + return s.db.Get(s.addr, path) +} + +func (s *shadowNodeStorage4Trie) Put(path string, val []byte) error { + return s.db.Put(s.addr, path, val) +} + +func (s *shadowNodeStorage4Trie) Delete(path string) error { + return s.db.Delete(s.addr, path) +} + +// ShadowNodeStorageRO shadow node only could modify the latest diff layers, +// if you want to modify older state, please unwind to thr older history +type ShadowNodeStorageRO struct { + diskdb ethdb.KeyValueStore + number *big.Int +} + +func (s *ShadowNodeStorageRO) Get(addr common.Hash, path string) ([]byte, error) { + return FindHistory(s.diskdb, s.number.Uint64()+1, addr, path) +} + +func (s *ShadowNodeStorageRO) Delete(addr common.Hash, path string) error { + return errors.New("ShadowNodeStorageRO unsupported") +} + +func (s *ShadowNodeStorageRO) Put(addr common.Hash, path string, val []byte) error { + return errors.New("ShadowNodeStorageRO unsupported") +} + +func (s *ShadowNodeStorageRO) OpenStorage(addr common.Hash) ShadowNodeStorage { + return NewShadowNodeStorage4Trie(addr, s) +} + +func (s *ShadowNodeStorageRO) Commit(number *big.Int, blockRoot common.Hash) error { + return errors.New("ShadowNodeStorageRO unsupported") +} + +type ShadowNodeStorageRW struct { + snap shadowNodeSnapshot + tree *ShadowNodeSnapTree + dirties map[common.Hash]map[string][]byte + + stale bool + lock sync.RWMutex +} + +// NewShadowNodeDatabase first find snap by blockRoot, if got nil, try using number to instance a read only storage +func NewShadowNodeDatabase(tree *ShadowNodeSnapTree, number *big.Int, blockRoot common.Hash) (ShadowNodeDatabase, error) { + snap := tree.Snapshot(blockRoot) + if snap == nil { + // try using default snap + if snap = tree.Snapshot(emptyRoot); snap == nil { + // open read only history + log.Debug("NewShadowNodeDatabase use RO database", "number", number, "root", blockRoot) + return &ShadowNodeStorageRO{ + diskdb: tree.DB(), + number: number, + }, nil + } + log.Debug("NewShadowNodeDatabase use default database", "number", number, "root", blockRoot) + } + return &ShadowNodeStorageRW{ + snap: snap, + tree: tree, + dirties: make(map[common.Hash]map[string][]byte), + }, nil +} + +func (s *ShadowNodeStorageRW) Get(addr common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + sub, exist := s.dirties[addr] + if exist { + if val, ok := sub[path]; ok { + return val, nil + } + } + + return s.snap.ShadowNode(addr, path) +} + +func (s *ShadowNodeStorageRW) Delete(addr common.Hash, path string) error { + s.lock.RLock() + defer s.lock.RUnlock() + if s.stale { + return errors.New("storage has staled") + } + _, ok := s.dirties[addr] + if !ok { + s.dirties[addr] = make(map[string][]byte) + } + + s.dirties[addr][path] = nil + return nil +} + +func (s *ShadowNodeStorageRW) Put(addr common.Hash, path string, val []byte) error { + prev, err := s.Get(addr, path) + if err != nil { + return err + } + if bytes.Equal(prev, val) { + return nil + } + + s.lock.RLock() + defer s.lock.RUnlock() + if s.stale { + return errors.New("storage has staled") + } + + _, ok := s.dirties[addr] + if !ok { + s.dirties[addr] = make(map[string][]byte) + } + s.dirties[addr][path] = val + return nil +} + +func (s *ShadowNodeStorageRW) OpenStorage(addr common.Hash) ShadowNodeStorage { + return NewShadowNodeStorage4Trie(addr, s) +} + +// Commit if you commit to an unknown parent, like deeper than 128 layers, will get error +func (s *ShadowNodeStorageRW) Commit(number *big.Int, blockRoot common.Hash) error { + s.lock.Lock() + defer s.lock.Unlock() + if s.stale { + return errors.New("storage has staled") + } + + s.stale = true + err := s.tree.Update(s.snap.Root(), number, blockRoot, s.dirties) + if err != nil { + return err + } + + return s.tree.Cap(blockRoot) +} diff --git a/trie/shadow_node_difflayer.go b/trie/shadow_node_difflayer.go new file mode 100644 index 0000000000..dc400e8886 --- /dev/null +++ b/trie/shadow_node_difflayer.go @@ -0,0 +1,636 @@ +package trie + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/big" + "sync" + + "github.com/ethereum/go-ethereum/log" + + "github.com/RoaringBitmap/roaring/roaring64" + "github.com/ethereum/go-ethereum/common/math" + lru "github.com/hashicorp/golang-lru" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + // MaxShadowNodeDiffDepth default is 128 layers + MaxShadowNodeDiffDepth = 128 + journalVersion uint64 = 1 + defaultDiskLayerCacheSize = 100000 +) + +// shadowNodeSnapshot record diff layer and disk layer of shadow nodes, support mini reorg +type shadowNodeSnapshot interface { + // Root block state root + Root() common.Hash + + // ShadowNode query shadow node from db, got RLP format + ShadowNode(addrHash common.Hash, path string) ([]byte, error) + + // Parent parent snap + Parent() shadowNodeSnapshot + + // Update create a new diff layer from here + Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) + + // Journal commit self as a journal to buffer + Journal(buffer *bytes.Buffer) (common.Hash, error) +} + +// ShadowNodeSnapTree maintain all diff layers support reorg, will flush to db when MaxShadowNodeDiffDepth reach +// every layer response to a block state change set, there no flatten layers operation. +type ShadowNodeSnapTree struct { + diskdb ethdb.KeyValueStore + + // diffLayers + diskLayer, disk layer, always not nil + layers map[common.Hash]shadowNodeSnapshot + children map[common.Hash][]common.Hash + + lock sync.RWMutex +} + +func NewShadowNodeSnapTree(diskdb ethdb.KeyValueStore, archiveMode bool) (*ShadowNodeSnapTree, error) { + diskLayer, err := loadDiskLayer(diskdb, archiveMode) + if err != nil { + return nil, err + } + layers, children, err := loadDiffLayers(diskdb, diskLayer) + if err != nil { + return nil, err + } + + layers[diskLayer.blockRoot] = diskLayer + // check if continuously after disk layer + if len(layers) > 1 && len(children[diskLayer.blockRoot]) == 0 { + return nil, errors.New("cannot found any diff layers link to disk layer") + } + return &ShadowNodeSnapTree{ + diskdb: diskdb, + layers: layers, + children: children, + }, nil +} + +// Cap keep tree depth not greater MaxShadowNodeDiffDepth, all forks parent to disk layer will delete +func (s *ShadowNodeSnapTree) Cap(blockRoot common.Hash) error { + snap := s.Snapshot(blockRoot) + if snap == nil { + return errors.New("snapshot missing") + } + nextDiff, ok := snap.(*shadowNodeDiffLayer) + if !ok { + return nil + } + for i := 0; i < MaxShadowNodeDiffDepth-1; i++ { + nextDiff, ok = nextDiff.Parent().(*shadowNodeDiffLayer) + // if depth less MaxShadowNodeDiffDepth, just return + if !ok { + return nil + } + } + + flatten := make([]shadowNodeSnapshot, 0) + parent := nextDiff.Parent() + for parent != nil { + flatten = append(flatten, parent) + parent = parent.Parent() + } + if len(flatten) <= 1 { + return nil + } + + last, ok := flatten[len(flatten)-1].(*shadowNodeDiskLayer) + if !ok { + return errors.New("the diff layers not link to disk layer") + } + + s.lock.Lock() + defer s.lock.Unlock() + newDiskLayer, err := s.flattenDiffs2Disk(flatten[:len(flatten)-1], last) + if err != nil { + return err + } + + // clear forks, but keep latest disk forks + for i := len(flatten) - 1; i > 0; i-- { + var childRoot common.Hash + if i > 0 { + childRoot = flatten[i-1].Root() + } else { + childRoot = nextDiff.Root() + } + root := flatten[i].Root() + s.removeSubLayers(s.children[root], &childRoot) + delete(s.layers, root) + delete(s.children, root) + } + + // reset newDiskLayer and children's parent + s.layers[newDiskLayer.Root()] = newDiskLayer + for _, child := range s.children[newDiskLayer.Root()] { + if diff, exist := s.layers[child].(*shadowNodeDiffLayer); exist { + diff.setParent(newDiskLayer) + } + } + return nil +} + +func (s *ShadowNodeSnapTree) Update(parentRoot common.Hash, blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) error { + // if there are no changes, just skip + if blockRoot == parentRoot { + return nil + } + + // Generate a new snapshot on top of the parent + parent := s.Snapshot(parentRoot) + if parent == nil { + // just point to fake disk layers + parent = s.Snapshot(emptyRoot) + if parent == nil { + return errors.New("cannot find any suitable parent") + } + parentRoot = parent.Root() + } + snap, err := parent.Update(blockNumber, blockRoot, nodeSet) + if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + + s.layers[blockRoot] = snap + s.children[parentRoot] = append(s.children[parentRoot], blockRoot) + return nil +} + +func (s *ShadowNodeSnapTree) Snapshot(blockRoot common.Hash) shadowNodeSnapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.layers[blockRoot] +} + +func (s *ShadowNodeSnapTree) DB() ethdb.KeyValueStore { + s.lock.RLock() + defer s.lock.RUnlock() + return s.diskdb +} + +func (s *ShadowNodeSnapTree) Journal() error { + s.lock.Lock() + defer s.lock.Unlock() + + // Firstly write out the metadata of journal + journal := new(bytes.Buffer) + if err := rlp.Encode(journal, journalVersion); err != nil { + return err + } + for _, snap := range s.layers { + if _, err := snap.Journal(journal); err != nil { + return err + } + } + rawdb.WriteShadowNodeSnapshotJournal(s.diskdb, journal.Bytes()) + return nil +} + +func (s *ShadowNodeSnapTree) removeSubLayers(layers []common.Hash, skip *common.Hash) { + for _, layer := range layers { + if skip != nil && layer == *skip { + continue + } + s.removeSubLayers(s.children[layer], nil) + delete(s.layers, layer) + delete(s.children, layer) + } +} + +// flattenDiffs2Disk delete all flatten and push them to db +func (s *ShadowNodeSnapTree) flattenDiffs2Disk(flatten []shadowNodeSnapshot, diskLayer *shadowNodeDiskLayer) (*shadowNodeDiskLayer, error) { + var err error + for i := len(flatten) - 1; i >= 0; i-- { + diskLayer, err = diskLayer.PushDiff(flatten[i].(*shadowNodeDiffLayer)) + if err != nil { + return nil, err + } + } + + return diskLayer, nil +} + +// loadDiskLayer load from db, could be nil when none in db +func loadDiskLayer(db ethdb.KeyValueStore, archiveMode bool) (*shadowNodeDiskLayer, error) { + val := rawdb.ReadShadowNodePlainStateMeta(db) + // if there is no disk layer, will construct a fake disk layer + if len(val) == 0 { + diskLayer, err := newShadowNodeDiskLayer(db, common.Big0, emptyRoot, archiveMode) + if err != nil { + return nil, err + } + return diskLayer, nil + } + var meta shadowNodePlainMeta + if err := rlp.DecodeBytes(val, &meta); err != nil { + return nil, err + } + + layer, err := newShadowNodeDiskLayer(db, meta.BlockNumber, meta.BlockRoot, archiveMode) + if err != nil { + return nil, err + } + return layer, nil +} + +func loadDiffLayers(db ethdb.KeyValueStore, diskLayer *shadowNodeDiskLayer) (map[common.Hash]shadowNodeSnapshot, map[common.Hash][]common.Hash, error) { + layers := make(map[common.Hash]shadowNodeSnapshot) + children := make(map[common.Hash][]common.Hash) + + journal := rawdb.ReadShadowNodeSnapshotJournal(db) + if len(journal) == 0 { + return layers, children, nil + } + r := rlp.NewStream(bytes.NewReader(journal), 0) + // Firstly, resolve the first element as the journal version + version, err := r.Uint64() + if err != nil { + return nil, nil, errors.New("failed to resolve journal version") + } + if version != journalVersion { + return nil, nil, errors.New("wrong journal version") + } + + parents := make(map[common.Hash]common.Hash) + for { + var ( + parent common.Hash + number big.Int + root common.Hash + js []journalShadowNode + ) + // Read the next diff journal entry + if err := r.Decode(&number); err != nil { + // The first read may fail with EOF, marking the end of the journal + if errors.Is(err, io.EOF) { + break + } + return nil, nil, fmt.Errorf("load diff number: %v", err) + } + if err := r.Decode(&parent); err != nil { + return nil, nil, fmt.Errorf("load diff parent: %v", err) + } + // Read the next diff journal entry + if err := r.Decode(&root); err != nil { + return nil, nil, fmt.Errorf("load diff root: %v", err) + } + if err := r.Decode(&js); err != nil { + return nil, nil, fmt.Errorf("load diff storage: %v", err) + } + + nodeSet := make(map[common.Hash]map[string][]byte) + for _, entry := range js { + nodes := make(map[string][]byte) + for i, key := range entry.Keys { + if len(entry.Vals[i]) > 0 { // RLP loses nil-ness, but `[]byte{}` is not a valid item, so reinterpret that + nodes[key] = entry.Vals[i] + } else { + nodes[key] = nil + } + } + nodeSet[entry.Hash] = nodes + } + + parents[root] = parent + layers[root] = newShadowNodeDiffLayer(&number, root, nil, nodeSet) + } + + for t, s := range layers { + parent := parents[t] + children[parent] = append(children[parent], t) + if p, ok := layers[parent]; ok { + s.(*shadowNodeDiffLayer).parent = p + } else if diskLayer != nil && parent == diskLayer.Root() { + s.(*shadowNodeDiffLayer).parent = diskLayer + } else { + return nil, nil, errors.New("cannot find it's parent") + } + } + return layers, children, nil +} + +type shadowNodeDiffLayer struct { + blockNumber *big.Int + blockRoot common.Hash + parent shadowNodeSnapshot + nodeSet map[common.Hash]map[string][]byte + + // TODO(0xbundler): add destruct handle later? + lock sync.RWMutex +} + +func newShadowNodeDiffLayer(blockNumber *big.Int, blockRoot common.Hash, parent shadowNodeSnapshot, nodeSet map[common.Hash]map[string][]byte) *shadowNodeDiffLayer { + return &shadowNodeDiffLayer{ + blockNumber: blockNumber, + blockRoot: blockRoot, + parent: parent, + nodeSet: nodeSet, + } +} + +func (s *shadowNodeDiffLayer) Root() common.Hash { + s.lock.RLock() + defer s.lock.RUnlock() + return s.blockRoot +} + +func (s *shadowNodeDiffLayer) ShadowNode(addrHash common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + cm, exist := s.nodeSet[addrHash] + if exist { + if ret, ok := cm[path]; ok { + return ret, nil + } + } + + return s.parent.ShadowNode(addrHash, path) +} + +func (s *shadowNodeDiffLayer) Parent() shadowNodeSnapshot { + s.lock.RLock() + defer s.lock.RUnlock() + return s.parent +} + +// Update append new diff layer onto current, nodeChgRecord when val is []byte{}, it delete the kv +func (s *shadowNodeDiffLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { + s.lock.RLock() + if s.blockNumber.Cmp(blockNumber) >= 0 { + return nil, errors.New("update a unordered diff layer") + } + s.lock.RUnlock() + return newShadowNodeDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *shadowNodeDiffLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + if err := rlp.Encode(buffer, s.blockNumber); err != nil { + return common.Hash{}, err + } + + if s.parent != nil { + if err := rlp.Encode(buffer, s.parent.Root()); err != nil { + return common.Hash{}, err + } + } else { + if err := rlp.Encode(buffer, emptyRoot); err != nil { + return common.Hash{}, err + } + } + + if err := rlp.Encode(buffer, s.blockRoot); err != nil { + return common.Hash{}, err + } + storage := make([]journalShadowNode, 0, len(s.nodeSet)) + for hash, nodes := range s.nodeSet { + keys := make([]string, 0, len(nodes)) + vals := make([][]byte, 0, len(nodes)) + for key, val := range nodes { + keys = append(keys, key) + vals = append(vals, val) + } + storage = append(storage, journalShadowNode{Hash: hash, Keys: keys, Vals: vals}) + } + if err := rlp.Encode(buffer, storage); err != nil { + return common.Hash{}, err + } + return s.blockRoot, nil +} + +func (s *shadowNodeDiffLayer) setParent(parent shadowNodeSnapshot) { + s.lock.Lock() + defer s.lock.Unlock() + s.parent = parent +} + +func (s *shadowNodeDiffLayer) getNodeSet() map[common.Hash]map[string][]byte { + s.lock.Lock() + defer s.lock.Unlock() + return s.nodeSet +} + +type journalShadowNode struct { + Hash common.Hash + Keys []string + Vals [][]byte +} + +type shadowNodePlainMeta struct { + BlockNumber *big.Int + BlockRoot common.Hash +} + +type shadowNodeDiskLayer struct { + // TODO(0xbundler): add history & changeSet later + diskdb ethdb.KeyValueStore + blockNumber *big.Int + blockRoot common.Hash + cache *lru.Cache + archiveMode bool // archiveMode, if true keep all history, if false just flatten all changSet to plainState + + lock sync.RWMutex +} + +func newShadowNodeDiskLayer(diskdb ethdb.KeyValueStore, blockNumber *big.Int, blockRoot common.Hash, archiveMode bool) (*shadowNodeDiskLayer, error) { + cache, err := lru.New(defaultDiskLayerCacheSize) + if err != nil { + return nil, err + } + return &shadowNodeDiskLayer{ + diskdb: diskdb, + blockNumber: blockNumber, + blockRoot: blockRoot, + cache: cache, + archiveMode: archiveMode, + }, nil +} + +func (s *shadowNodeDiskLayer) Root() common.Hash { + s.lock.RLock() + defer s.lock.RUnlock() + return s.blockRoot +} + +func (s *shadowNodeDiskLayer) ShadowNode(addr common.Hash, path string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + cacheKey := shadowNodeCacheKey(addr, path) + cached, exist := s.cache.Get(cacheKey) + if exist { + return cached.([]byte), nil + } + + val, err := FindHistory(s.diskdb, s.blockNumber.Uint64()+1, addr, path) + if err != nil { + return nil, err + } + + s.cache.Add(cacheKey, val) + return val, err +} + +func (s *shadowNodeDiskLayer) Parent() shadowNodeSnapshot { + return nil +} + +func (s *shadowNodeDiskLayer) Update(blockNumber *big.Int, blockRoot common.Hash, nodeSet map[common.Hash]map[string][]byte) (shadowNodeSnapshot, error) { + s.lock.RLock() + if s.blockNumber.Cmp(blockNumber) >= 0 { + return nil, errors.New("update a unordered diff layer") + } + s.lock.RUnlock() + return newShadowNodeDiffLayer(blockNumber, blockRoot, s, nodeSet), nil +} + +func (s *shadowNodeDiskLayer) Journal(buffer *bytes.Buffer) (common.Hash, error) { + return common.Hash{}, nil +} + +func (s *shadowNodeDiskLayer) PushDiff(diff *shadowNodeDiffLayer) (*shadowNodeDiskLayer, error) { + s.lock.Lock() + defer s.lock.Unlock() + + number := diff.blockNumber + if s.blockNumber.Cmp(number) >= 0 { + return nil, errors.New("push a lower block to disk") + } + batch := s.diskdb.NewBatch() + nodeSet := diff.getNodeSet() + if err := s.writeHistory(number, batch, diff.getNodeSet()); err != nil { + return nil, err + } + + // update meta + meta := shadowNodePlainMeta{ + BlockNumber: number, + BlockRoot: diff.blockRoot, + } + enc, err := rlp.EncodeToBytes(meta) + if err != nil { + return nil, err + } + if err = rawdb.WriteShadowNodePlainStateMeta(batch, enc); err != nil { + return nil, err + } + + if err = batch.Write(); err != nil { + return nil, err + } + diskLayer := &shadowNodeDiskLayer{ + diskdb: s.diskdb, + blockNumber: number, + blockRoot: diff.blockRoot, + cache: s.cache, + archiveMode: s.archiveMode, + } + + // reuse cache + for addr, nodes := range nodeSet { + for path, val := range nodes { + diskLayer.cache.Add(shadowNodeCacheKey(addr, path), val) + } + } + return diskLayer, nil +} + +func (s *shadowNodeDiskLayer) writeHistory(number *big.Int, batch ethdb.Batch, nodeSet map[common.Hash]map[string][]byte) error { + // if not in archiveMode, just flatten to plainState + if !s.archiveMode { + for addr, subSet := range nodeSet { + for path, val := range subSet { + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { + return err + } + } else { + if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { + return err + } + } + } + } + log.Info("shadow node history pruned, only keep plainState", "number", number, "count", len(nodeSet)) + return nil + } + + for addr, subSet := range nodeSet { + changeSet := make([]nodeChgRecord, 0, len(subSet)) + for path, val := range subSet { + if err := refreshShadowNodeHistory(s.diskdb, batch, addr, path, number.Uint64()); err != nil { + return err + } + prev := rawdb.ReadShadowNodePlainState(s.diskdb, addr, path) + // refresh plain state + if len(val) == 0 { + if err := rawdb.DeleteShadowNodePlainState(batch, addr, path); err != nil { + return err + } + } else { + if err := rawdb.WriteShadowNodePlainState(batch, addr, path, val); err != nil { + return err + } + } + + changeSet = append(changeSet, nodeChgRecord{ + Path: path, + Prev: prev, + }) + } + enc, err := rlp.EncodeToBytes(changeSet) + if err != nil { + return err + } + if err = rawdb.WriteShadowNodeChangeSet(batch, addr, number.Uint64(), enc); err != nil { + return err + } + } + + return nil +} + +func shadowNodeCacheKey(addr common.Hash, path string) string { + key := make([]byte, len(addr)+len(path)) + copy(key[:], addr.Bytes()) + copy(key[len(addr):], path) + return string(key) +} + +func refreshShadowNodeHistory(db ethdb.KeyValueReader, batch ethdb.Batch, addr common.Hash, path string, number uint64) error { + enc := rawdb.ReadShadowNodeHistory(db, addr, path, math.MaxUint64) + index := roaring64.New() + if len(enc) > 0 { + if _, err := index.ReadFrom(bytes.NewReader(enc)); err != nil { + return err + } + } + index.Add(number) + enc, err := index.ToBytes() + if err != nil { + return err + } + if err = rawdb.WriteShadowNodeHistory(batch, addr, path, math.MaxUint64, enc); err != nil { + return err + } + return nil +} diff --git a/trie/shadow_node_difflayer_test.go b/trie/shadow_node_difflayer_test.go new file mode 100644 index 0000000000..e91f103a00 --- /dev/null +++ b/trie/shadow_node_difflayer_test.go @@ -0,0 +1,265 @@ +package trie + +import ( + "math/big" + "strconv" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +var ( + blockRoot0 = makeHash("b0") + blockRoot1 = makeHash("b1") + blockRoot2 = makeHash("b2") + blockRoot3 = makeHash("b3") + storageRoot0 = makeHash("s0") + storageRoot1 = makeHash("s1") + storageRoot2 = makeHash("s2") + storageRoot3 = makeHash("s3") + contract1 = makeHash("c1") + contract2 = makeHash("c2") + contract3 = makeHash("c3") +) + +func TestShadowNodeDiffLayer_whenGenesis(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap = tree.Snapshot(blockRoot1) + assert.Nil(t, snap) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, makeNodeSet(contract1, []string{"hello", "world"})) + assert.NoError(t, err) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, makeNodeSet(contract1, []string{"hello2", "world2"})) + assert.NoError(t, err) + err = tree.Cap(blockRoot1) + assert.NoError(t, err) + err = tree.Journal() + assert.NoError(t, err) + + // reload + tree, err = NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + diskLayer := tree.Snapshot(emptyRoot) + assert.NotNil(t, diskLayer) + snap = tree.Snapshot(blockRoot0) + assert.Nil(t, snap) + snap1 := tree.Snapshot(blockRoot1) + n, err := snap1.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + assert.Equal(t, diskLayer, snap1.Parent()) + assert.Equal(t, blockRoot1, snap1.Root()) + + // read from child + snap2 := tree.Snapshot(blockRoot2) + assert.Equal(t, snap1, snap2.Parent()) + assert.Equal(t, blockRoot2, snap2.Root()) + n, err = snap2.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), n) + n, err = snap2.ShadowNode(contract1, "hello2") + assert.NoError(t, err) + assert.Equal(t, []byte("world2"), n) +} + +func TestShadowNodeDiffLayer_crud(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + set1 := makeNodeSet(contract1, []string{"hello", "world", "h1", "w1"}) + appendNodeSet(set1, contract3, []string{"h3", "w3"}) + err = tree.Update(blockRoot0, common.Big1, blockRoot1, set1) + assert.NoError(t, err) + set2 := makeNodeSet(contract1, []string{"hello", "", "h1", ""}) + appendNodeSet(set2, contract2, []string{"hello", "", "h2", "w2"}) + err = tree.Update(blockRoot1, common.Big2, blockRoot2, set2) + assert.NoError(t, err) + snap := tree.Snapshot(blockRoot1) + assert.NotNil(t, snap) + val, err := snap.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + val, err = snap.ShadowNode(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte("w1"), val) + val, err = snap.ShadowNode(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) + + snap = tree.Snapshot(blockRoot2) + assert.NotNil(t, snap) + val, err = snap.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract1, "h1") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract2, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + val, err = snap.ShadowNode(contract2, "h2") + assert.NoError(t, err) + assert.Equal(t, []byte("w2"), val) + val, err = snap.ShadowNode(contract3, "h3") + assert.NoError(t, err) + assert.Equal(t, []byte("w3"), val) +} + +func TestShadowNodeDiffLayer_capDiffLayers(t *testing.T) { + diskdb := memorydb.New() + // create empty tree + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + + // push 200 diff layers + count := 1 + for i := 0; i < 200; i++ { + ns := strconv.Itoa(count) + root := makeHash("b" + ns) + parent := makeHash("b" + strconv.Itoa(count-1)) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, + root, makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 10 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + + err = tree.Cap(root) + assert.NoError(t, err) + count++ + } + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 100; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + // add 20 forks + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(parent, number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + for j := 0; j < 10; j++ { + fs := strconv.Itoa(j) + err = tree.Update(makeHash("b"+strconv.Itoa(count-1)+"f"+fs), number, + makeHash("b"+ns+"f"+fs), makeNodeSet(contract1, []string{"hello" + ns + "f" + fs, "world" + ns + "f" + fs})) + assert.NoError(t, err) + } + count++ + } + lastRoot := makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + assert.Equal(t, 1409, len(tree.layers)) + + // push 100 diff layers, and cap + for i := 0; i < 129; i++ { + ns := strconv.Itoa(count) + parent := makeHash("b" + strconv.Itoa(count-1)) + root := makeHash("b" + ns) + number := new(big.Int).SetUint64(uint64(count)) + err = tree.Update(parent, number, root, + makeNodeSet(contract1, []string{"hello" + ns, "world" + ns})) + assert.NoError(t, err) + + count++ + } + lastRoot = makeHash("b" + strconv.Itoa(count-1)) + err = tree.Cap(lastRoot) + assert.NoError(t, err) + + assert.Equal(t, 129, len(tree.layers)) + assert.Equal(t, 128, len(tree.children)) + for parent, children := range tree.children { + if tree.layers[parent] == nil { + t.Log(tree.layers[parent]) + } + assert.NotNil(t, tree.layers[parent]) + for _, child := range children { + if tree.layers[child] == nil { + t.Log(tree.layers[child]) + } + assert.NotNil(t, tree.layers[child]) + } + } + + snap := tree.Snapshot(lastRoot) + assert.NotNil(t, snap) + for i := 1; i < count; i++ { + ns := strconv.Itoa(i) + n, err := snap.ShadowNode(contract1, "hello"+ns) + assert.NoError(t, err) + assert.Equal(t, []byte("world"+ns), n) + } + + // store + err = tree.Journal() + assert.NoError(t, err) +} + +func makeHash(s string) common.Hash { + var ret common.Hash + if len(s) >= 32 { + copy(ret[:], []byte(s)[:hashLen]) + return ret + } + for i := 0; i < hashLen; i++ { + ret[i] = '0' + } + copy(ret[hashLen-len(s):hashLen], s) + return ret +} + +func makeNodeSet(addr common.Hash, kvs []string) map[common.Hash]map[string][]byte { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + ret := make(map[common.Hash]map[string][]byte) + ret[addr] = make(map[string][]byte) + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } + + return ret +} + +func appendNodeSet(ret map[common.Hash]map[string][]byte, addr common.Hash, kvs []string) { + if len(kvs)%2 != 0 { + panic("makeNodeSet: wrong params") + } + if _, ok := ret[addr]; !ok { + ret[addr] = make(map[string][]byte) + } + for i := 0; i < len(kvs); i += 2 { + if len(kvs) == 0 { + ret[addr][kvs[i]] = nil + continue + } + ret[addr][kvs[i]] = []byte(kvs[i+1]) + } +} diff --git a/trie/shadow_node_history.go b/trie/shadow_node_history.go new file mode 100644 index 0000000000..6457e88ca3 --- /dev/null +++ b/trie/shadow_node_history.go @@ -0,0 +1,86 @@ +package trie + +import ( + "bytes" + "errors" + + "github.com/RoaringBitmap/roaring/roaring64" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" +) + +func FindHistory(db ethdb.KeyValueStore, number uint64, addr common.Hash, path string) ([]byte, error) { + val, found, err := findHistory(db, number, addr, path) + if err != nil { + return nil, err + } + + if found { + return val, nil + } + + // query from plain state + return rawdb.ReadShadowNodePlainState(db, addr, path), nil +} + +type nodeChgRecord struct { + Path string + Prev []byte +} + +func findHistory(db ethdb.KeyValueStore, number uint64, addr common.Hash, path string) ([]byte, bool, error) { + // TODO(0xbundler): split shards according bitmap size later, less than 1mb + hbytes := rawdb.ReadShadowNodeHistory(db, addr, path, math.MaxUint64) + if len(hbytes) == 0 { + return nil, false, nil + } + + index := roaring64.New() + if _, err := index.ReadFrom(bytes.NewReader(hbytes)); err != nil { + return nil, false, err + } + found, ok := SeekInBitmap64(index, number) + if !ok { + return nil, false, nil + } + + // TODO(0xbundler): using mdbx's DupSort? + changeSet := rawdb.ReadShadowNodeChangeSet(db, addr, found) + if len(changeSet) == 0 { + return nil, false, errors.New("cannot find target changeSet") + } + var ns []nodeChgRecord + if err := rlp.DecodeBytes(changeSet, &ns); err != nil { + return nil, false, err + } + nodeSetMap := make(map[string][]byte) + for _, n := range ns { + nodeSetMap[n.Path] = n.Prev + } + + val, exist := nodeSetMap[path] + if !exist { + return nil, false, errors.New("cannot find path's change val") + } + + return val, true, nil +} + +// SeekInBitmap - returns value in bitmap which is >= n +func SeekInBitmap64(m *roaring64.Bitmap, n uint64) (found uint64, ok bool) { + if m.IsEmpty() { + return 0, false + } + if n == 0 { + return m.Minimum(), true + } + searchRank := m.Rank(n - 1) + if searchRank >= m.GetCardinality() { + return 0, false + } + found, _ = m.Select(searchRank) + return found, true +} diff --git a/trie/shadow_node_history_test.go b/trie/shadow_node_history_test.go new file mode 100644 index 0000000000..fcd1ecb48c --- /dev/null +++ b/trie/shadow_node_history_test.go @@ -0,0 +1,81 @@ +package trie + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func TestShadowNodeHistory_Diff2Disk(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb, true) + assert.NoError(t, err) + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + _, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + // find history + val, err := FindHistory(diskdb, common.Big2.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + val, err = FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte{}, val) + + // reload disk layer + diskLayer, err = loadDiskLayer(diskdb, true) + assert.NoError(t, err) + val, err = diskLayer.ShadowNode(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) +} + +func TestShadowNodeHistory_case2(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb, true) + assert.NoError(t, err) + + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + diff = newShadowNodeDiffLayer(common.Big2, blockRoot2, nil, makeNodeSet(contract1, []string{"hello", "world1"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err := FindHistory(diskdb, common.Big2.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + + val, err = FindHistory(diskdb, common.Big3.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world1"), val) +} + +func TestShadowNodeHistory_disableArchive(t *testing.T) { + diskdb := memorydb.New() + diskLayer, err := loadDiskLayer(diskdb, false) + assert.NoError(t, err) + + diff := newShadowNodeDiffLayer(common.Big1, blockRoot1, nil, makeNodeSet(contract1, []string{"hello", "world"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + diff = newShadowNodeDiffLayer(common.Big2, blockRoot2, nil, makeNodeSet(contract1, []string{"hello", "world1"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err := FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world1"), val) + + diff = newShadowNodeDiffLayer(common.Big3, blockRoot3, nil, makeNodeSet(contract1, []string{"hello", "world2"})) + diskLayer, err = diskLayer.PushDiff(diff) + assert.NoError(t, err) + + val, err = FindHistory(diskdb, common.Big1.Uint64(), contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world2"), val) +} diff --git a/trie/shadow_node_test.go b/trie/shadow_node_test.go new file mode 100644 index 0000000000..667cc291ee --- /dev/null +++ b/trie/shadow_node_test.go @@ -0,0 +1,245 @@ +package trie + +import ( + "bytes" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/core/types" + + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/rlp" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/ethdb/memorydb" + "github.com/stretchr/testify/assert" +) + +func TestShadowNodeRW_CRUD(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + + err = storageDB.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + err = storageDB.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + val, err := storageDB.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) + err = storageDB.Delete(contract1, "hello") + assert.NoError(t, err) + val, err = storageDB.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) +} + +func TestShadowNodeRO_Get(t *testing.T) { + diskdb := memorydb.New() + makeDiskLayer(diskdb, common.Big2, blockRoot2, contract1, []string{"k1", "v1"}) + + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + storageRO, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + + err = storageRO.Put(contract1, "hello", []byte("world")) + assert.Error(t, err) + err = storageRO.Delete(contract1, "hello") + assert.Error(t, err) + err = storageRO.Commit(common.Big2, blockRoot2) + assert.Error(t, err) + + val, err := storageRO.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) + val, err = storageRO.Get(contract1, "k1") + assert.NoError(t, err) + assert.Equal(t, []byte("v1"), val) +} + +func makeDiskLayer(diskdb *memorydb.Database, number *big.Int, root common.Hash, addr common.Hash, kv []string) { + if len(kv)%2 != 0 { + panic("wrong kv") + } + meta := shadowNodePlainMeta{ + BlockNumber: number, + BlockRoot: root, + } + enc, _ := rlp.EncodeToBytes(&meta) + rawdb.WriteShadowNodePlainStateMeta(diskdb, enc) + + for i := 0; i < len(kv); i += 2 { + rawdb.WriteShadowNodePlainState(diskdb, addr, kv[i], []byte(kv[i+1])) + } +} + +func TestShadowNodeRW_Commit(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + + err = storageDB.Put(contract1, "hello", []byte("world")) + assert.NoError(t, err) + + err = storageDB.Commit(common.Big1, blockRoot1) + assert.NoError(t, err) + + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + val, err := storageDB.Get(contract1, "hello") + assert.NoError(t, err) + assert.Equal(t, []byte("world"), val) +} + +func TestNewShadowNodeStorage4Trie(t *testing.T) { + diskdb := memorydb.New() + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + storageDB, err := NewShadowNodeDatabase(tree, common.Big1, blockRoot1) + assert.NoError(t, err) + + s1 := storageDB.OpenStorage(contract1) + s2 := storageDB.OpenStorage(contract2) + val, err := s1.Get("hello") + assert.NoError(t, err) + assert.Equal(t, []byte(nil), val) + err = s1.Put("hello", []byte("world")) + assert.NoError(t, err) + val, _ = s1.Get("hello") + assert.Equal(t, []byte("world"), val) + val, _ = s2.Get("hello") + assert.Equal(t, []byte(nil), val) + err = s1.Delete("hello") + assert.NoError(t, err) + val, _ = s1.Get("hello") + assert.Equal(t, []byte(nil), val) + + s2.Put("h2", []byte("w2")) + val, _ = s2.Get("h2") + assert.Equal(t, []byte("w2"), val) + + err = storageDB.Commit(common.Big1, blockRoot2) + assert.NoError(t, err) +} + +func TestShadowExtendNode_encodeDecode(t *testing.T) { + dt := []struct { + n shadowExtensionNode + }{ + { + n: shadowExtensionNode{ + ShadowHash: nil, + }, + }, + { + n: shadowExtensionNode{ + ShadowHash: &blockRoot0, + }, + }, + { + n: shadowExtensionNode{ + ShadowHash: &blockRoot1, + }, + }, + } + for _, item := range dt { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := decodeShadowExtensionNode(enc) + assert.NoError(t, err) + assert.Equal(t, &item.n, rn) + } +} + +func TestShadowBranchNode_encodeDecode(t *testing.T) { + dt := []struct { + n shadowBranchNode + }{ + { + n: shadowBranchNode{ + ShadowHash: nil, + EpochMap: [16]types.StateEpoch{}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: nil, + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: &blockRoot0, + EpochMap: [16]types.StateEpoch{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + { + n: shadowBranchNode{ + ShadowHash: &blockRoot1, + EpochMap: [16]types.StateEpoch{}, + }, + }, + } + for _, item := range dt { + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := decodeShadowBranchNode(enc) + assert.NoError(t, err) + assert.Equal(t, &item.n, rn) + } +} + +func TestRootNode_encodeDecode(t *testing.T) { + dt := []struct { + n rootNode + isEqual bool + }{ + { + n: rootNode{ + Epoch: 10, + TrieRoot: blockRoot0, + ShadowTreeRoot: blockRoot1, + }, + isEqual: true, + }, + { + n: rootNode{}, + isEqual: true, + }, + { + n: rootNode{ + Epoch: 100, + TrieRoot: blockRoot2, + ShadowTreeRoot: common.Hash{}, + }, + isEqual: true, + }, + { + n: rootNode{}, + }, + } + + for _, item := range dt { + item.n.resolveCache() + buf := rlp.NewEncoderBuffer(bytes.NewBuffer([]byte{})) + item.n.encode(buf) + enc := buf.ToBytes() + + rn, err := DecodeRootNode(enc) + assert.NoError(t, err) + if !item.isEqual { + assert.NotEqual(t, item.n, rn) + continue + } + assert.Equal(t, &item.n, rn) + } +} diff --git a/trie/stacktrie.go b/trie/stacktrie.go index ec278d390c..0cbc02aa47 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -392,15 +392,15 @@ func (st *StackTrie) hashRec(hasher *hasher) { var nodes rawFullNode for i, child := range st.children { if child == nil { - nodes[i] = nilValueNode + nodes.children[i] = nilValueNode continue } child.hashRec(hasher) if len(child.val) < 32 { - nodes[i] = rawNode(child.val) + nodes.children[i] = rawNode(child.val) } else { - nodes[i] = hashNode(child.val) + nodes.children[i] = hashNode(child.val) } // Release child back to pool. diff --git a/trie/trie.go b/trie/trie.go index 878f4cd3a3..5614eed7f2 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -61,11 +61,21 @@ type LeafCallback func(paths [][]byte, hexpath []byte, leaf []byte, parent commo // Trie is not safe for concurrent use. type Trie struct { db *Database + sndb ShadowNodeStorage // only storage trie using it, account trie needn't the shadow node root node // Keep track of the number leafs which have been inserted since the last // hashing operation. This number will not directly map to the number of // actually unhashed nodes unhashed int + + // fields for shadow tree & state epoch + withShadowNodes bool + currentEpoch types.StateEpoch + + // fields for rootNode + shadowTreeRoot common.Hash + rootEpoch types.StateEpoch + trieRoot common.Hash } // newFlag returns the cache flag value for a newly created node. @@ -96,6 +106,49 @@ func New(root common.Hash, db *Database) (*Trie, error) { return trie, nil } +func NewWithShadowNode(curEpoch types.StateEpoch, rootNode *rootNode, db *Database, sndb ShadowNodeStorage) (*Trie, error) { + if db == nil || sndb == nil { + panic("trie.New called without a database") + } + if curEpoch < rootNode.Epoch { + return nil, errors.New("open trie at a wrong epoch") + } + + useShadowTree := false + // only enable after first state expiry's hard fork + if curEpoch > types.StateEpoch0 { + useShadowTree = true + log.Debug("withShadowNodes trie open", "rootNodeHash", rootNode.cachedHash, "RootHash", rootNode.TrieRoot, "ShadowTreeRoot", rootNode.ShadowTreeRoot) + } + + trie := &Trie{ + db: db, + sndb: sndb, + currentEpoch: curEpoch, + withShadowNodes: useShadowTree, + shadowTreeRoot: rootNode.ShadowTreeRoot, + rootEpoch: rootNode.Epoch, + trieRoot: rootNode.TrieRoot, + } + + if rootNode.TrieRoot != (common.Hash{}) && rootNode.TrieRoot != emptyRoot { + // check root if expired, if expired just set as hashNode + if types.EpochExpired(rootNode.Epoch, curEpoch) { + trie.root = hashNode(rootNode.TrieRoot.Bytes()) + return trie, nil + } + root, err := trie.resolveHash(rootNode.TrieRoot[:], nil) + if err != nil { + return nil, err + } + if err = trie.resolveShadowNode(rootNode.Epoch, root, nil); err != nil { + return nil, err + } + trie.root = root + } + return trie, nil +} + // NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at // the key after the given start key. func (t *Trie) NodeIterator(start []byte) NodeIterator { @@ -115,8 +168,24 @@ func (t *Trie) Get(key []byte) []byte { // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. -func (t *Trie) TryGet(key []byte) ([]byte, error) { - value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0) +func (t *Trie) TryGet(key []byte) (value []byte, err error) { + var newroot node + var didResolve bool + if t.withShadowNodes { + value, newroot, didResolve, err = t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), false) + } else { + value, newroot, didResolve, err = t.tryGet(t.root, keybytesToHex(key), 0) + } + + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + +func (t *Trie) TryGetAndUpdateEpoch(key []byte) ([]byte, error) { + value, newroot, didResolve, err := t.tryGetWithEpoch(t.root, keybytesToHex(key), 0, t.getRootEpoch(), true) + if err == nil && didResolve { t.root = newroot } @@ -154,6 +223,65 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } value, newnode, _, err := t.tryGet(child, key, pos) return value, newnode, true, err + case *rootNode: + return nil, n, false, nil // TODO(asyukii): temporary fix + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + +func (t *Trie) tryGetWithEpoch(origNode node, key []byte, pos int, epoch types.StateEpoch, updateEpoch bool) (value []byte, newnode node, didResolve bool, err error) { + if t.epochExpired(origNode, epoch) { + return nil, nil, false, NewExpiredNodeError(key[:pos], epoch) + } + + switch n := (origNode).(type) { + case nil: + return nil, nil, false, nil + case valueNode: + return n, n, false, nil + case *shortNode: + if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + // key not found in trie + return nil, n, false, nil + } + + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Val, key, pos+len(n.Key), t.currentEpoch, updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { + n = n.copy() + n.Val = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + didResolve = true + } + return value, n, didResolve, err + case *fullNode: + value, newnode, didResolve, err = t.tryGetWithEpoch(n.Children[key[pos]], key, pos+1, n.GetChildEpoch(int(key[pos])), updateEpoch) + if err == nil && t.renewNode(epoch, didResolve, updateEpoch) { + n = n.copy() + n.Children[key[pos]] = newnode + if updateEpoch { + n.setEpoch(t.currentEpoch) + } + if updateEpoch && newnode != nil { + n.UpdateChildEpoch(int(key[pos]), t.currentEpoch) + } + didResolve = true + } + + return value, n, didResolve, err + case hashNode: + child, err := t.resolveHash(n, key[:pos]) + if err != nil { + return nil, n, true, err + } + if err = t.resolveShadowNode(epoch, child, key[:pos]); err != nil { + return nil, nil, false, err + } + + value, newnode, _, err = t.tryGetWithEpoch(child, key, pos, epoch, updateEpoch) + return value, newnode, true, err default: panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) } @@ -267,14 +395,15 @@ func (t *Trie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { func (t *Trie) TryUpdate(key, value []byte) error { t.unhashed++ k := keybytesToHex(key) + rootEpoch := t.getRootEpoch() if len(value) != 0 { - _, n, err := t.insert(t.root, nil, k, valueNode(value)) + _, n, err := t.insert(t.root, nil, k, valueNode(value), rootEpoch) if err != nil { return err } t.root = n } else { - _, n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k, rootEpoch) if err != nil { return err } @@ -283,7 +412,11 @@ func (t *Trie) TryUpdate(key, value []byte) error { return nil } -func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { +func (t *Trie) insert(n node, prefix, key []byte, value node, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch) + } + if len(key) == 0 { if v, ok := n.(valueNode); ok { return !bytes.Equal(v, value.(valueNode)), value, nil @@ -296,42 +429,51 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) - if !dirty || err != nil { + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value, n.epoch) + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } - return true, &shortNode{n.Key, nn, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: nn, flags: t.newFlag(), epoch: t.currentEpoch}, nil } // Otherwise branch out at the index where they differ. branch := &fullNode{flags: t.newFlag()} var err error - _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val, t.currentEpoch) if err != nil { return false, nil, err } - _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value, t.currentEpoch) if err != nil { return false, nil, err } + if t.withShadowNodes { + branch.setEpoch(t.currentEpoch) + branch.UpdateChildEpoch(int(n.Key[matchlen]), t.currentEpoch) + branch.UpdateChildEpoch(int(key[matchlen]), t.currentEpoch) + } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil + return true, &shortNode{Key: key[:matchlen], Val: branch, flags: t.newFlag(), epoch: t.currentEpoch}, nil case *fullNode: - dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) - if !dirty || err != nil { + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value, n.GetChildEpoch(int(key[0]))) + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn + if t.withShadowNodes { + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } return true, n, nil case nil: - return true, &shortNode{key, value, t.newFlag()}, nil + return true, &shortNode{Key: key, Val: value, flags: t.newFlag(), epoch: t.currentEpoch}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load @@ -341,7 +483,11 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if err != nil { return false, nil, err } - dirty, nn, err := t.insert(rn, prefix, key, value) + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return false, nil, err + } + + dirty, nn, err := t.insert(rn, prefix, key, value, epoch) if !dirty || err != nil { return false, rn, err } @@ -364,7 +510,7 @@ func (t *Trie) Delete(key []byte) { func (t *Trie) TryDelete(key []byte) error { t.unhashed++ k := keybytesToHex(key) - _, n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k, t.getRootEpoch()) if err != nil { return err } @@ -375,7 +521,11 @@ func (t *Trie) TryDelete(key []byte) error { // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { +func (t *Trie) delete(n node, prefix, key []byte, epoch types.StateEpoch) (bool, node, error) { + if t.epochExpired(n, epoch) { + return false, nil, NewExpiredNodeError(prefix, epoch) + } + switch n := n.(type) { case *shortNode: matchlen := prefixLen(key, n.Key) @@ -389,8 +539,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) - if !dirty || err != nil { + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):], n.epoch) + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } switch child := child.(type) { @@ -401,19 +551,25 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil + return true, &shortNode{Key: concat(n.Key, child.Key...), Val: child.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil default: - return true, &shortNode{n.Key, child, t.newFlag()}, nil + return true, &shortNode{Key: n.Key, Val: child, flags: t.newFlag(), epoch: t.currentEpoch}, nil } case *fullNode: - dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) - if !dirty || err != nil { + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:], n.GetChildEpoch(int(key[0]))) + if !t.renewNode(epoch, dirty, true) || err != nil { return false, n, err } n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn + if t.withShadowNodes { + n.setEpoch(t.currentEpoch) + } + if t.withShadowNodes && nn != nil { + n.UpdateChildEpoch(int(key[0]), t.currentEpoch) + } // Because n is a full node, it must've contained at least two children // before the delete operation. If the new child value is non-nil, n still @@ -457,12 +613,12 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } if cnode, ok := cnode.(*shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return true, &shortNode{k, cnode.Val, t.newFlag()}, nil + return true, &shortNode{Key: k, Val: cnode.Val, flags: t.newFlag(), epoch: t.currentEpoch}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil + return true, &shortNode{Key: []byte{byte(pos)}, Val: n.Children[pos], flags: t.newFlag(), epoch: t.currentEpoch}, nil } // n still contains at least two values and cannot be reduced. return true, n, nil @@ -481,7 +637,11 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { if err != nil { return false, nil, err } - dirty, nn, err := t.delete(rn, prefix, key) + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return false, nil, err + } + + dirty, nn, err := t.delete(rn, prefix, key, epoch) if !dirty || err != nil { return false, rn, err } @@ -492,6 +652,75 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { } } +// ExpireByPrefix is used to simulate the expiration of a trie by prefix key. +// It is not used in the actual trie implementation. ExpireByPrefix makes sure +// only a child node of a full node is expired, if not an error is returned. +func (t *Trie) ExpireByPrefix(prefixKeyHex []byte) error { + hn, _, err := t.expireByPrefix(t.root, prefixKeyHex) + if prefixKeyHex == nil && hn != nil { + t.root = hn + } + if err != nil { + return err + } + return nil +} + +func (t *Trie) expireByPrefix(n node, prefixKeyHex []byte) (node, bool, error) { + // Loop through prefix key + // When prefix key is empty, generate the hash node of the current node + // Replace current node with the hash node + + // If length of prefix key is empty + if len(prefixKeyHex) == 0 { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + var hn node + _, hn = hasher.proofHash(n) + if _, ok := hn.(hashNode); ok { + return hn, false, nil + } + + return nil, true, nil + } + + switch n := n.(type) { + case *shortNode: + matchLen := prefixLen(prefixKeyHex, n.Key) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Val, prefixKeyHex[matchLen:]) + if err != nil { + return nil, didUpdateEpoch, err + } + + if hn != nil { + return nil, didUpdateEpoch, fmt.Errorf("can only expire child short node") + } + + return nil, didUpdateEpoch, err + case *fullNode: + childIndex := int(prefixKeyHex[0]) + hn, didUpdateEpoch, err := t.expireByPrefix(n.Children[childIndex], prefixKeyHex[1:]) + if err != nil { + return nil, didUpdateEpoch, err + } + + // Replace child node with hash node + if hn != nil { + n.Children[prefixKeyHex[0]] = hn + } + + // Update the epoch so that it is expired + if !didUpdateEpoch { + n.UpdateChildEpoch(childIndex, 0) + didUpdateEpoch = true + } + + return nil, didUpdateEpoch, err + default: + return nil, false, fmt.Errorf("invalid node type") + } +} + func concat(s1 []byte, s2 ...byte) []byte { r := make([]byte, len(s1)+len(s2)) copy(r, s1) @@ -508,6 +737,9 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) { func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { hash := common.BytesToHash(n) + if t.db == nil { + return nil, fmt.Errorf("empty trie database") + } if node := t.db.node(hash); node != nil { return node, nil } @@ -519,7 +751,22 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { func (t *Trie) Hash() common.Hash { hash, cached, _ := t.hashRoot() t.root = cached - return common.BytesToHash(hash.(hashNode)) + newRootHash := common.BytesToHash(hash.(hashNode)) + t.trieRoot = newRootHash + if t.withShadowNodes { + newShadowTreeRoot := emptyRoot + shadowTreeRoot, err := t.ShadowHash() + if err != nil { + panic(fmt.Sprintf("trie hash err, when ShadowHash, err %v", err)) + } + // replace shadowTreeRoot for rootNode + if shadowTreeRoot != nil { + newShadowTreeRoot = *shadowTreeRoot + } + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) + return rn.cachedHash + } + return newRootHash } // Commit writes all nodes to the trie's memory database, tracking the internal @@ -533,7 +780,25 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { } // Derive the hash for all dirty nodes first. We hold the assumption // in the following procedure that all nodes are hashed. - rootHash := t.Hash() + hash, cached, _ := t.hashRoot() + t.root = cached + newRootHash := common.BytesToHash(hash.(hashNode)) + newShadowTreeRoot := emptyRoot + if t.withShadowNodes { + shadowTreeRoot, err := t.ShadowHash() + if err != nil { + return common.Hash{}, 0, err + } + // replace shadowTreeRoot for rootNode + if shadowTreeRoot != nil { + newShadowTreeRoot = *shadowTreeRoot + } + // commit shadow nodes after ShadowHash in Commit + if err = t.commitShadowNodes(t.root, nil, t.getRootEpoch()); err != nil { + return common.Hash{}, 0, err + } + } + h := newCommitter() defer returnCommitterToPool(h) @@ -541,7 +806,14 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { // up goroutines. This can happen e.g. if we load a trie for reading storage // values, but don't write to it. if _, dirty := t.root.cache(); !dirty { - return rootHash, 0, nil + if t.withShadowNodes { + rootNodeHash, err := t.storeRootNodeTrieDb(h, newRootHash, newShadowTreeRoot) + if err != nil { + return common.Hash{}, 0, err + } + return rootNodeHash, 0, nil + } + return newRootHash, 0, nil } var wg sync.WaitGroup if onleaf != nil { @@ -565,8 +837,21 @@ func (t *Trie) Commit(onleaf LeafCallback) (common.Hash, int, error) { if err != nil { return common.Hash{}, 0, err } + if t.withShadowNodes { + rootNodeHash, err := t.storeRootNodeTrieDb(h, newRootHash, newShadowTreeRoot) + if err != nil { + return common.Hash{}, 0, err + } + // update root & root node + t.rootEpoch = t.getRootEpoch() + t.shadowTreeRoot = newShadowTreeRoot + t.trieRoot = newRootHash + t.root = newRoot + log.Debug("withShadowNodes trie commit", "rootNodeHash", rootNodeHash, "newRootHash", newRootHash, "newShadowTreeRoot", newShadowTreeRoot) + return rootNodeHash, committed, nil + } t.root = newRoot - return rootHash, committed, nil + return newRootHash, committed, nil } // hashRoot calculates the root hash of the given trie @@ -577,9 +862,9 @@ func (t *Trie) hashRoot() (node, node, error) { // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) defer returnHasherToPool(h) - hashed, cached := h.hash(t.root, true) + newRootHash, cached := h.hash(t.root, true) t.unhashed = 0 - return hashed, cached, nil + return newRootHash, cached, nil } // Reset drops the referenced root node and cleans all internal state. @@ -591,3 +876,408 @@ func (t *Trie) Reset() { func (t *Trie) Size() int { return estimateSize(t.root) } + +// ReviveTrie attempts to revive a trie from a list of MPTProofNubs. +// ReviveTrie performs full or partial revive and returns a list of successful +// nubs. ReviveTrie does not guarantee that a value will be revived completely, +// if the proof is not fully valid. +func (t *Trie) ReviveTrie(proof []*MPTProofNub) (successNubs []*MPTProofNub) { + successNubs, err := t.TryRevive(proof) + if err != nil { + log.Error(fmt.Sprintf("Failed to revive trie: %v", err)) + } + return successNubs +} + +func (t *Trie) TryRevive(proof []*MPTProofNub) (successNubs []*MPTProofNub, err error) { + + // Revive trie with each proof nub + for _, nub := range proof { + path := []byte{} + rootExpired := types.EpochExpired(t.getRootEpoch(), t.currentEpoch) + newNode, didRevive, err := t.tryRevive(t.root, nub.n1PrefixKey, *nub, t.currentEpoch, path, rootExpired) + if err != nil { + log.Error("tryRevive err", "prefix", nub.n1PrefixKey, "didRevive", didRevive, "err", err) + } + if didRevive && err == nil { + successNubs = append(successNubs, nub) + t.root = newNode + } + } + + // If no nubs were successful, return error + if len(successNubs) == 0 && len(proof) != 0 { + return successNubs, fmt.Errorf("all nubs failed to revive trie") + } + + return successNubs, nil +} + +func (t *Trie) tryRevive(n node, key []byte, nub MPTProofNub, epoch types.StateEpoch, path []byte, isExpired bool) (node, bool, error) { + + // To revive a node, few conditions must be met: + // 1. key length must be 0, indicating that the targeted node is reached + // 2. the node must be expired + // 3. the node must be a hash node + // 4. the node hash must match the hash value of nub + if len(key) == 0 { + if !isExpired { + return nil, false, fmt.Errorf("key %v not found", key) + } + + hn, ok := n.(hashNode) + if !ok { + return nil, false, fmt.Errorf("node is not a hash node") + } + + cachedHash, _ := nub.n1.cache() + if !bytes.Equal(cachedHash, hn) { + return nil, false, fmt.Errorf("hash values does not match") + } + + tryUpdateNodeEpoch(nub.n1, t.currentEpoch) + if nub.n2 != nil { + tryUpdateNodeEpoch(nub.n2, t.currentEpoch) + if n1, ok := nub.n1.(*shortNode); ok { // n2 can only be followed by a short node + n1.Val = nub.n2 + } else { + return nil, false, fmt.Errorf("invalid node type") + } + } + return nub.n1, true, nil + } + + if isExpired { // the node is expired but targeted node is not reached + return nil, false, NewExpiredNodeError(path, 0) // Set default value, will change later + } + + switch n := n.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(key[:len(n.Key)], n.Key) { + return nil, false, fmt.Errorf("key %v not found", key) + } + newNode, didRevive, err := t.tryRevive(n.Val, key[len(n.Key):], nub, epoch, append(path, key[:len(n.Key)]...), isExpired) + if didRevive && err == nil { + n = n.copy() + n.Val = newNode + n.setEpoch(t.currentEpoch) + } + return n, didRevive, err + case *fullNode: + childIndex := int(key[0]) + isExpired, _ := n.ChildExpired(nil, childIndex, t.currentEpoch) // TODO (asyukii): t.currentEpoch or t.root.getEpoch()? + newNode, didRevive, err := t.tryRevive(n.Children[childIndex], key[1:], nub, epoch, append(path, key[0]), isExpired) + if didRevive && err == nil { + n = n.copy() + n.Children[childIndex] = newNode + n.setEpoch(t.currentEpoch) + n.UpdateChildEpoch(childIndex, t.currentEpoch) + } + + if e, ok := err.(*ExpiredNodeError); ok { + e.Epoch = n.GetChildEpoch(childIndex) + return n, didRevive, e + } + + return n, didRevive, err + case hashNode: + tn, err := t.resolveHash(n, path) // TODO(asyukii): may need to copy resolved hash node + if err != nil { + return nil, false, err + } + if err = t.resolveShadowNode(epoch, tn, path); err != nil { + return nil, false, err + } + return t.tryRevive(tn, key, nub, epoch, path, isExpired) + case valueNode: + return nil, false, nil + case nil: + return nil, false, nil + default: + panic(fmt.Sprintf("invalid node: %T", n)) + } +} + +func (t *Trie) resolveShadowNode(epoch types.StateEpoch, origin node, prefix []byte) error { + if !t.withShadowNodes { + return nil + } + + if t.sndb == nil { + return errors.New("cannot resolve shadow node") + } + + switch n := origin.(type) { + case *shortNode: + n.setEpoch(epoch) + n.shadowNode.ShadowHash = nil + return t.resolveShadowNode(epoch, n.Val, safeAppendBytes(prefix, n.Key...)) + case *fullNode: + n.setEpoch(epoch) + val, err := t.sndb.Get(string(hexToSuffixCompact(prefix))) + if err != nil { + return err + } + if len(val) == 0 { + // set default epoch map + n.shadowNode.EpochMap = [16]types.StateEpoch{} + n.shadowNode.ShadowHash = nil + } else { + tmp, decErr := decodeShadowBranchNode(val) + if decErr != nil { + return decErr + } + n.shadowNode = *tmp + } + for i := byte(0); i < BranchNodeLength-1; i++ { + if err := t.resolveShadowNode(n.shadowNode.EpochMap[i], n.Children[i], safeAppendBytes(prefix, i)); err != nil { + return err + } + } + return nil + case valueNode, hashNode, nil: + // just skip + return nil + default: + return errors.New("resolveShadowNode unsupported node type") + } +} + +func (t *Trie) ShadowHash() (*common.Hash, error) { + if t.root == nil { + return nil, nil + } + h := newHasher(true) + defer returnHasherToPool(h) + sh, _, err := t.shadowHash(t.root, h, nil, t.getRootEpoch()) + return sh, err +} + +// shadowHash calculate node's shadow node hash, recalculate needn't a copy +// epoch param only use in resolve hashNode +func (t *Trie) shadowHash(origin node, h *hasher, prefix []byte, epoch types.StateEpoch) (*common.Hash, types.StateEpoch, error) { + switch n := origin.(type) { + case *shortNode: + var err error + if n.shadowNode.ShadowHash, _, err = t.shadowHash(n.Val, h, append(prefix, n.Key...), n.epoch); err != nil { + return nil, types.StateEpoch0, err + } + return n.shadowNode.ShadowHash, n.epoch, nil + case *fullNode: + epochSelf := n.epoch + epochMap := n.shadowNode.EpochMap + hashList := make([]*common.Hash, 0, BranchNodeLength-1) + for i := byte(0); i < BranchNodeLength-1; i++ { + child := n.Children[i] + if child == nil { + continue + } + // skip expired node. + if epochSelf >= epochMap[i]+2 { + n.shadowNode.EpochMap[i] = 0 + continue + } + + subHash, subEpoch, err := t.shadowHash(child, h, append(prefix, i), epochMap[i]) + if err != nil { + return nil, types.StateEpoch0, err + } + if subHash != nil { + hashList = append(hashList, subHash) + } + n.shadowNode.EpochMap[i] = subEpoch + } + n.shadowNode.ShadowHash = h.shadowNodeHashListToHash(hashList) + return h.shadowBranchNodeToHash(&n.shadowNode), n.epoch, nil + case valueNode: + return nil, epoch, nil + case hashNode: + if t.db == nil || t.sndb == nil { + return nil, types.StateEpoch0, errors.New("ShadowHash db or sndb is nil") + } + // resolve temporary, not add to trie + rn, err := t.resolveHash(n, prefix) + if err != nil { + return nil, types.StateEpoch0, err + } + if err = t.resolveShadowNode(epoch, rn, prefix); err != nil { + return nil, types.StateEpoch0, err + } + return t.shadowHash(rn, h, prefix, epoch) + default: + return nil, types.StateEpoch0, errors.New("cannot get shortNode's child shadow node") + } +} + +// commitShadowNodes commit node's shadow node, must call after shadowHash +func (t *Trie) commitShadowNodes(origin node, prefix []byte, epoch types.StateEpoch) error { + if t.sndb == nil { + return errors.New("ShadowHash sndb is nil") + } + switch n := origin.(type) { + case *shortNode: + if err := t.commitShadowNodes(n.Val, append(prefix, n.Key...), epoch); err != nil { + return err + } + return nil + case *fullNode: + epochSelf := n.epoch + epochMap := n.shadowNode.EpochMap + for i := byte(0); i < BranchNodeLength-1; i++ { + child := n.Children[i] + if child == nil { + continue + } + // skip expired node. + if epochSelf >= epochMap[i]+2 { + continue + } + + if err := t.commitShadowNodes(child, append(prefix, i), epochMap[i]); err != nil { + return err + } + } + + encBuf := rlp.NewEncoderBuffer(nil) + n.shadowNode.encode(encBuf) + if err := t.sndb.Put(string(hexToSuffixCompact(prefix)), encBuf.ToBytes()); err != nil { + return err + } + return nil + case valueNode, hashNode: + return nil + default: + return errors.New("cannot get shortNode's child shadow node") + } +} + +func (t *Trie) storeRootNodeTrieDb(c *committer, newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) + hn, _, err := c.Commit(rn, t.db) + return common.BytesToHash(hn), err +} + +func (t *Trie) storeRootNode(newRootHash, newShadowTreeRoot common.Hash) (common.Hash, error) { + rn := newRootNode(t.getRootEpoch(), newRootHash, newShadowTreeRoot) + if err := t.sndb.Put(ShadowTreeRootNodePath, rn.cachedEnc); err != nil { + return common.Hash{}, err + } + return rn.cachedHash, nil +} + +// getRootEpoch parse root and resolve its epoch, when root is hashNode, +// the epoch should be t.rootEpoch, not default types.StateEpoch0 +func (t *Trie) getRootEpoch() types.StateEpoch { + ret := t.rootEpoch + switch n := t.root.(type) { + case *shortNode: + ret = n.getEpoch() + case *fullNode: + ret = n.getEpoch() + } + + return ret +} + +func (t *Trie) epochExpired(n node, epoch types.StateEpoch) bool { + // when node is nil, skip epoch check + if !t.withShadowNodes || n == nil { + return false + } + return types.EpochExpired(epoch, t.currentEpoch) +} + +// renewNode check if renew node, according to trie node epoch and childDirty, +// childDirty or updateEpoch need copy for prevent reuse trie cache +func (t *Trie) renewNode(epoch types.StateEpoch, childDirty bool, updateEpoch bool) bool { + // when !updateEpoch, it same as !t.withShadowNodes + if !t.withShadowNodes || !updateEpoch { + return childDirty + } + + // when no epoch update, same as before + if epoch == t.currentEpoch { + return childDirty + } + + // node need update epoch, just renew + return true +} + +func resolveRootNodeTrieDb(db *Database, root common.Hash) (*rootNode, error) { + expectHash := common.BytesToHash(root[:]) + rn := db.node(root) + n, ok := rn.(*rootNode) + if !ok { + return newEpoch0RootNode(root), nil + } + + if n.cachedHash != expectHash { + return nil, errors.New("found the wrong rootNode") + } + return n, nil +} + +func resolveRootNode(sndb ShadowNodeStorage, root common.Hash) (*rootNode, error) { + expectHash := common.BytesToHash(root[:]) + val, err := sndb.Get(ShadowTreeRootNodePath) + if err != nil { + return nil, err + } + if len(val) == 0 { + return newEpoch0RootNode(root), nil + } + n, err := DecodeRootNode(val) + if err != nil { + return nil, err + } + + if n.cachedHash != expectHash { + return nil, errors.New("found the wrong rootNode") + } + return n, nil +} + +// SUFFIX-COMPACT encoding is used for encoding trie node path in the trie node +// storage key. The main difference with COMPACT encoding is that the key flag +// is put at the end of the key. +// +// e.g. +// - the key [] is encoded as [0x00] +// - the key [0x1, 0x2, 0x3] is encoded as [0x12, 0x31] +// - the key [0x1, 0x2, 0x3, 0x0] is encoded as [0x12, 0x30, 0x00] +// +// The main benefit of this format is the continuous paths can retain the shared +// path prefix after encoding. +func hexToSuffixCompact(hex []byte) []byte { + terminator := byte(0) + if hasTerm(hex) { + terminator = 1 + hex = hex[:len(hex)-1] + } + buf := make([]byte, len(hex)/2+1) + buf[len(buf)-1] = terminator << 1 // the flag byte + if len(hex)&1 == 1 { + buf[len(buf)-1] |= 1 // odd flag + buf[len(buf)-1] |= hex[len(hex)-1] << 4 // last nibble is contained in the last byte + hex = hex[:len(hex)-1] + } + decodeNibbles(hex, buf[:len(buf)-1]) + return buf +} + +func tryUpdateNodeEpoch(origin node, epoch types.StateEpoch) { + switch n := origin.(type) { + case *shortNode: + n.setEpoch(epoch) + case *fullNode: + n.setEpoch(epoch) + } +} + +func safeAppendBytes(src []byte, added ...byte) []byte { + ret := make([]byte, len(src)+len(added)) + copy(ret, src) + copy(ret[len(src):], added) + return ret +} diff --git a/trie/trie_test.go b/trie/trie_test.go index 63aed333db..656cae81a9 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -19,6 +19,8 @@ package trie import ( "bytes" "encoding/binary" + + // "encoding/hex" "errors" "fmt" "hash" @@ -38,6 +40,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb/leveldb" "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/rlp" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/sha3" ) @@ -473,6 +476,560 @@ func TestRandom(t *testing.T) { } } +// TestExpireByPrefix tests that the trie is not corrupted after +// expiring a key by prefix. +func TestExpireByPrefix(t *testing.T) { + data := map[string]string{ + "abcd": "A", + "abce": "B", + "abde": "C", + "abdf": "D", + "defg": "E", + "defh": "F", + "degh": "G", + "degi": "H", + } + + trie := createCustomTrie(data, 0) + rootHash := trie.Hash() + + for k := range data { + prefixKeys := getPrefixKeysHex(trie, []byte(k)) + for _, prefixKey := range prefixKeys { + trie.ExpireByPrefix(prefixKey) + currHash := trie.Hash() + // Validate root hash + assert.Equal(t, rootHash, currHash, "Root hash mismatch, got %x, expected %x", currHash, rootHash) + + // Reset trie + trie = createCustomTrie(data, 0) + } + } +} + +func createCustomTrie(data map[string]string, epoch types.StateEpoch) *Trie { + trie := new(Trie) + trie.currentEpoch = epoch + for k, v := range data { + trie.Update([]byte(k), []byte(v)) + } + + return trie +} + +type proofList [][]byte + +func (n *proofList) Put(key []byte, value []byte) error { + *n = append(*n, value) + return nil +} + +func (n *proofList) Delete(key []byte) error { + panic("not supported") +} + +func makeRawMPTProofCache(rootKeyHex []byte, proof [][]byte) MPTProofCache { + return MPTProofCache{ + MPTProof: types.MPTProof{ + RootKeyHex: rootKeyHex, + Proof: proof, + }, + } +} + +func getFullNodePrefixKeys(t *Trie, key []byte) [][]byte { + var prefixKeys [][]byte + key = keybytesToHex(key) + tn := t.root + currPath := []byte{} + for len(key) > 0 && tn != nil { + switch n := tn.(type) { + case *shortNode: + if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + // The trie doesn't contain the key. + tn = nil + } else { + tn = n.Val + prefixKeys = append(prefixKeys, currPath) + currPath = append(currPath, n.Key...) + key = key[len(n.Key):] + } + case *fullNode: + tn = n.Children[key[0]] + currPath = append(currPath, key[0]) + key = key[1:] + case hashNode: + var err error + tn, err = t.resolveHash(n, nil) + if err != nil { + return nil + } + default: + return nil + } + } + + // Remove the first item in prefixKeys, which is the empty key + if len(prefixKeys) > 0 { + prefixKeys = prefixKeys[1:] + } + + return prefixKeys +} + +// TestTryRevive tests that a trie can be revived from a proof +func TestTryRevive(t *testing.T) { + + trie, vals := nonRandomTrieWithShadowNodes(500) + + oriRootHash := trie.Hash() + + for _, kv := range vals { + key := kv.k + val := kv.v + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err, "TryRevive failed, key %x, prefixKey %x, val %x", key, prefixKey, val) + + // Verify value exists after revive + v := trie.Get(key) + assert.Equal(t, val, v, "value mismatch, got %x, expected %x. key %x, prefixKey %x", v, val, key, prefixKey) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x", currRootHash, oriRootHash) + + // Reset trie + trie, _ = nonRandomTrieWithShadowNodes(500) + } + } +} + +// TestTryReviveCustomData tests that a trie can be revived from a proof +func TestTryReviveCustomData(t *testing.T) { + + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data, 10) + + oriRootHash := trie.Hash() + + for k, v := range data { + key := []byte(k) + val := []byte(v) + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + // Generate proof + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + trie.ExpireByPrefix(prefixKey) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err, "verify proof failed, key %x, prefixKey %x", key, prefixKey) + + // Revive trie + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err, "try revive failed, key %x, prefixKey %x", key, prefixKey) + + // Verify value exists after revive + v := trie.Get(key) + assert.Equal(t, val, v) + + // Verify root hash + currRootHash := trie.Hash() + assert.Equal(t, oriRootHash, currRootHash, "Root hash mismatch, got %x, expected %x, key %x, prefixKey %x", currRootHash, oriRootHash, key, prefixKey) + + // Reset trie + trie = createCustomTrie(data, 10) + } + } +} + +// TestReviveBadProof tests that a trie cannot be revived from a bad proof +func TestReviveBadProof(t *testing.T) { + + dataA := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + dataB := map[string]string{ + "qwer": "A", "qwet": "B", "qwrt": "C", "qwry": "D", + "abcd": "E", "abce": "F", "abde": "G", "abdf": "H", + } + + trieA := createCustomTrie(dataA, 0) + trieB := createCustomTrie(dataB, 0) + + var proofB proofList + + err := trieB.ProveStorageWitness([]byte("abcd"), nil, &proofB) + assert.NoError(t, err) + + // Expire trie A + trieA.ExpireByPrefix(nil) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proofB) + + // VerifyProof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trieA.TryRevive(proofCache.cacheNubs) + assert.Error(t, err) + + // Verify value does exists after revive + _, err = trieA.TryGet([]byte("abcd")) + assert.Error(t, err) + +} + +// TestReviveBadProofAfterUpdate tests that after reviving a path and +// then update the value, old proof should be invalid +func TestReviveBadProofAfterUpdate(t *testing.T) { + trie, vals := nonRandomTrieWithShadowNodes(500) + for _, kv := range vals { + key := kv.k + prefixKeys := getFullNodePrefixKeys(trie, key) + for _, prefixKey := range prefixKeys { + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix(prefixKey) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive first + trie.TryRevive(proofCache.cacheNubs) + + // Update value + trie.Update(key, []byte("new value")) + + // Revive again with old proof + trie.TryRevive(proofCache.cacheNubs) + + // Validate trie + resVal, err := trie.TryGet(key) + assert.NoError(t, err) + assert.Equal(t, []byte("new value"), resVal) + } + } +} + +// TestPartialReviveFullProof tests that a path can be revived +// with full proof even when the trie is partially expired +func TestPartialReviveFullProof(t *testing.T) { + data := map[string]string{ + "abcd": "A", "abce": "B", "abde": "C", "abdf": "D", + "defg": "E", "defh": "F", "degh": "G", "degi": "H", + } + + trie := createCustomTrie(data, 10) + + // Get proof + var proof proofList + err := trie.ProveStorageWitness([]byte("abcd"), nil, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix([]byte{6, 1}) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(nil, proof) + + // Verify proof + err = proofCache.VerifyProof() + assert.NoError(t, err) + + // Revive trie + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.TryGet([]byte("abcd")) + assert.NoError(t, err) + assert.Equal(t, []byte("A"), resVal) +} + +// TestReviveValueAtFullNode tests that a value that is at +// full node can be revived properly +func TestReviveValueAtFullNode(t *testing.T) { + hexKeys := [][]byte{ + {6, 1, 6, 2, 6, 3, 6, 4, 16}, + {6, 1, 6, 2, 6, 3, 6, 5, 16}, + {6, 1, 6, 2, 6, 4, 6, 5, 16}, + {6, 1, 6, 2, 6, 4, 6, 6, 16}, + {6, 4, 6, 5, 6, 6, 6, 7, 16}, + {6, 4, 6, 5, 6, 6, 6, 8, 16}, + {6, 4, 6, 5, 6, 7, 6, 8, 16}, + {6, 4, 6, 5, 6, 7, 6, 9, 16}, + {6, 1, 6, 2, 6, 3, 6, 4, 16}, + {6, 1, 6, 2, 16}, // This is the key that has a value at a full node + } + + byteKeys := make([][]byte, len(hexKeys)) + + vals := []string{ + "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", + } + + for i, hexKey := range hexKeys { + hexKey = hexToKeybytes(hexKey) + byteKeys[i] = hexKey + } + + // Insert keys into trie + trie := new(Trie) + trie.currentEpoch = 10 + for i, hexKey := range byteKeys { + trie.Update(hexKey, []byte(vals[i])) + } + + key := byteKeys[9] + val := vals[9] + + prefixKeys := getFullNodePrefixKeys(trie, key) + + for _, prefixKey := range prefixKeys { + var proof proofList + err := trie.ProveStorageWitness(key, prefixKey, &proof) + assert.NoError(t, err) + + // Expire trie + err = trie.ExpireByPrefix(prefixKey) + assert.NoError(t, err) + + // Construct MPTProofCache + proofCache := makeRawMPTProofCache(prefixKey, proof) + + err = proofCache.VerifyProof() + assert.NoError(t, err) + + _, err = trie.TryRevive(proofCache.cacheNubs) + assert.NoError(t, err) + + // Validate trie + resVal, err := trie.TryGet(key) + assert.NoError(t, err) + assert.Equal(t, []byte(val), resVal) + } +} + +// TODO add testing trie epoch update & expired +// case1: when meet err, update do not update epoch +// case2: when child is nil, do not check its epoch, default is 0 +// case3: when access/update/delete, update epoch correct, when node is not expired, then update it later +// case3: when access/update/delete, check expired node correct +// case4: root node epoch update and check correct +// case5: when node is expired, do not resolve it, safe for revive + +func TestTrie_ShadowNodeRW_expired(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) + assert.NoError(t, err) + + tr, err := NewSecureWithShadowNodes(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) + tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) + val, err := tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + assert.NoError(t, tr.TryDelete(makeHash("k1").Bytes())) + + // commit + nextRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big1, nextRoot)) + + // reload in epoch2 + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + val, err = tr.TryGet(makeHash("k2").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v2").Bytes(), val) + + // reload in epoch3, check expired + tr, err = NewSecureWithShadowNodes(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + _, err = tr.TryGet(makeHash("k2").Bytes()) + assert.Error(t, err) +} + +func TestTrie_ShadowNodeRW_accessStates(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, blockRoot0) + assert.NoError(t, err) + + tr, err := NewSecureWithShadowNodes(types.StateEpoch(1), emptyRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + tr.Update(makeHash("k1").Bytes(), makeHash("v1").Bytes()) + tr.Update(makeHash("k2").Bytes(), makeHash("v2").Bytes()) + val, err := tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + + // commit + nextRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big1, nextRoot)) + + // reload in epoch2, access K1, k2 + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(2), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + err = tr.TryUpdateEpoch(makeHash("k1").Bytes()) + assert.NoError(t, err) + err = tr.TryUpdateEpoch(makeHash("k2").Bytes()) + assert.NoError(t, err) + + // commit + nextRoot, _, err = tr.Commit(nil) + assert.NoError(t, err) + assert.NoError(t, storageDB.Commit(common.Big2, nextRoot)) + + // reload in epoch3 + storageDB, err = NewShadowNodeDatabase(tree, common.Big2, nextRoot) + assert.NoError(t, err) + tr, err = NewSecureWithShadowNodes(types.StateEpoch(3), nextRoot, database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + val, err = tr.TryGet(makeHash("k1").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v1").Bytes(), val) + val, err = tr.TryGet(makeHash("k2").Bytes()) + assert.NoError(t, err) + assert.Equal(t, makeHash("v2").Bytes(), val) +} + +func TestTrie_ShadowHash(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, emptyRoot) + assert.NoError(t, err) + + tr, err := NewWithShadowNode(types.StateEpoch0, newEpoch0RootNode(emptyRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + + batchUpdateTrie(t, tr, []string{"a711355", "450", "a77d337", "100", "a7f9365", "110", "a77d397", "012"}) + sh1, err := tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0xc752578873185d8b97bdf9e59c8178719e30a03515c7a791e779d4823bbb3fa4"), *sh1) + + // commit and shadow hash again + newRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + err = storageDB.Commit(common.Big1, newRoot) + assert.NoError(t, err) + + storageDB, err = NewShadowNodeDatabase(tree, common.Big1, newRoot) + assert.NoError(t, err) + tr, err = NewWithShadowNode(types.StateEpoch(1), newEpoch0RootNode(newRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + sh1, err = tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0xc752578873185d8b97bdf9e59c8178719e30a03515c7a791e779d4823bbb3fa4"), *sh1) + + err = tr.TryUpdate(common.Hex2Bytes("a711355"), common.Hex2Bytes("800")) + assert.NoError(t, err) + sh1, err = tr.ShadowHash() + assert.NoError(t, err) + assert.Equal(t, common.HexToHash("0xa88d96fa4e1b7b4421198f965230b85e153a6453d1f43b97c0ad89feafa73dd6"), *sh1) +} + +func TestTrie_ShadowHash_case2(t *testing.T) { + database, tree := makeStorageTrieDatabase(t) + storageDB, err := NewShadowNodeDatabase(tree, common.Big0, emptyRoot) + assert.NoError(t, err) + + tr, err := NewWithShadowNode(10, newRootNode(10, emptyRoot, emptyRoot), database, storageDB.OpenStorage(contract1)) + assert.NoError(t, err) + + batchUpdateTrie(t, tr, []string{"223dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "450", "224dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "100", "233dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "110", "253dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288", "012"}) + + // commit and shadow hash again + newRoot, _, err := tr.Commit(nil) + assert.NoError(t, err) + err = storageDB.Commit(common.Big1, newRoot) + assert.NoError(t, err) + + storageDB, err = NewShadowNodeDatabase(tree, common.Big2, newRoot) + assert.NoError(t, err) + + sndb := storageDB.OpenStorage(contract1) + rn := tr.db.node(newRoot) + // enc, err := sndb.Get(ShadowTreeRootNodePath) + // assert.NoError(t, err) + // r1, err := decodeRootNode(enc) + // assert.NoError(t, err) + tr, err = NewWithShadowNode(11, rn.(*rootNode), database, sndb) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("223dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("224dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("233dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) + _, err = tr.TryGet(common.Hex2Bytes("253dffac48c9ce11eb8dd110a36c55aa7f51fd1ab98b4c9b8ebe4decfd72f2288")) + assert.NoError(t, err) +} + +func batchUpdateTrie(t *testing.T, tr *Trie, kvs []string) { + if len(kvs)%2 != 0 { + panic("wrong kvs") + } + for i := 0; i < len(kvs); i += 2 { + err := tr.TryUpdate(common.Hex2Bytes(kvs[i]), common.Hex2Bytes(kvs[i+1])) + assert.NoError(t, err) + } +} + +func makeStorageTrieDatabase(t *testing.T) (*Database, *ShadowNodeSnapTree) { + diskdb := memorydb.New() + database := NewDatabase(diskdb) + tree, err := NewShadowNodeSnapTree(diskdb, true) + assert.NoError(t, err) + return database, tree +} + func BenchmarkGet(b *testing.B) { benchGet(b, false) } func BenchmarkGetDB(b *testing.B) { benchGet(b, true) } func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) } @@ -893,6 +1450,14 @@ func TestCommitSequenceSmallRoot(t *testing.T) { } } +func TestSafeAppendBytes(t *testing.T) { + assert.Equal(t, safeAppendBytes(nil, []byte{1}...), append([]byte(nil), 1)) + assert.Equal(t, safeAppendBytes(nil, []byte{1, 2, 3, 5}...), append([]byte(nil), []byte{1, 2, 3, 5}...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}, []byte{5, 4, 3, 2, 1}...), append([]byte{1, 2, 3, 5}, []byte{5, 4, 3, 2, 1}...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}, []byte(nil)...), append([]byte{1, 2, 3, 5}, []byte(nil)...)) + assert.Equal(t, safeAppendBytes([]byte{1, 2, 3, 5}), append([]byte{1, 2, 3, 5})) +} + // BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie. // This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically, // storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple @@ -1048,7 +1613,7 @@ func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [] h := trie.Hash() trie.Commit(nil) b.StartTimer() - trie.db.Dereference(h) + trie.db.Dereference(h, 0) // TODO(asyukii): set epoch to 0 temporary, might need to fix b.StopTimer() }