diff --git a/lib/trie/version.go b/lib/trie/version.go new file mode 100644 index 0000000000..23527890c8 --- /dev/null +++ b/lib/trie/version.go @@ -0,0 +1,44 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "errors" + "fmt" + "strings" +) + +// Version is the state trie version which dictates how a +// Merkle root should be constructed. It is defined in +// https://spec.polkadot.network/#defn-state-version +type Version uint8 + +const ( + // V0 is the state trie version 0 where the values of the keys are + // inserted into the trie directly. + // TODO set to iota once CI passes + V0 Version = 1 +) + +func (v Version) String() string { + switch v { + case V0: + return "v0" + default: + panic(fmt.Sprintf("unknown version %d", v)) + } +} + +var ErrParseVersion = errors.New("parsing version failed") + +// ParseVersion parses a state trie version string. +func ParseVersion(s string) (version Version, err error) { + switch { + case strings.EqualFold(s, V0.String()): + return V0, nil + default: + return version, fmt.Errorf("%w: %q must be %s", + ErrParseVersion, s, V0) + } +} diff --git a/lib/trie/version_test.go b/lib/trie/version_test.go new file mode 100644 index 0000000000..ab2ac03ebe --- /dev/null +++ b/lib/trie/version_test.go @@ -0,0 +1,86 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Version_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + version Version + versionString string + panicMessage string + }{ + "v0": { + version: V0, + versionString: "v0", + }, + "invalid": { + version: Version(99), + panicMessage: "unknown version 99", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + if testCase.panicMessage != "" { + assert.PanicsWithValue(t, testCase.panicMessage, func() { + _ = testCase.version.String() + }) + return + } + + versionString := testCase.version.String() + assert.Equal(t, testCase.versionString, versionString) + }) + } +} + +func Test_ParseVersion(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + s string + version Version + errWrapped error + errMessage string + }{ + "v0": { + s: "v0", + version: V0, + }, + "V0": { + s: "V0", + version: V0, + }, + "invalid": { + s: "xyz", + errWrapped: ErrParseVersion, + errMessage: "parsing version failed: \"xyz\" must be v0", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + version, err := ParseVersion(testCase.s) + + assert.Equal(t, testCase.version, version) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +}