Skip to content

Commit

Permalink
chore(trie): merge leaf and branch in single node struct (#2504)
Browse files Browse the repository at this point in the history
- Removed type assertions and `Node` interface
- Leaf and branch structs merged in one common `Node` struct
- Removed trivial methods that were used to avoid type switches such as `GetHash()`
- Change `Children` field from an array to a slice so it can be nil for leaves
- Deduct node type from children slice (not nil is for a leaf)

Co-authored-by: Eclésio Junior <[email protected]>
  • Loading branch information
qdm12 and EclesioMeloJunior authored Jun 6, 2022
1 parent f2cdfea commit 6d22c23
Show file tree
Hide file tree
Showing 46 changed files with 2,188 additions and 3,291 deletions.
7 changes: 6 additions & 1 deletion dot/state/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) {
"0",
))

testChildTrie := trie.NewTrie(node.NewLeaf([]byte{1, 2}, []byte{3, 4}, true, 0))
trieRoot := &node.Node{
Key: []byte{1, 2},
Value: []byte{3, 4},
Dirty: true,
}
testChildTrie := trie.NewTrie(trieRoot)

testChildTrie.Put([]byte("keyInsidechild"), []byte("voila"))

Expand Down
4 changes: 2 additions & 2 deletions dot/state/tries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ func Test_Tries_get(t *testing.T) {
"found in map": {
tries: &Tries{
rootToTrie: map[common.Hash]*trie.Trie{
{1, 2, 3}: trie.NewTrie(&node.Leaf{
{1, 2, 3}: trie.NewTrie(&node.Node{
Key: []byte{1, 2, 3},
}),
},
},
root: common.Hash{1, 2, 3},
trie: trie.NewTrie(&node.Leaf{
trie: trie.NewTrie(&node.Node{
Key: []byte{1, 2, 3},
}),
},
Expand Down
9 changes: 3 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,18 @@ require (
github.com/nanobox-io/golang-scribble v0.0.0-20190309225732-aa3e7c118975
github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416
github.com/perlin-network/life v0.0.0-20191203030451-05c0e0f7eaea
github.com/prometheus/client_golang v1.12.2
github.com/prometheus/client_model v0.2.0
github.com/qdm12/gotree v0.2.0
github.com/stretchr/testify v1.7.1
github.com/urfave/cli v1.22.9
github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
golang.org/x/text v0.3.7
google.golang.org/protobuf v1.28.0
)

require (
github.com/prometheus/client_golang v1.12.2
github.com/prometheus/client_model v0.2.0
)

require (
github.com/ChainSafe/log15 v1.0.0 // indirect
github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 // indirect
Expand Down Expand Up @@ -181,7 +179,6 @@ require (
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/appengine v1.6.6 // indirect
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
Expand Down
84 changes: 0 additions & 84 deletions internal/trie/node/branch.go

This file was deleted.

194 changes: 66 additions & 128 deletions internal/trie/node/branch_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,125 +10,17 @@ import (
"io"
"runtime"

"github.com/ChainSafe/gossamer/internal/trie/codec"
"github.com/ChainSafe/gossamer/internal/trie/pools"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"
)

// ScaleEncodeHash hashes the node (blake2b sum on encoded value)
// and then SCALE encodes it. This is used to encode children
// nodes of branches.
func (b *Branch) ScaleEncodeHash() (encoding []byte, err error) {
buffer := pools.DigestBuffers.Get().(*bytes.Buffer)
buffer.Reset()
defer pools.DigestBuffers.Put(buffer)

err = b.hash(buffer)
if err != nil {
return nil, fmt.Errorf("cannot hash branch: %w", err)
}

encoding, err = scale.Marshal(buffer.Bytes())
if err != nil {
return nil, fmt.Errorf("cannot scale encode hashed branch: %w", err)
}

return encoding, nil
}

func (b *Branch) hash(digestBuffer io.Writer) (err error) {
encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer)
encodingBuffer.Reset()
defer pools.EncodingBuffers.Put(encodingBuffer)

err = b.Encode(encodingBuffer)
if err != nil {
return fmt.Errorf("cannot encode leaf: %w", err)
}

// if length of encoded branch is less than 32 bytes, do not hash
if encodingBuffer.Len() < 32 {
_, err = digestBuffer.Write(encodingBuffer.Bytes())
if err != nil {
return fmt.Errorf("cannot write encoded branch to buffer: %w", err)
}
return nil
}

// otherwise, hash encoded node
hasher := pools.Hashers.Get().(hash.Hash)
hasher.Reset()
defer pools.Hashers.Put(hasher)

// Note: using the sync.Pool's buffer is useful here.
_, err = hasher.Write(encodingBuffer.Bytes())
if err != nil {
return fmt.Errorf("cannot hash encoded node: %w", err)
}

_, err = digestBuffer.Write(hasher.Sum(nil))
if err != nil {
return fmt.Errorf("cannot write hash sum of branch to buffer: %w", err)
}
return nil
}

// Encode encodes a branch with the encoding specified at the top of this package
// to the buffer given.
func (b *Branch) Encode(buffer Buffer) (err error) {
if !b.Dirty && b.Encoding != nil {
_, err = buffer.Write(b.Encoding)
if err != nil {
return fmt.Errorf("cannot write stored encoding to buffer: %w", err)
}
return nil
}

err = b.encodeHeader(buffer)
if err != nil {
return fmt.Errorf("cannot encode header: %w", err)
}

keyLE := codec.NibblesToKeyLE(b.Key)
_, err = buffer.Write(keyLE)
if err != nil {
return fmt.Errorf("cannot write encoded key to buffer: %w", err)
}

childrenBitmap := common.Uint16ToBytes(b.ChildrenBitmap())
_, err = buffer.Write(childrenBitmap)
if err != nil {
return fmt.Errorf("cannot write children bitmap to buffer: %w", err)
}

if b.Value != nil {
bytes, err := scale.Marshal(b.Value)
if err != nil {
return fmt.Errorf("cannot scale encode value: %w", err)
}

_, err = buffer.Write(bytes)
if err != nil {
return fmt.Errorf("cannot write encoded value to buffer: %w", err)
}
}

err = encodeChildrenOpportunisticParallel(b.Children, buffer)
if err != nil {
return fmt.Errorf("cannot encode children of branch: %w", err)
}

return nil
}

type encodingAsyncResult struct {
index int
buffer *bytes.Buffer
err error
}

func runEncodeChild(child Node, index int,
func runEncodeChild(child *Node, index int,
results chan<- encodingAsyncResult, rateLimit <-chan struct{}) {
buffer := pools.EncodingBuffers.Get().(*bytes.Buffer)
buffer.Reset()
Expand Down Expand Up @@ -158,13 +50,13 @@ var parallelEncodingRateLimit = make(chan struct{}, parallelLimit)
// goroutines IF they are less than the parallelLimit number of goroutines already
// running. This is designed to limit the total number of goroutines in order to
// avoid using too much memory on the stack.
func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (err error) {
func encodeChildrenOpportunisticParallel(children []*Node, buffer io.Writer) (err error) {
// Buffered channels since children might be encoded in this
// goroutine or another one.
resultsCh := make(chan encodingAsyncResult, ChildrenCapacity)

for i, child := range children {
if isNodeNil(child) || child.Type() == LeafType {
if child == nil || child.Type() == Leaf {
runEncodeChild(child, i, resultsCh, nil)
continue
}
Expand Down Expand Up @@ -223,7 +115,7 @@ func encodeChildrenOpportunisticParallel(children [16]Node, buffer io.Writer) (e
return err
}

func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) {
func encodeChildrenSequentially(children []*Node, buffer io.Writer) (err error) {
for i, child := range children {
err = encodeChild(child, buffer)
if err != nil {
Expand All @@ -233,32 +125,78 @@ func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error)
return nil
}

func isNodeNil(n Node) (isNil bool) {
switch impl := n.(type) {
case *Branch:
isNil = impl == nil
case *Leaf:
isNil = impl == nil
default:
isNil = n == nil
}
return isNil
}

func encodeChild(child Node, buffer io.Writer) (err error) {
if isNodeNil(child) {
func encodeChild(child *Node, buffer io.Writer) (err error) {
if child == nil {
return nil
}

scaleEncodedChild, err := child.ScaleEncodeHash()
scaleEncodedChildHash, err := scaleEncodeHash(child)
if err != nil {
return fmt.Errorf("failed to hash and scale encode child: %w", err)
}

_, err = buffer.Write(scaleEncodedChild)
_, err = buffer.Write(scaleEncodedChildHash)
if err != nil {
return fmt.Errorf("failed to write child to buffer: %w", err)
}

return nil
}

// scaleEncodeHash hashes the node (blake2b sum on encoded value)
// and then SCALE encodes it. This is used to encode children
// nodes of branches.
func scaleEncodeHash(node *Node) (encoding []byte, err error) {
buffer := pools.DigestBuffers.Get().(*bytes.Buffer)
buffer.Reset()
defer pools.DigestBuffers.Put(buffer)

err = hashNode(node, buffer)
if err != nil {
return nil, fmt.Errorf("cannot hash %s: %w", node.Type(), err)
}

encoding, err = scale.Marshal(buffer.Bytes())
if err != nil {
return nil, fmt.Errorf("cannot scale encode hashed %s: %w", node.Type(), err)
}

return encoding, nil
}

func hashNode(node *Node, digestWriter io.Writer) (err error) {
encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer)
encodingBuffer.Reset()
defer pools.EncodingBuffers.Put(encodingBuffer)

err = node.Encode(encodingBuffer)
if err != nil {
return fmt.Errorf("cannot encode %s: %w", node.Type(), err)
}

// if length of encoded leaf is less than 32 bytes, do not hash
if encodingBuffer.Len() < 32 {
_, err = digestWriter.Write(encodingBuffer.Bytes())
if err != nil {
return fmt.Errorf("cannot write encoded %s to buffer: %w", node.Type(), err)
}
return nil
}

// otherwise, hash encoded node
hasher := pools.Hashers.Get().(hash.Hash)
hasher.Reset()
defer pools.Hashers.Put(hasher)

// Note: using the sync.Pool's buffer is useful here.
_, err = hasher.Write(encodingBuffer.Bytes())
if err != nil {
return fmt.Errorf("cannot hash encoding of %s: %w", node.Type(), err)
}

_, err = digestWriter.Write(hasher.Sum(nil))
if err != nil {
return fmt.Errorf("cannot write hash sum of %s to buffer: %w", node.Type(), err)
}
return nil
}
Loading

0 comments on commit 6d22c23

Please sign in to comment.