diff --git a/trie/committer.go b/trie/committer.go
index cf43e12fe..584288e62 100644
--- a/trie/committer.go
+++ b/trie/committer.go
@@ -60,9 +60,9 @@ var committerPool = sync.Pool{
}
// newCommitter creates a new committer or picks one from the pool.
-func newCommitter(owner common.Hash, tracer *tracer, collectLeaf bool) *committer {
+func newCommitter(nodes *NodeSet, tracer *tracer, collectLeaf bool) *committer {
return &committer{
- nodes: NewNodeSet(owner),
+ nodes: nodes,
tracer: tracer,
collectLeaf: collectLeaf,
}
@@ -74,20 +74,6 @@ func (c *committer) Commit(n node) (hashNode, *NodeSet, error) {
if err != nil {
return nil, nil, err
}
- // Some nodes can be deleted from trie which can't be captured by committer
- // itself. Iterate all deleted nodes tracked by tracer and marked them as
- // deleted only if they are present in database previously.
- for _, path := range c.tracer.deleteList() {
- // There are a few possibilities for this scenario(the node is deleted
- // but not present in database previously), for example the node was
- // embedded in the parent and now deleted from the trie. In this case
- // it's noop from database's perspective.
- val := c.tracer.getPrev(path)
- if len(val) == 0 {
- continue
- }
- c.nodes.markDeleted(path, val)
- }
return h.(hashNode), c.nodes, nil
}
@@ -119,12 +105,6 @@ func (c *committer) commit(path []byte, n node) (node, error) {
if hn, ok := hashedNode.(hashNode); ok {
return hn, nil
}
- // The short node now is embedded in its parent. Mark the node as
- // deleted if it's present in database previously. It's equivalent
- // as deletion from database's perspective.
- if prev := c.tracer.getPrev(path); len(prev) != 0 {
- c.nodes.markDeleted(path, prev)
- }
return collapsed, nil
case *fullNode:
hashedKids, err := c.commitChildren(path, cn)
@@ -138,12 +118,6 @@ func (c *committer) commit(path []byte, n node) (node, error) {
if hn, ok := hashedNode.(hashNode); ok {
return hn, nil
}
- // The short node now is embedded in its parent. Mark the node as
- // deleted if it's present in database previously. It's equivalent
- // as deletion from database's perspective.
- if prev := c.tracer.getPrev(path); len(prev) != 0 {
- c.nodes.markDeleted(path, prev)
- }
return collapsed, nil
case hashNode:
return cn, nil
@@ -196,6 +170,13 @@ func (c *committer) store(path []byte, n node) node {
// usually is leaf node). But small value(less than 32bytes) is not
// our target(leaves in account trie only).
if hash == nil {
+ // The node is embedded in its parent, in other words, this node
+ // will not be stored in the database independently, mark it as
+ // deleted only if the node was existent in database before.
+ prev, ok := c.tracer.accessList[string(path)]
+ if ok {
+ c.nodes.addNode(path, &nodeWithPrev{&memoryNode{}, prev})
+ }
return n
}
// We have the hash already, estimate the RLP encoding-size of the node.
@@ -203,15 +184,18 @@ func (c *committer) store(path []byte, n node) node {
var (
size = estimateSize(n)
nhash = common.BytesToHash(hash)
- mnode = &memoryNode{
- hash: nhash,
- node: simplifyNode(n),
- size: uint16(size),
+ node = &nodeWithPrev{
+ &memoryNode{
+ nhash,
+ uint16(size),
+ simplifyNode(n),
+ },
+ c.tracer.accessList[string(path)],
}
)
// Collect the dirty node to nodeset for return.
- c.nodes.markUpdated(path, mnode, c.tracer.getPrev(path))
+ c.nodes.addNode(path, node)
// Collect the corresponding leaf node if it's required. We don't check
// full node since it's impossible to store value in fullNode. The key
// length of leaves should be exactly same.
diff --git a/trie/database.go b/trie/database.go
index e0225eb18..6b1bc7a62 100644
--- a/trie/database.go
+++ b/trie/database.go
@@ -782,17 +782,31 @@ func (db *Database) Update(nodes *MergedNodeSet) error {
defer db.lock.Unlock()
// Insert dirty nodes into the database. In the same tree, it must be
// ensured that children are inserted first, then parent so that children
- // can be linked with their parent correctly. The order of writing between
- // different tries(account trie, storage tries) is not required.
- for owner, subset := range nodes.sets {
- for _, path := range subset.updates.order {
- n, ok := subset.updates.nodes[path]
- if !ok {
- return fmt.Errorf("missing node %x %v", owner, path)
+ // can be linked with their parent correctly.
+ //
+ // Note, the storage tries must be flushed before the account trie to
+ // retain the invariant that children go into the dirty cache first.
+ var order []common.Hash
+ for owner := range nodes.sets {
+ if owner == (common.Hash{}) {
+ continue
+ }
+ order = append(order, owner)
+ }
+ if _, ok := nodes.sets[common.Hash{}]; ok {
+ order = append(order, common.Hash{})
+ }
+ for _, owner := range order {
+ subset := nodes.sets[owner]
+ subset.forEachWithOrder(func(path string, n *memoryNode) {
+ if n.isDeleted() {
+ return // ignore deletion
}
db.insert(n.hash, int(n.size), n.node)
- }
+ })
}
+ // Link up the account trie and storage trie if the node points
+ // to an account trie leaf.
if set, present := nodes.sets[common.Hash{}]; present {
for _, leaf := range set.leaves {
// Looping node leaf, then reference the leaf node to the root node
diff --git a/trie/nodeset.go b/trie/nodeset.go
index a94535069..6b99dbebc 100644
--- a/trie/nodeset.go
+++ b/trie/nodeset.go
@@ -19,6 +19,7 @@ package trie
import (
"fmt"
"reflect"
+ "sort"
"strings"
"github.com/ethereum/go-ethereum/common"
@@ -42,8 +43,13 @@ var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
// memorySize returns the total memory size used by this node.
// nolint:unused
-func (n *memoryNode) memorySize(key int) int {
- return int(n.size) + memoryNodeSize + key
+func (n *memoryNode) memorySize(pathlen int) int {
+ return int(n.size) + memoryNodeSize + pathlen
+}
+
+// isDeleted returns the indicator if the node is marked as deleted.
+func (n *memoryNode) isDeleted() bool {
+ return n.hash == (common.Hash{})
}
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
@@ -89,21 +95,19 @@ func (n *nodeWithPrev) memorySize(key int) int {
return n.memoryNode.memorySize(key) + len(n.prev)
}
-// nodesWithOrder represents a collection of dirty nodes which includes
-// newly-inserted and updated nodes. The modification order of all nodes
-// is represented by order list.
-type nodesWithOrder struct {
- order []string // the path list of dirty nodes, sort by insertion order
- nodes map[string]*nodeWithPrev // the map of dirty nodes, keyed by node path
-}
-
// NodeSet contains all dirty nodes collected during the commit operation
// Each node is keyed by path. It's not the thread-safe to use.
type NodeSet struct {
- owner common.Hash // the identifier of the trie
- updates *nodesWithOrder // the set of updated nodes(newly inserted, updated)
- deletes map[string][]byte // the map of deleted nodes, keyed by node
- leaves []*leaf // the list of dirty leaves
+ owner common.Hash // the identifier of the trie
+ leaves []*leaf // the list of dirty leaves
+ updates int // the count of updated and inserted nodes
+ deletes int // the count of deleted nodes
+
+ // The set of all dirty nodes. Dirty nodes include newly inserted nodes,
+ // deleted nodes and updated nodes. The original value of the newly
+ // inserted node must be nil, and the original value of the other two
+ // types must be non-nil.
+ nodes map[string]*nodeWithPrev
}
// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
@@ -112,35 +116,32 @@ type NodeSet struct {
func NewNodeSet(owner common.Hash) *NodeSet {
return &NodeSet{
owner: owner,
- updates: &nodesWithOrder{
- nodes: make(map[string]*nodeWithPrev),
- },
- deletes: make(map[string][]byte),
+ nodes: make(map[string]*nodeWithPrev),
}
}
-// NewNodeSetWithDeletion initializes the nodeset with provided deletion set.
-func NewNodeSetWithDeletion(owner common.Hash, paths [][]byte, prev [][]byte) *NodeSet {
- set := NewNodeSet(owner)
- for i, path := range paths {
- set.markDeleted(path, prev[i])
+// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
+// right to left, nodes with the longest path will be iterated first.
+func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
+ var paths sort.StringSlice
+ for path := range set.nodes {
+ paths = append(paths, path)
}
- return set
-}
-
-// markUpdated marks the node as dirty(newly-inserted or updated) with provided
-// node path, node object along with its previous value.
-func (set *NodeSet) markUpdated(path []byte, node *memoryNode, prev []byte) {
- set.updates.order = append(set.updates.order, string(path))
- set.updates.nodes[string(path)] = &nodeWithPrev{
- memoryNode: node,
- prev: prev,
+ // Bottom-up, longest path first
+ sort.Sort(sort.Reverse(paths))
+ for _, path := range paths {
+ callback(path, set.nodes[path].unwrap())
}
}
-// markDeleted marks the node as deleted with provided path and previous value.
-func (set *NodeSet) markDeleted(path []byte, prev []byte) {
- set.deletes[string(path)] = prev
+// addNode adds the provided dirty node into set.
+func (set *NodeSet) addNode(path []byte, n *nodeWithPrev) {
+ if n.isDeleted() {
+ set.deletes += 1
+ } else {
+ set.updates += 1
+ }
+ set.nodes[string(path)] = n
}
// addLeaf collects the provided leaf node into set.
@@ -150,13 +151,13 @@ func (set *NodeSet) addLeaf(leaf *leaf) {
// Size returns the number of updated and deleted nodes contained in the set.
func (set *NodeSet) Size() (int, int) {
- return len(set.updates.order), len(set.deletes)
+ return set.updates, set.deletes
}
// Hashes returns the hashes of all updated nodes.
func (set *NodeSet) Hashes() []common.Hash {
var ret []common.Hash
- for _, node := range set.updates.nodes {
+ for _, node := range set.nodes {
ret = append(ret, node.hash)
}
return ret
@@ -166,19 +167,22 @@ func (set *NodeSet) Hashes() []common.Hash {
func (set *NodeSet) Summary() string {
var out = new(strings.Builder)
fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
- if set.updates != nil {
- for _, key := range set.updates.order {
- updated := set.updates.nodes[key]
- if updated.prev != nil {
- fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", key, updated.hash, updated.prev)
- } else {
- fmt.Fprintf(out, " [+]: %x -> %v\n", key, updated.hash)
+ if set.nodes != nil {
+ for path, n := range set.nodes {
+ // Deletion
+ if n.isDeleted() {
+ fmt.Fprintf(out, " [-]: %x prev: %x\n", path, n.prev)
+ continue
}
+ // Insertion
+ if len(n.prev) == 0 {
+ fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
+ continue
+ }
+ // Update
+ fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, n.prev)
}
}
- for k, n := range set.deletes {
- fmt.Fprintf(out, " [-]: %x -> %x\n", k, n)
- }
for _, n := range set.leaves {
fmt.Fprintf(out, "[leaf]: %v\n", n)
}
diff --git a/trie/proof.go b/trie/proof.go
index c58997197..29c6aa2c5 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -562,7 +562,7 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
}
// Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one.
- tr := &Trie{root: root, reader: newEmptyReader()}
+ tr := &Trie{root: root, reader: newEmptyReader(), tracer: newTracer()}
if empty {
tr.root = nil
}
diff --git a/trie/utils.go b/trie/tracer.go
similarity index 53%
rename from trie/utils.go
rename to trie/tracer.go
index d1cd3bdd2..cd5ebb85a 100644
--- a/trie/utils.go
+++ b/trie/tracer.go
@@ -36,147 +36,89 @@ package trie
// Note tracer is not thread-safe, callers should be responsible for handling
// the concurrency issues by themselves.
type tracer struct {
- insert map[string]struct{}
- delete map[string]struct{}
- origin map[string][]byte
+ inserts map[string]struct{}
+ deletes map[string]struct{}
+ accessList map[string][]byte
}
// newTracer initlializes tride node diff tracer.
func newTracer() *tracer {
return &tracer{
- insert: make(map[string]struct{}),
- delete: make(map[string]struct{}),
- origin: make(map[string][]byte),
+ inserts: make(map[string]struct{}),
+ deletes: make(map[string]struct{}),
+ accessList: make(map[string][]byte),
}
}
// onRead tracks the newly loaded trie node and caches the rlp-encoded blob internally.
// Don't change the value outside of function since it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return
- }
- t.origin[string(path)] = val
+ t.accessList[string(path)] = val
}
// onInsert tracks the newly inserted trie node. If it's already
// in the delete set(resurrected node), then just wipe it from
// the deletion set as it's untouched.
func (t *tracer) onInsert(path []byte) {
- // Tracer isn't used right now, remove this check latter.
- if t == nil {
- return
- }
// If the path is in the delete set, then it's a resurrected node, then wipe it.
- if _, present := t.delete[string(path)]; present {
- delete(t.delete, string(path))
+ if _, present := t.deletes[string(path)]; present {
+ delete(t.deletes, string(path))
return
}
- t.insert[string(path)] = struct{}{}
+ t.inserts[string(path)] = struct{}{}
}
// OnDelete tracks the newly deleted trie node. If it's already
// in the addition set, then just wipe it from the addtion set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
- // Tracer isn't used right now, remove this check latter.
- if t == nil {
+ if _, present := t.inserts[string(path)]; present {
+ delete(t.inserts, string(path))
return
}
- if _, present := t.insert[string(path)]; present {
- delete(t.insert, string(path))
- return
- }
- t.delete[string(path)] = struct{}{}
-}
-
-// insertList returns the tracked inserted trie nodes in list format.
-func (t *tracer) insertList() [][]byte {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return nil
- }
- var ret [][]byte
- for path := range t.insert {
- ret = append(ret, []byte(path))
- }
- return ret
-}
-
-// deleteList returns the tracked deleted trie nodes in list format.
-func (t *tracer) deleteList() [][]byte {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return nil
- }
- var ret [][]byte
- for path := range t.delete {
- ret = append(ret, []byte(path))
- }
- return ret
-}
-
-// prevList returns the tracked node blobs in list format.
-func (t *tracer) prevList() ([][]byte, [][]byte) {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return nil, nil
- }
- var (
- paths [][]byte
- blobs [][]byte
- )
- for path, blob := range t.origin {
- paths = append(paths, []byte(path))
- blobs = append(blobs, blob)
- }
- return paths, blobs
-}
-
-// getPrev returns the cached original value of the specified node.
-func (t *tracer) getPrev(path []byte) []byte {
- // Don't panic on uninitialized tracer, it's possible in testing.
- if t == nil {
- return nil
- }
- return t.origin[string(path)]
+ t.deletes[string(path)] = struct{}{}
}
// reset clears the content tracked by tracer.
func (t *tracer) reset() {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return
- }
- t.insert = make(map[string]struct{})
- t.delete = make(map[string]struct{})
- t.origin = make(map[string][]byte)
+ t.inserts = make(map[string]struct{})
+ t.deletes = make(map[string]struct{})
+ t.accessList = make(map[string][]byte)
}
// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
- // Tracer isn't used right now, remove this check later.
- if t == nil {
- return nil
- }
var (
- insert = make(map[string]struct{})
- delete = make(map[string]struct{})
- origin = make(map[string][]byte)
+ inserts = make(map[string]struct{})
+ deletes = make(map[string]struct{})
+ accessList = make(map[string][]byte)
)
- for key := range t.insert {
- insert[key] = struct{}{}
+ for key := range t.inserts {
+ inserts[key] = struct{}{}
}
- for key := range t.delete {
- delete[key] = struct{}{}
+ for key := range t.deletes {
+ deletes[key] = struct{}{}
}
- for key, val := range t.origin {
- origin[key] = val
+ for key, val := range t.accessList {
+ accessList[key] = val
}
return &tracer{
- insert: insert,
- delete: delete,
- origin: origin,
+ inserts: inserts,
+ deletes: deletes,
+ accessList: accessList,
+ }
+}
+
+// markDeletions puts all tracked deletions into the provided nodeset.
+func (t *tracer) markDeletions(set *NodeSet) {
+ for path := range t.deletes {
+ // It's possible a few deleted nodes were embedded
+ // in their parent before, the deletions can be no
+ // effect by deleting nothing, filter them out.
+ prev, ok := t.accessList[path]
+ if !ok {
+ continue
+ }
+ set.addNode([]byte(path), &nodeWithPrev{&memoryNode{}, prev})
}
}
diff --git a/trie/tracer_test.go b/trie/tracer_test.go
new file mode 100644
index 000000000..f8511a5e6
--- /dev/null
+++ b/trie/tracer_test.go
@@ -0,0 +1,371 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+)
+
+var (
+ tiny = []struct{ k, v string }{
+ {"k1", "v1"},
+ {"k2", "v2"},
+ {"k3", "v3"},
+ }
+ nonAligned = []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ standard = []struct{ k, v string }{
+ {string(randBytes(32)), "verb"},
+ {string(randBytes(32)), "wookiedoo"},
+ {string(randBytes(32)), "stallion"},
+ {string(randBytes(32)), "horse"},
+ {string(randBytes(32)), "coin"},
+ {string(randBytes(32)), "puppy"},
+ {string(randBytes(32)), "myothernodedata"},
+ }
+)
+
+func TestTrieTracer(t *testing.T) {
+ testTrieTracer(t, tiny)
+ testTrieTracer(t, nonAligned)
+ testTrieTracer(t, standard)
+}
+
+// Tests if the trie diffs are tracked correctly. Tracer should capture
+// all non-leaf dirty nodes, no matter the node is embedded or not.
+func testTrieTracer(t *testing.T, vals []struct{ k, v string }) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie := NewEmpty(db)
+
+ // Determine all new nodes are tracked
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ insertSet := copySet(trie.tracer.inserts) // copy before commit
+ deleteSet := copySet(trie.tracer.deletes) // copy before commit
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ seen := setKeys(iterNodes(db, root))
+ if !compareSet(insertSet, seen) {
+ t.Fatal("Unexpected insertion set")
+ }
+ if !compareSet(deleteSet, nil) {
+ t.Fatal("Unexpected deletion set")
+ }
+
+ // Determine all deletions are tracked
+ trie, _ = New(TrieID(root), db)
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ insertSet, deleteSet = copySet(trie.tracer.inserts), copySet(trie.tracer.deletes)
+ if !compareSet(insertSet, nil) {
+ t.Fatal("Unexpected insertion set")
+ }
+ if !compareSet(deleteSet, seen) {
+ t.Fatal("Unexpected deletion set")
+ }
+}
+
+// Test that after inserting a new batch of nodes and deleting them immediately,
+// the trie tracer should be cleared normally as no operation happened.
+func TestTrieTracerNoop(t *testing.T) {
+ testTrieTracerNoop(t, tiny)
+ testTrieTracerNoop(t, nonAligned)
+ testTrieTracerNoop(t, standard)
+}
+
+func testTrieTracerNoop(t *testing.T, vals []struct{ k, v string }) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ if len(trie.tracer.inserts) != 0 {
+ t.Fatal("Unexpected insertion set")
+ }
+ if len(trie.tracer.deletes) != 0 {
+ t.Fatal("Unexpected deletion set")
+ }
+}
+
+// Tests if the accessList is correctly tracked.
+func TestAccessList(t *testing.T) {
+ testAccessList(t, tiny)
+ testAccessList(t, nonAligned)
+ testAccessList(t, standard)
+}
+
+func testAccessList(t *testing.T, vals []struct{ k, v string }) {
+ var (
+ db = NewDatabase(rawdb.NewMemoryDatabase())
+ trie = NewEmpty(db)
+ orig = trie.Copy()
+ )
+ // Create trie from scratch
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, nodes); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+
+ // Update trie
+ trie, _ = New(TrieID(root), db)
+ orig = trie.Copy()
+ for _, val := range vals {
+ trie.Update([]byte(val.k), randBytes(32))
+ }
+ root, nodes, _ = trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, nodes); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+
+ // Add more new nodes
+ trie, _ = New(TrieID(root), db)
+ orig = trie.Copy()
+ var keys []string
+ for i := 0; i < 30; i++ {
+ key := randBytes(32)
+ keys = append(keys, string(key))
+ trie.Update(key, randBytes(32))
+ }
+ root, nodes, _ = trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, nodes); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+
+ // Partial deletions
+ trie, _ = New(TrieID(root), db)
+ orig = trie.Copy()
+ for _, key := range keys {
+ trie.Update([]byte(key), nil)
+ }
+ root, nodes, _ = trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, nodes); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+
+ // Delete all
+ trie, _ = New(TrieID(root), db)
+ orig = trie.Copy()
+ for _, val := range vals {
+ trie.Update([]byte(val.k), nil)
+ }
+ root, nodes, _ = trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, nodes); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+}
+
+// Tests origin values won't be tracked in Iterator or Prover
+func TestAccessListLeak(t *testing.T) {
+ var (
+ db = NewDatabase(rawdb.NewMemoryDatabase())
+ trie = NewEmpty(db)
+ )
+ // Create trie from scratch
+ for _, val := range standard {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+
+ var cases = []struct {
+ op func(tr *Trie)
+ }{
+ {
+ func(tr *Trie) {
+ it := tr.NodeIterator(nil)
+ for it.Next(true) {
+ }
+ },
+ },
+ {
+ func(tr *Trie) {
+ it := NewIterator(tr.NodeIterator(nil))
+ for it.Next() {
+ }
+ },
+ },
+ {
+ func(tr *Trie) {
+ for _, val := range standard {
+ tr.Prove([]byte(val.k), 0, rawdb.NewMemoryDatabase())
+ }
+ },
+ },
+ }
+ for _, c := range cases {
+ trie, _ = New(TrieID(root), db)
+ n1 := len(trie.tracer.accessList)
+ c.op(trie)
+ n2 := len(trie.tracer.accessList)
+
+ if n1 != n2 {
+ t.Fatalf("AccessList is leaked, prev %d after %d", n1, n2)
+ }
+ }
+}
+
+// Tests whether the original tree node is correctly deleted after being embedded
+// in its parent due to the smaller size of the original tree node.
+func TestTinyTree(t *testing.T) {
+ var (
+ db = NewDatabase(rawdb.NewMemoryDatabase())
+ trie = NewEmpty(db)
+ )
+ for _, val := range tiny {
+ trie.Update([]byte(val.k), randBytes(32))
+ }
+ root, set, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(set))
+
+ trie, _ = New(TrieID(root), db)
+ orig := trie.Copy()
+ for _, val := range tiny {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, set, _ = trie.Commit(false)
+ db.Update(NewWithNodeSet(set))
+
+ trie, _ = New(TrieID(root), db)
+ if err := verifyAccessList(orig, trie, set); err != nil {
+ t.Fatalf("Invalid accessList %v", err)
+ }
+}
+
+func compareSet(setA, setB map[string]struct{}) bool {
+ if len(setA) != len(setB) {
+ return false
+ }
+ for key := range setA {
+ if _, ok := setB[key]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+func forNodes(tr *Trie) map[string][]byte {
+ var (
+ it = tr.NodeIterator(nil)
+ nodes = make(map[string][]byte)
+ )
+ for it.Next(true) {
+ if it.Leaf() {
+ continue
+ }
+ blob := it.NodeBlob()
+ nodes[string(it.Path())] = common.CopyBytes(blob)
+ }
+ return nodes
+}
+
+func iterNodes(db *Database, root common.Hash) map[string][]byte {
+ tr, _ := New(TrieID(root), db)
+ return forNodes(tr)
+}
+
+func forHashedNodes(tr *Trie) map[string][]byte {
+ var (
+ it = tr.NodeIterator(nil)
+ nodes = make(map[string][]byte)
+ )
+ for it.Next(true) {
+ if it.Hash() == (common.Hash{}) {
+ continue
+ }
+ blob := it.NodeBlob()
+ nodes[string(it.Path())] = common.CopyBytes(blob)
+ }
+ return nodes
+}
+
+// diffTries return the diff and shared nodes between 2 tries
+func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
+ var (
+ nodesA = forHashedNodes(trieA)
+ nodesB = forHashedNodes(trieB)
+ inA = make(map[string][]byte) // hashed nodes in trie a but not b
+ inB = make(map[string][]byte) // hashed nodes in trie b but not a
+ both = make(map[string][]byte) // hashed nodes in both tries but different value
+ )
+ for path, blobA := range nodesA {
+ if blobB, ok := nodesB[path]; ok {
+ if bytes.Equal(blobA, blobB) {
+ continue
+ }
+ both[path] = blobA
+ continue
+ }
+ inA[path] = blobA
+ }
+ for path, blobB := range nodesB {
+ if _, ok := nodesA[path]; ok {
+ continue
+ }
+ inB[path] = blobB
+ }
+ return inA, inB, both
+}
+
+func setKeys(set map[string][]byte) map[string]struct{} {
+ keys := make(map[string]struct{})
+ for k := range set {
+ keys[k] = struct{}{}
+ }
+ return keys
+}
+
+func copySet(set map[string]struct{}) map[string]struct{} {
+ copied := make(map[string]struct{})
+ for k := range set {
+ copied[k] = struct{}{}
+ }
+ return copied
+}
diff --git a/trie/trie.go b/trie/trie.go
index b1c3d7136..596cc31d1 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -77,7 +77,7 @@ func New(id *ID, db NodeReader) (*Trie, error) {
trie := &Trie{
owner: id.Owner,
reader: reader,
- //tracer: newTracer(),
+ tracer: newTracer(),
}
if id.Root != (common.Hash{}) && id.Root != emptyRoot {
rootnode, err := trie.resolveAndTrack(id.Root[:], nil)
@@ -571,7 +571,7 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
// Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash {
- hash, cached, _ := t.hashRoot()
+ hash, cached := t.hashRoot()
t.root = cached
return common.BytesToHash(hash.(hashNode))
}
@@ -584,9 +584,11 @@ func (t *Trie) Hash() common.Hash {
// be created with new root and updated trie database for following usage
func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet, error) {
defer t.tracer.reset()
+ nodes := NewNodeSet(t.owner)
+ t.tracer.markDeletions(nodes)
if t.root == nil {
- return emptyRoot, nil, nil
+ return emptyRoot, nodes, nil
}
// Derive the hash for all dirty nodes first. We hold the assumption
// in the following procedure that all nodes are hashed.
@@ -601,7 +603,7 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet, error) {
t.root = hashedNode
return rootHash, nil, nil
}
- h := newCommitter(t.owner, t.tracer, collectLeaf)
+ h := newCommitter(nodes, t.tracer, collectLeaf)
newRoot, nodes, err := h.Commit(t.root)
if err != nil {
return common.Hash{}, nil, err
@@ -612,16 +614,16 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet, error) {
}
// hashRoot calculates the root hash of the given trie
-func (t *Trie) hashRoot() (node, node, error) {
+func (t *Trie) hashRoot() (node, node) {
if t.root == nil {
- return hashNode(emptyRoot.Bytes()), nil, nil
+ return hashNode(emptyRoot.Bytes()), nil
}
// If the number of changes is below 100, we let one thread handle it
h := newHasher(t.unhashed >= 100)
defer returnHasherToPool(h)
hashed, cached := h.hash(t.root, true)
t.unhashed = 0
- return hashed, cached, nil
+ return hashed, cached
}
// Reset drops the referenced root node and cleans all internal state.
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 02efa6104..499f0574d 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -420,6 +420,49 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(steps)
}
+// verifyAccessList verifies the access list of the new trie against the old trie.
+func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error {
+ deletes, inserts, updates := diffTries(old, new)
+
+ // Check insertion set
+ for path := range inserts {
+ n, ok := set.nodes[path]
+ if !ok || n.isDeleted() {
+ return errors.New("expect new node")
+ }
+ if len(n.prev) > 0 {
+ return errors.New("unexpected origin value")
+ }
+ }
+ // Check deletion set
+ for path, blob := range deletes {
+ n, ok := set.nodes[path]
+ if !ok || !n.isDeleted() {
+ return errors.New("expect deleted node")
+ }
+ if len(n.prev) == 0 {
+ return errors.New("expect origin value")
+ }
+ if !bytes.Equal(n.prev, blob) {
+ return errors.New("invalid origin value")
+ }
+ }
+ // Check update set
+ for path, blob := range updates {
+ n, ok := set.nodes[path]
+ if !ok || n.isDeleted() {
+ return errors.New("expect updated node")
+ }
+ if len(n.prev) == 0 {
+ return errors.New("expect origin value")
+ }
+ if !bytes.Equal(n.prev, blob) {
+ return errors.New("invalid origin value")
+ }
+ }
+ return nil
+}
+
func runRandTest(rt randTest) bool {
var (
triedb = NewDatabase(rawdb.NewMemoryDatabase())
@@ -468,24 +511,6 @@ func runRandTest(rt randTest) bool {
rt[i].err = err
return false
}
- // Validity the returned nodeset
- if nodes != nil {
- for path, node := range nodes.updates.nodes {
- blob, _, _ := origTrie.TryGetNode(hexToCompact([]byte(path)))
- got := node.prev
- if !bytes.Equal(blob, got) {
- rt[i].err = fmt.Errorf("prevalue mismatch for 0x%x, got 0x%x want 0x%x", path, got, blob)
- panic(rt[i].err)
- }
- }
- for path, prev := range nodes.deletes {
- blob, _, _ := origTrie.TryGetNode(hexToCompact([]byte(path)))
- if !bytes.Equal(blob, prev) {
- rt[i].err = fmt.Errorf("prevalue mismatch for 0x%x, got 0x%x want 0x%x", path, prev, blob)
- return false
- }
- }
- }
if nodes != nil {
triedb.Update(NewWithNodeSet(nodes))
}
@@ -494,8 +519,13 @@ func runRandTest(rt randTest) bool {
rt[i].err = err
return false
}
+ if nodes != nil {
+ if err := verifyAccessList(origTrie, newtr, nodes); err != nil {
+ rt[i].err = err
+ return false
+ }
+ }
tr = newtr
-
// Enable node tracing. Resolve the root node again explicitly
// since it's not captured at the beginning.
tr.tracer = newTracer()
diff --git a/trie/utils_test.go b/trie/utils_test.go
deleted file mode 100644
index 011d93967..000000000
--- a/trie/utils_test.go
+++ /dev/null
@@ -1,242 +0,0 @@
-// Copyright 2022 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see .
-
-package trie
-
-import (
- "bytes"
- "testing"
-
- "github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/core/rawdb"
-)
-
-// Tests if the trie diffs are tracked correctly.
-func TestTrieTracer(t *testing.T) {
- db := NewDatabase(rawdb.NewMemoryDatabase())
- trie := NewEmpty(db)
- trie.tracer = newTracer()
-
- // Insert a batch of entries, all the nodes should be marked as inserted
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"dog", "puppy"},
- {"somethingveryoddindeedthis is", "myothernodedata"},
- }
- for _, val := range vals {
- trie.Update([]byte(val.k), []byte(val.v))
- }
- trie.Hash()
-
- seen := make(map[string]struct{})
- it := trie.NodeIterator(nil)
- for it.Next(true) {
- if it.Leaf() {
- continue
- }
- seen[string(it.Path())] = struct{}{}
- }
- inserted := trie.tracer.insertList()
- if len(inserted) != len(seen) {
- t.Fatalf("Unexpected inserted node tracked want %d got %d", len(seen), len(inserted))
- }
- for _, k := range inserted {
- _, ok := seen[string(k)]
- if !ok {
- t.Fatalf("Unexpected inserted node")
- }
- }
- deleted := trie.tracer.deleteList()
- if len(deleted) != 0 {
- t.Fatalf("Unexpected deleted node tracked %d", len(deleted))
- }
-
- // Commit the changes
- root, nodes, _ := trie.Commit(false)
- db.Update(NewWithNodeSet(nodes))
- trie, _ = New(TrieID(root), db)
- trie.tracer = newTracer()
-
- // Delete all the elements, check deletion set
- for _, val := range vals {
- trie.Delete([]byte(val.k))
- }
- trie.Hash()
-
- inserted = trie.tracer.insertList()
- if len(inserted) != 0 {
- t.Fatalf("Unexpected inserted node tracked %d", len(inserted))
- }
- deleted = trie.tracer.deleteList()
- if len(deleted) != len(seen) {
- t.Fatalf("Unexpected deleted node tracked want %d got %d", len(seen), len(deleted))
- }
- for _, k := range deleted {
- _, ok := seen[string(k)]
- if !ok {
- t.Fatalf("Unexpected inserted node")
- }
- }
-}
-
-func TestTrieTracerNoop(t *testing.T) {
- trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
- trie.tracer = newTracer()
-
- // Insert a batch of entries, all the nodes should be marked as inserted
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"dog", "puppy"},
- {"somethingveryoddindeedthis is", "myothernodedata"},
- }
- for _, val := range vals {
- trie.Update([]byte(val.k), []byte(val.v))
- }
- for _, val := range vals {
- trie.Delete([]byte(val.k))
- }
- if len(trie.tracer.insertList()) != 0 {
- t.Fatalf("Unexpected inserted node tracked %d", len(trie.tracer.insertList()))
- }
- if len(trie.tracer.deleteList()) != 0 {
- t.Fatalf("Unexpected deleted node tracked %d", len(trie.tracer.deleteList()))
- }
-}
-func TestTrieTracePrevValue(t *testing.T) {
- db := NewDatabase(rawdb.NewMemoryDatabase())
- trie := NewEmpty(db)
- trie.tracer = newTracer()
-
- paths, blobs := trie.tracer.prevList()
- if len(paths) != 0 || len(blobs) != 0 {
- t.Fatalf("Nothing should be tracked")
- }
- // Insert a batch of entries, all the nodes should be marked as inserted
- vals := []struct{ k, v string }{
- {"do", "verb"},
- {"ether", "wookiedoo"},
- {"horse", "stallion"},
- {"shaman", "horse"},
- {"doge", "coin"},
- {"dog", "puppy"},
- {"somethingveryoddindeedthis is", "myothernodedata"},
- }
- for _, val := range vals {
- trie.Update([]byte(val.k), []byte(val.v))
- }
- paths, blobs = trie.tracer.prevList()
- if len(paths) != 0 || len(blobs) != 0 {
- t.Fatalf("Nothing should be tracked")
- }
-
- // Commit the changes and re-create with new root
- root, nodes, _ := trie.Commit(false)
- if err := db.Update(NewWithNodeSet(nodes)); err != nil {
- t.Fatal(err)
- }
- trie, _ = New(TrieID(root), db)
- trie.tracer = newTracer()
- trie.resolveAndTrack(root.Bytes(), nil)
-
- // Load all nodes in trie
- for _, val := range vals {
- trie.TryGet([]byte(val.k))
- }
-
- // Ensure all nodes are tracked by tracer with correct prev-values
- iter := trie.NodeIterator(nil)
- seen := make(map[string][]byte)
- for iter.Next(true) {
- // Embedded nodes are ignored since they are not present in
- // database.
- if iter.Hash() == (common.Hash{}) {
- continue
- }
- blob := iter.NodeBlob()
- seen[string(iter.Path())] = common.CopyBytes(blob)
- }
-
- paths, blobs = trie.tracer.prevList()
- if len(paths) != len(seen) || len(blobs) != len(seen) {
- t.Fatalf("Unexpected tracked values")
- }
- for i, path := range paths {
- blob := blobs[i]
- prev, ok := seen[string(path)]
- if !ok {
- t.Fatalf("Missing node %v", path)
- }
- if !bytes.Equal(blob, prev) {
- t.Fatalf("Unexpected value path: %v, want: %v, got: %v", path, prev, blob)
- }
- }
-
- // Re-open the trie and iterate the trie, ensure nothing will be tracked.
- // Iterator will not link any loaded nodes to trie.
- trie, _ = New(TrieID(root), db)
- trie.tracer = newTracer()
-
- iter = trie.NodeIterator(nil)
- for iter.Next(true) {
- }
- paths, blobs = trie.tracer.prevList()
- if len(paths) != 0 || len(blobs) != 0 {
- t.Fatalf("Nothing should be tracked")
- }
-
- // Re-open the trie and generate proof for entries, ensure nothing will
- // be tracked. Prover will not link any loaded nodes to trie.
- trie, _ = New(TrieID(root), db)
- trie.tracer = newTracer()
- for _, val := range vals {
- trie.Prove([]byte(val.k), 0, rawdb.NewMemoryDatabase())
- }
- paths, blobs = trie.tracer.prevList()
- if len(paths) != 0 || len(blobs) != 0 {
- t.Fatalf("Nothing should be tracked")
- }
-
- // Delete entries from trie, ensure all previous values are correct.
- trie, _ = New(TrieID(root), db)
- trie.tracer = newTracer()
- trie.resolveAndTrack(root.Bytes(), nil)
-
- for _, val := range vals {
- trie.TryDelete([]byte(val.k))
- }
- paths, blobs = trie.tracer.prevList()
- if len(paths) != len(seen) || len(blobs) != len(seen) {
- t.Fatalf("Unexpected tracked values")
- }
- for i, path := range paths {
- blob := blobs[i]
- prev, ok := seen[string(path)]
- if !ok {
- t.Fatalf("Missing node %v", path)
- }
- if !bytes.Equal(blob, prev) {
- t.Fatalf("Unexpected value path: %v, want: %v, got: %v", path, prev, blob)
- }
- }
-}