From d616e4945c99c319f6920078afc5e9ec43cff445 Mon Sep 17 00:00:00 2001 From: kdeme Date: Thu, 21 Oct 2021 16:21:14 +0200 Subject: [PATCH 1/5] Copy ssz code from nimbus-eth2 and adjust to stand on its own --- ssz_serialization.nim | 250 ++++++++ ssz_serialization.nimble | 13 +- ssz_serialization/bitseqs.nim | 333 +++++++++++ ssz_serialization/codec.nim | 254 ++++++++ ssz_serialization/dynamic_navigator.nim | 163 ++++++ ssz_serialization/merkleization.nim | 734 ++++++++++++++++++++++++ ssz_serialization/navigator.nim | 143 +++++ ssz_serialization/types.nim | 570 ++++++++++++++++++ tests/test_all.nim | 3 + tests/test_ssz_roundtrip.nim | 14 + tests/test_ssz_serialization.nim | 359 ++++++++++++ 11 files changed, 2832 insertions(+), 4 deletions(-) create mode 100644 ssz_serialization.nim create mode 100644 ssz_serialization/bitseqs.nim create mode 100644 ssz_serialization/codec.nim create mode 100644 ssz_serialization/dynamic_navigator.nim create mode 100644 ssz_serialization/merkleization.nim create mode 100644 ssz_serialization/navigator.nim create mode 100644 ssz_serialization/types.nim create mode 100644 tests/test_ssz_roundtrip.nim create mode 100644 tests/test_ssz_serialization.nim diff --git a/ssz_serialization.nim b/ssz_serialization.nim new file mode 100644 index 0000000..4f2e8c8 --- /dev/null +++ b/ssz_serialization.nim @@ -0,0 +1,250 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} +{.pragma: raisesssz, raises: [Defect, MalformedSszError, SszSizeMismatchError].} + +## SSZ serialization for core SSZ types, as specified in: +# https://github.com/ethereum/consensus-specs/blob/v1.0.1/ssz/simple-serialize.md#serialization + +import + std/typetraits, + stew/[endians2, leb128, objects], + serialization, serialization/testing/tracing, + ./ssz_serialization/[codec, bitseqs, types] + +export + serialization, codec, types, bitseqs + +type + SszReader* = object + stream: InputStream + + SszWriter* = object + stream: OutputStream + + SizePrefixed*[T] = distinct T + SszMaxSizeExceeded* = object of SerializationError + + VarSizedWriterCtx = object + fixedParts: WriteCursor + offset: int + + FixedSizedWriterCtx = object + +serializationFormat SSZ + +SSZ.setReader SszReader +SSZ.setWriter SszWriter, PreferredOutput = seq[byte] + +template sizePrefixed*[TT](x: TT): untyped = + type T = TT + SizePrefixed[T](x) + +proc init*(T: type SszReader, + stream: InputStream): T = + T(stream: stream) + +proc writeFixedSized(s: var (OutputStream|WriteCursor), x: auto) {.raises: [Defect, IOError].} = + mixin toSszType + + when x is byte: + s.write x + elif x is bool: + s.write byte(ord(x)) + elif x is UintN: + when cpuEndian == bigEndian: + s.write toBytesLE(x) + else: + s.writeMemCopy x + elif x is array: + when x[0] is byte: + trs "APPENDING FIXED SIZE BYTES", x + s.write x + else: + for elem in x: + trs "WRITING FIXED SIZE ARRAY ELEMENT" + s.writeFixedSized toSszType(elem) + elif x is tuple|object: + enumInstanceSerializedFields(x, fieldName, field): + trs "WRITING FIXED SIZE FIELD", fieldName + s.writeFixedSized toSszType(field) + else: + unsupported x.type + +template writeOffset(cursor: var WriteCursor, offset: int) = + write cursor, toBytesLE(uint32 offset) + +template supports*(_: type SSZ, T: type): bool = + mixin toSszType + anonConst compiles(fixedPortionSize toSszType(declval T)) + +func init*(T: type SszWriter, stream: OutputStream): T = + result.stream = stream + +proc writeVarSizeType(w: var SszWriter, value: auto) {.gcsafe, raises: [Defect, IOError].} + +proc beginRecord*(w: var SszWriter, TT: type): auto = + type T = TT + when isFixedSize(T): + FixedSizedWriterCtx() + else: + const offset = when T is array|HashArray: len(T) * offsetSize + else: fixedPortionSize(T) + VarSizedWriterCtx(offset: offset, + fixedParts: w.stream.delayFixedSizeWrite(offset)) + +template writeField*(w: var SszWriter, + ctx: var auto, + fieldName: string, + field: auto) = + mixin toSszType + when ctx is FixedSizedWriterCtx: + writeFixedSized(w.stream, toSszType(field)) + else: + type FieldType = type toSszType(field) + + when isFixedSize(FieldType): + writeFixedSized(ctx.fixedParts, toSszType(field)) + else: + trs "WRITING OFFSET ", ctx.offset, " FOR ", fieldName + writeOffset(ctx.fixedParts, ctx.offset) + let initPos = w.stream.pos + trs "WRITING VAR SIZE VALUE OF TYPE ", name(FieldType) + when FieldType is BitList: + trs "BIT SEQ ", bytes(field) + writeVarSizeType(w, toSszType(field)) + ctx.offset += w.stream.pos - initPos + +template endRecord*(w: var SszWriter, ctx: var auto) = + when ctx is VarSizedWriterCtx: + finalize ctx.fixedParts + +proc writeSeq[T](w: var SszWriter, value: seq[T]) + {.raises: [Defect, IOError].} = + # Please note that `writeSeq` exists in order to reduce the code bloat + # produced from generic instantiations of the unique `List[N, T]` types. + when isFixedSize(T): + trs "WRITING LIST WITH FIXED SIZE ELEMENTS" + for elem in value: + w.stream.writeFixedSized toSszType(elem) + trs "DONE" + else: + trs "WRITING LIST WITH VAR SIZE ELEMENTS" + var offset = value.len * offsetSize + var cursor = w.stream.delayFixedSizeWrite offset + for elem in value: + cursor.writeFixedSized uint32(offset) + let initPos = w.stream.pos + w.writeVarSizeType toSszType(elem) + offset += w.stream.pos - initPos + finalize cursor + trs "DONE" + +proc writeVarSizeType(w: var SszWriter, value: auto) {.raises: [Defect, IOError].} = + trs "STARTING VAR SIZE TYPE" + + when value is HashArray|HashList: + writeVarSizeType(w, value.data) + elif value is SingleMemberUnion: + doAssert value.selector == 0'u8 + w.writeValue 0'u8 + w.writeValue value.value + elif value is List: + # We reduce code bloat by forwarding all `List` types to a general `seq[T]` proc. + writeSeq(w, asSeq value) + elif value is BitList: + # ATTENTION! We can reuse `writeSeq` only as long as our BitList type is implemented + # to internally match the binary representation of SSZ BitLists in memory. + writeSeq(w, bytes value) + elif value is object|tuple|array: + trs "WRITING OBJECT OR ARRAY" + var ctx = beginRecord(w, type value) + enumerateSubFields(value, field): + writeField w, ctx, astToStr(field), field + endRecord w, ctx + else: + unsupported type(value) + +proc writeValue*(w: var SszWriter, x: auto) {.gcsafe, raises: [Defect, IOError].} = + mixin toSszType + type T = type toSszType(x) + + when isFixedSize(T): + w.stream.writeFixedSized toSszType(x) + else: + w.writeVarSizeType toSszType(x) + +func sszSize*(value: auto): int {.gcsafe, raises: [Defect].} + +func sszSizeForVarSizeList[T](value: openArray[T]): int = + result = len(value) * offsetSize + for elem in value: + result += sszSize(toSszType elem) + +func sszSize*(value: auto): int {.gcsafe, raises: [Defect].} = + mixin toSszType + type T = type toSszType(value) + + when isFixedSize(T): + anonConst fixedPortionSize(T) + + elif T is array|List|HashList|HashArray: + type E = ElemType(T) + when isFixedSize(E): + len(value) * anonConst(fixedPortionSize(E)) + elif T is HashArray: + sszSizeForVarSizeList(value.data) + elif T is array: + sszSizeForVarSizeList(value) + else: + sszSizeForVarSizeList(asSeq value) + + elif T is BitList: + return len(bytes(value)) + + elif T is SingleMemberUnion: + sszSize(toSszType value.value) + 1 + + elif T is object|tuple: + result = anonConst fixedPortionSize(T) + enumInstanceSerializedFields(value, _{.used.}, field): + type FieldType = type toSszType(field) + when not isFixedSize(FieldType): + result += sszSize(toSszType field) + + else: + unsupported T + +proc writeValue*[T](w: var SszWriter, x: SizePrefixed[T]) {.raises: [Defect, IOError].} = + var cursor = w.stream.delayVarSizeWrite(Leb128.maxLen(uint64)) + let initPos = w.stream.pos + w.writeValue T(x) + let length = toBytes(uint64(w.stream.pos - initPos), Leb128) + cursor.finalWrite length.toOpenArray() + +proc readValue*(r: var SszReader, val: var auto) {. + raises: [Defect, MalformedSszError, SszSizeMismatchError, IOError].} = + mixin readSszBytes + type T = type val + when isFixedSize(T): + const minimalSize = fixedPortionSize(T) + if r.stream.readable(minimalSize): + readSszBytes(r.stream.read(minimalSize), val) + else: + raise newException(MalformedSszError, "SSZ input of insufficient size") + else: + # TODO(zah) Read the fixed portion first and precisely measure the + # size of the dynamic portion to consume the right number of bytes. + readSszBytes(r.stream.read(r.stream.len.get), val) + +proc readSszBytes*[T](data: openArray[byte], val: var T) {. + raises: [Defect, MalformedSszError, SszSizeMismatchError].} = + # Overload `readSszBytes` to perform custom operations on T after + # deserialization + mixin readSszValue + readSszValue(data, val) diff --git a/ssz_serialization.nimble b/ssz_serialization.nimble index 15b5be5..1441092 100644 --- a/ssz_serialization.nimble +++ b/ssz_serialization.nimble @@ -9,7 +9,12 @@ skipDirs = @["tests"] requires "nim >= 1.2.0", "serialization", - "stew" + "json_serialization", + "stew", + "stint", + "nimcrypto", + "blscurve", + "unittest2" proc test(env, path: string) = # Compilation language is controlled by TEST_LANG @@ -20,8 +25,8 @@ proc test(env, path: string) = if not dirExists "build": mkDir "build" exec "nim " & lang & " " & env & - " -r --hints:off --warnings:off " & path + " -r --hints:off --warnings:on " & path task test, "Run all tests": - test "--threads:off", "tests/test_all" - test "--threads:on", "tests/test_all" + test "--threads:off -d:PREFER_BLST_SHA256=false", "tests/test_all" + test "--threads:on -d:PREFER_BLST_SHA256=false", "tests/test_all" diff --git a/ssz_serialization/bitseqs.nim b/ssz_serialization/bitseqs.nim new file mode 100644 index 0000000..02173a5 --- /dev/null +++ b/ssz_serialization/bitseqs.nim @@ -0,0 +1,333 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + stew/[bitops2, endians2, byteutils, ptrops], + json_serialization + +export json_serialization + +type + Bytes = seq[byte] + + BitSeq* = distinct Bytes + ## The current design of BitSeq tries to follow precisely + ## the bitwise representation of the SSZ bitlists. + ## This is a relatively compact representation, but as + ## evident from the code below, many of the operations + ## are not trivial. + + BitArray*[bits: static int] = object + bytes*: array[(bits + 7) div 8, byte] + +func bitsLen*(bytes: openArray[byte]): int = + let + bytesCount = bytes.len + lastByte = bytes[bytesCount - 1] + markerPos = log2trunc(lastByte) + + bytesCount * 8 - (8 - markerPos) + +template len*(s: BitSeq): int = + bitsLen(Bytes s) + +template len*(a: BitArray): int = + a.bits + +func add*(s: var BitSeq, value: bool) = + let + lastBytePos = s.Bytes.len - 1 + lastByte = s.Bytes[lastBytePos] + + if (lastByte and byte(128)) == 0: + # There is at least one leading zero, so we have enough + # room to store the new bit + let markerPos = log2trunc(lastByte) + s.Bytes[lastBytePos].changeBit markerPos, value + s.Bytes[lastBytePos].setBit markerPos + 1 + else: + s.Bytes[lastBytePos].changeBit 7, value + s.Bytes.add byte(1) + +func toBytesLE(x: uint): array[sizeof(x), byte] = + # stew/endians2 supports explicitly sized uints only + when sizeof(uint) == 4: + static: doAssert sizeof(uint) == sizeof(uint32) + toBytesLE(x.uint32) + elif sizeof(uint) == 8: + static: doAssert sizeof(uint) == sizeof(uint64) + toBytesLE(x.uint64) + else: + static: doAssert false, "requires a 32-bit or 64-bit platform" + +func loadLEBytes(WordType: type, bytes: openArray[byte]): WordType = + # TODO: this is a temporary proc until the endians API is improved + var shift = 0 + for b in bytes: + result = result or (WordType(b) shl shift) + shift += 8 + +func storeLEBytes(value: SomeUnsignedInt, dst: var openArray[byte]) = + doAssert dst.len <= sizeof(value) + let bytesLE = toBytesLE(value) + copyMem(addr dst[0], unsafeAddr bytesLE[0], dst.len) + +template loopOverWords(lhs, rhs: BitSeq, + lhsIsVar, rhsIsVar: static bool, + WordType: type, + lhsBits, rhsBits, body: untyped) = + const hasRhs = astToStr(lhs) != astToStr(rhs) + + let bytesCount = len Bytes(lhs) + when hasRhs: doAssert len(Bytes(rhs)) == bytesCount + + var fullWordsCount = bytesCount div sizeof(WordType) + let lastWordSize = bytesCount mod sizeof(WordType) + + block: + var lhsWord: WordType + when hasRhs: + var rhsWord: WordType + var firstByteOfLastWord, lastByteOfLastWord: int + + # TODO: Returning a `var` value from an iterator is always safe due to + # the way inlining works, but currently the compiler reports an error + # when a local variable escapes. We have to cheat it with this location + # obfuscation through pointers: + template lhsBits: auto = (addr(lhsWord))[] + + when hasRhs: + template rhsBits: auto = (addr(rhsWord))[] + + template lastWordBytes(bitseq): auto = + Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord) + + template initLastWords = + lhsWord = loadLEBytes(WordType, lastWordBytes(lhs)) + when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs)) + + if lastWordSize == 0: + firstByteOfLastWord = bytesCount - sizeof(WordType) + lastByteOfLastWord = bytesCount - 1 + dec fullWordsCount + else: + firstByteOfLastWord = bytesCount - lastWordSize + lastByteOfLastWord = bytesCount - 1 + + initLastWords() + let markerPos = log2trunc(lhsWord) + when hasRhs: doAssert log2trunc(rhsWord) == markerPos + + lhsWord.clearBit markerPos + when hasRhs: rhsWord.clearBit markerPos + + body + + when lhsIsVar or rhsIsVar: + let + markerBit = uint(1 shl markerPos) + mask = markerBit - 1'u + + when lhsIsVar: + let lhsEndResult = (lhsWord and mask) or markerBit + storeLEBytes(lhsEndResult, lastWordBytes(lhs)) + + when rhsIsVar: + let rhsEndResult = (rhsWord and mask) or markerBit + storeLEBytes(rhsEndResult, lastWordBytes(rhs)) + + var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0]) + let lhsEndAddr = offset(lhsCurrAddr, fullWordsCount) + when hasRhs: + var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0]) + + while lhsCurrAddr < lhsEndAddr: + template lhsBits: auto = lhsCurrAddr[] + when hasRhs: + template rhsBits: auto = rhsCurrAddr[] + + body + + lhsCurrAddr = offset(lhsCurrAddr, 1) + when hasRhs: rhsCurrAddr = offset(rhsCurrAddr, 1) + +iterator words*(x: var BitSeq): var uint = + loopOverWords(x, x, true, false, uint, word, wordB): + yield word + +iterator words*(x: BitSeq): uint = + loopOverWords(x, x, false, false, uint, word, word): + yield word + +iterator words*(a, b: BitSeq): (uint, uint) = + loopOverWords(a, b, false, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) = + loopOverWords(a, b, true, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a, b: var BitSeq): (var uint, var uint) = + loopOverWords(a, b, true, true, uint, wordA, wordB): + yield (wordA, wordB) + +func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} = + doAssert pos < s.len + s.Bytes.getBit pos + +func `[]=`*(s: var BitSeq, pos: Natural, value: bool) {.inline.} = + doAssert pos < s.len + s.Bytes.changeBit pos, value + +func setBit*(s: var BitSeq, pos: Natural) {.inline.} = + doAssert pos < s.len + setBit s.Bytes, pos + +func clearBit*(s: var BitSeq, pos: Natural) {.inline.} = + doAssert pos < s.len + clearBit s.Bytes, pos + +func init*(T: type BitSeq, len: int): T = + result = BitSeq newSeq[byte](1 + len div 8) + Bytes(result).setBit len + +func init*(T: type BitArray): T = + # The default zero-initializatio is fine + discard + +template `[]`*(a: BitArray, pos: Natural): bool = + getBit a.bytes, pos + +template `[]=`*(a: var BitArray, pos: Natural, value: bool) = + changeBit a.bytes, pos, value + +template setBit*(a: var BitArray, pos: Natural) = + setBit a.bytes, pos + +template clearBit*(a: var BitArray, pos: Natural) = + clearBit a.bytes, pos + +# TODO: Submit this to the standard library as `cmp` +# At the moment, it doesn't work quite well because Nim selects +# the generic cmp[T] from the system module instead of choosing +# the openArray overload +func compareArrays[T](a, b: openArray[T]): int = + result = cmp(a.len, b.len) + if result != 0: return + + for i in 0 ..< a.len: + result = cmp(a[i], b[i]) + if result != 0: return + +template cmp*(a, b: BitSeq): int = + compareArrays(Bytes a, Bytes b) + +template `==`*(a, b: BitSeq): bool = + cmp(a, b) == 0 + +func `$`*(a: BitSeq | BitArray): string = + let length = a.len + result = newStringOfCap(2 + length) + result.add "0b" + for i in countdown(length - 1, 0): + result.add if a[i]: '1' else: '0' + +func incl*(tgt: var BitSeq, src: BitSeq) = + # Update `tgt` to include the bits of `src`, as if applying `or` to each bit + doAssert tgt.len == src.len + for tgtWord, srcWord in words(tgt, src): + tgtWord = tgtWord or srcWord + +func overlaps*(a, b: BitSeq): bool = + for wa, wb in words(a, b): + if (wa and wb) != 0: + return true + +func countOverlap*(a, b: BitSeq): int = + var res = 0 + for wa, wb in words(a, b): + res += countOnes(wa and wb) + res + +func isSubsetOf*(a, b: BitSeq): bool = + let alen = a.len + doAssert b.len == alen + for i in 0 ..< alen: + if a[i] and not b[i]: + return false + true + +func isZeros*(x: BitSeq): bool = + for w in words(x): + if w != 0: return false + return true + +func isZeros*(x: BitArray): bool = + x == default(type(x)) + +func countOnes*(x: BitSeq): int = + # Count the number of set bits + var res = 0 + for w in words(x): + res += w.countOnes() + res + +func clear*(x: var BitSeq) = + for w in words(x): + w = 0 + +func countZeros*(x: BitSeq): int = + x.len() - x.countOnes() + +template bytes*(x: BitSeq): untyped = + seq[byte](x) + +iterator items*(x: BitArray): bool = + for i in 0.. byte(1): + raise newException(MalformedSszError, "invalid boolean value") + data[0] == 1 + +func fromSszBytes*(T: type Digest, data: openArray[byte]): T {.raisesssz.} = + if data.len != sizeof(result.data): + raiseIncorrectSize T + copyMem(result.data.addr, unsafeAddr data[0], sizeof(result.data)) + +template fromSszBytes*(T: type BitSeq, bytes: openArray[byte]): auto = + BitSeq @bytes + +proc `[]`[T, U, V](s: openArray[T], x: HSlice[U, V]) {.error: + "Please don't use openArray's [] as it allocates a result sequence".} + +template checkForForbiddenBits(ResulType: type, + input: openArray[byte], + expectedBits: static int64) = + ## This checks if the input contains any bits set above the maximum + ## sized allowed. We only need to check the last byte to verify this: + const bitsInLastByte = (expectedBits mod 8) + when bitsInLastByte != 0: + # As an example, if there are 3 bits expected in the last byte, + # we calculate a bitmask equal to 11111000. If the input has any + # raised bits in range of the bitmask, this would be a violation + # of the size of the BitArray: + const forbiddenBitsMask = byte(byte(0xff) shl bitsInLastByte) + + if (input[^1] and forbiddenBitsMask) != 0: + raiseIncorrectSize ResulType + +func readSszValue*[T](input: openArray[byte], + val: var T) {.raisesssz.} = + mixin fromSszBytes, toSszType + + template readOffsetUnchecked(n: int): uint32 {.used.}= + fromSszBytes(uint32, input.toOpenArray(n, n + offsetSize - 1)) + + template readOffset(n: int): int {.used.} = + let offset = readOffsetUnchecked(n) + if offset > input.len.uint32: + raise newException(MalformedSszError, "SSZ list element offset points past the end of the input") + int(offset) + + when val is BitList: + if input.len == 0: + raise newException(MalformedSszError, "Invalid empty SSZ BitList value") + + # Since our BitLists have an in-memory representation that precisely + # matches their SSZ encoding, we can deserialize them as regular Lists: + const maxExpectedSize = (val.maxLen div 8) + 1 + type MatchingListType = List[byte, maxExpectedSize] + + when false: + # TODO: Nim doesn't like this simple type coercion, + # we'll rely on `cast` for now (see below) + readSszValue(input, MatchingListType val) + else: + static: + # As a sanity check, we verify that the coercion is accepted by the compiler: + doAssert MatchingListType(val) is MatchingListType + readSszValue(input, cast[ptr MatchingListType](addr val)[]) + + let resultBytesCount = len bytes(val) + + if bytes(val)[resultBytesCount - 1] == 0: + raise newException(MalformedSszError, "SSZ BitList is not properly terminated") + + if resultBytesCount == maxExpectedSize: + checkForForbiddenBits(T, input, val.maxLen + 1) + + elif val is HashList | HashArray: + readSszValue(input, val.data) + val.resetCache() + + elif val is List|array: + type E = type val[0] + + when E is byte: + val.setOutputSize input.len + if input.len > 0: + copyMem(addr val[0], unsafeAddr input[0], input.len) + + elif isFixedSize(E): + const elemSize = fixedPortionSize(E) + if input.len mod elemSize != 0: + var ex = new SszSizeMismatchError + ex.deserializedType = cstring typetraits.name(T) + ex.actualSszSize = input.len + ex.elementSize = elemSize + raise ex + val.setOutputSize input.len div elemSize + for i in 0 ..< val.len: + let offset = i * elemSize + readSszValue(input.toOpenArray(offset, offset + elemSize - 1), val[i]) + + else: + if input.len == 0: + # This is an empty list. + # The default initialization of the return value is fine. + val.setOutputSize 0 + return + elif input.len < offsetSize: + raise newException(MalformedSszError, "SSZ input of insufficient size") + + var offset = readOffset 0 + let resultLen = offset div offsetSize + + if resultLen == 0: + # If there are too many elements, other constraints detect problems + # (not monotonically increasing, past end of input, or last element + # not matching up with its nextOffset properly) + raise newException(MalformedSszError, "SSZ list incorrectly encoded of zero length") + + val.setOutputSize resultLen + for i in 1 ..< resultLen: + let nextOffset = readOffset(i * offsetSize) + if nextOffset <= offset: + raise newException(MalformedSszError, "SSZ list element offsets are not monotonically increasing") + else: + readSszValue(input.toOpenArray(offset, nextOffset - 1), val[i - 1]) + offset = nextOffset + + readSszValue(input.toOpenArray(offset, input.len - 1), val[resultLen - 1]) + + elif val is SingleMemberUnion: + readSszValue(input.toOpenArray(0, 0), val.selector) + if val.selector != 0'u8: + raise newException(MalformedSszError, "SingleMemberUnion selector must be 0") + readSszValue(input.toOpenArray(1, input.len - 1), val.value) + + elif val is UintN|bool: + val = fromSszBytes(T, input) + + elif val is BitArray: + if sizeof(val) != input.len: + raiseIncorrectSize(T) + checkForForbiddenBits(T, input, val.bits) + copyMem(addr val.bytes[0], unsafeAddr input[0], input.len) + + elif val is object|tuple: + let inputLen = uint32 input.len + const minimallyExpectedSize = uint32 fixedPortionSize(T) + + if inputLen < minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ input of insufficient size") + + enumInstanceSerializedFields(val, fieldName, field): + const boundingOffsets = getFieldBoundingOffsets(T, fieldName) + + # type FieldType = type field # buggy + # For some reason, Nim gets confused about the alias here. This could be a + # generics caching issue caused by the use of distinct types. Such an + # issue is very scary in general. + # The bug can be seen with the two List[uint64, N] types that exist in + # the spec, with different N. + + type SszType = type toSszType(declval type(field)) + + when isFixedSize(SszType): + const + startOffset = boundingOffsets[0] + endOffset = boundingOffsets[1] + else: + let + startOffset = readOffsetUnchecked(boundingOffsets[0]) + endOffset = if boundingOffsets[1] == -1: inputLen + else: readOffsetUnchecked(boundingOffsets[1]) + + when boundingOffsets.isFirstOffset: + if startOffset != minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ object dynamic portion starts at invalid offset") + + if startOffset > endOffset: + raise newException(MalformedSszError, "SSZ field offsets are not monotonically increasing") + elif endOffset > inputLen: + raise newException(MalformedSszError, "SSZ field offset points past the end of the input") + elif startOffset < minimallyExpectedSize: + raise newException(MalformedSszError, "SSZ field offset points outside bounding offsets") + + # TODO The extra type escaping here is a work-around for a Nim issue: + when type(field) is type(SszType): + readSszValue( + input.toOpenArray(int(startOffset), int(endOffset - 1)), + field) + else: + field = fromSszBytes( + type(field), + input.toOpenArray(int(startOffset), int(endOffset - 1))) + + else: + unsupported T + +# Identity conversions for core SSZ types + +template toSszType*(v: auto): auto = + ## toSszType converts a given value into one of the primitive types supported + ## by SSZ - to add support for a custom type (for example a `distinct` type), + ## add an overload for `toSszType` which converts it to one of the `SszType` + ## types, as well as a `fromSszBytes`. + type T = type(v) + when T is SszType: + when T is Digest: + v.data + else: + v + else: + unsupported T diff --git a/ssz_serialization/dynamic_navigator.nim b/ssz_serialization/dynamic_navigator.nim new file mode 100644 index 0000000..8b13326 --- /dev/null +++ b/ssz_serialization/dynamic_navigator.nim @@ -0,0 +1,163 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} +{.pragma: raisesssz, raises: [Defect, IOError, MalformedSszError, SszSizeMismatchError].} + +import + std/[strutils, parseutils], + stew/objects, faststreams/outputs, json_serialization/writer, + ./codec, ./types, ./navigator + +export + codec, navigator, types + +type + ObjKind = enum + Record + Indexable + LeafValue + + FieldInfo = ref object + name: string + fieldType: TypeInfo + navigator: proc (m: MemRange): MemRange {. gcsafe + noSideEffect + raisesssz } + TypeInfo = ref object + case kind: ObjKind + of Record: + fields: seq[FieldInfo] + of Indexable: + elemType: TypeInfo + navigator: proc (m: MemRange, idx: int): MemRange {. gcsafe + noSideEffect + raisesssz } + else: + discard + + jsonPrinter: proc (m: MemRange, + outStream: OutputStream, + pretty: bool) {.gcsafe, raisesssz.} + + DynamicSszNavigator* = object + m: MemRange + typ: TypeInfo + +proc jsonPrinterImpl[T](m: MemRange, outStream: OutputStream, pretty: bool) {.raisesssz.} = + var typedNavigator = sszMount(m, T) + var jsonWriter = Json.Writer.init(outStream, pretty) + # TODO: it should be possible to serialize the navigator object + # without dereferencing it (to avoid the intermediate value). + writeValue(jsonWriter, typedNavigator[]) + +func findField(fields: seq[FieldInfo], name: string): FieldInfo = + # TODO: Replace this with a binary search? + # Will it buy us anything when there are only few fields? + for field in fields: + if field.name == name: + return field + +func indexableNavigatorImpl[T](m: MemRange, idx: int): MemRange {.raisesssz.} = + var typedNavigator = sszMount(m, T) + getMemRange(typedNavigator[idx]) + +func fieldNavigatorImpl[RecordType; FieldType; + fieldName: static string](m: MemRange): MemRange {.raisesssz.} = + # TODO: Make sure this doesn't fail with a Defect when + # navigating to an inactive field in a case object. + var typedNavigator = sszMount(m, RecordType) + getMemRange navigateToField(typedNavigator, fieldName, FieldType) + +func genTypeInfo(T: type): TypeInfo {.gcsafe.} + +proc typeInfo*(T: type): TypeInfo = + let res {.global.} = genTypeInfo(T) + + # TODO This will be safer if the RTTI object use only manually + # managed memory, but the `fields` sequence right now make + # things harder. We'll need to switch to a different seq type. + {.gcsafe, noSideEffect.}: res + +func genTypeInfo(T: type): TypeInfo = + mixin toSszType, enumAllSerializedFields + type SszType = type toSszType(declval T) + result = when type(SszType) isnot T: + TypeInfo(kind: LeafValue) + elif T is object: + var fields: seq[FieldInfo] + enumAllSerializedFields(T): + fields.add FieldInfo(name: fieldName, + fieldType: typeInfo(FieldType), + navigator: fieldNavigatorImpl[T, FieldType, fieldName]) + TypeInfo(kind: Record, fields: fields) + elif T is seq|array: + TypeInfo(kind: Indexable, + elemType: typeInfo(ElemType(T)), + navigator: indexableNavigatorImpl[T]) + else: + TypeInfo(kind: LeafValue) + + result.jsonPrinter = jsonPrinterImpl[T] + +func `[]`*(n: DynamicSszNavigator, idx: int): DynamicSszNavigator {.raisesssz.} = + doAssert n.typ.kind == Indexable + DynamicSszNavigator(m: n.typ.navigator(n.m, idx), typ: n.typ.elemType) + +func navigate*(n: DynamicSszNavigator, path: string): DynamicSszNavigator {. + raises: [Defect, KeyError, IOError, MalformedSszError, SszSizeMismatchError, ValueError] .} = + case n.typ.kind + of Record: + let fieldInfo = n.typ.fields.findField(path) + if fieldInfo == nil: + raise newException(KeyError, "Unrecogned field name: " & path) + return DynamicSszNavigator(m: fieldInfo.navigator(n.m), + typ: fieldInfo.fieldType) + of Indexable: + var idx: int + let consumed = parseInt(path, idx) + if consumed == 0 or idx < 0: + raise newException(KeyError, "Indexing should be done with natural numbers") + return n[idx] + else: + doAssert false, "Navigation should be terminated once you reach a leaf value" + +template navigatePathImpl(nav, iterabalePathFragments: untyped) = + result = nav + for pathFragment in iterabalePathFragments: + if pathFragment.len == 0: + continue + result = result.navigate(pathFragment) + if result.typ.kind == LeafValue: + return + +func navigatePath*(n: DynamicSszNavigator, path: string): DynamicSszNavigator {. + raises: [Defect, IOError, ValueError, MalformedSszError, SszSizeMismatchError] .} = + navigatePathImpl n, split(path, '/') + +func navigatePath*(n: DynamicSszNavigator, path: openArray[string]): DynamicSszNavigator {. + raises: [Defect, IOError, ValueError, MalformedSszError, SszSizeMismatchError] .} = + navigatePathImpl n, path + +func init*(T: type DynamicSszNavigator, + bytes: openArray[byte], Navigated: type): T = + T(m: MemRange(startAddr: unsafeAddr bytes[0], length: bytes.len), + typ: typeInfo(Navigated)) + +proc writeJson*(n: DynamicSszNavigator, outStream: OutputStream, pretty = true) {.raisesssz.} = + n.typ.jsonPrinter(n.m, outStream, pretty) + +func toJson*(n: DynamicSszNavigator, pretty = true): string {.raisesssz.} = + var outStream = memoryOutput() + {.noSideEffect.}: + # We are assuming that there are no side-effects here + # because we are using a `memoryOutput`. The computed + # side-effects are coming from the fact that the dynamic + # dispatch mechanisms used in faststreams may be reading + # from a file or a network device. + writeJson(n, outStream, pretty) + outStream.getOutput(string) diff --git a/ssz_serialization/merkleization.nim b/ssz_serialization/merkleization.nim new file mode 100644 index 0000000..9cf220c --- /dev/null +++ b/ssz_serialization/merkleization.nim @@ -0,0 +1,734 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +# This module contains the parts necessary to create a merkle hash from the core +# SSZ types outlined in the spec: +# https://github.com/ethereum/consensus-specs/blob/v1.0.1/ssz/simple-serialize.md#merkleization + +{.push raises: [Defect].} + +import + stew/[bitops2, endians2, ptrops], + stew/ranges/ptr_arith, nimcrypto/[hash, sha2], + serialization/testing/tracing, + "."/[bitseqs, codec, types] + +const PREFER_BLST_SHA256* {.booldefine.} = true + +when PREFER_BLST_SHA256: + import blscurve + when BLS_BACKEND == BLST: + const USE_BLST_SHA256 = true + else: + const USE_BLST_SHA256 = false +else: + const USE_BLST_SHA256 = false + +export + codec, bitseqs, types + +when hasSerializationTracing: + import stew/byteutils, typetraits + +const + zero64 = default array[64, byte] + bitsPerChunk = bytesPerChunk * 8 + +func binaryTreeHeight*(totalElements: Limit): int = + bitWidth nextPow2(uint64 totalElements) + +type + # TODO Figure out what would be the right type for this. + # It probably fits in uint16 for all practical purposes. + GeneralizedIndex* = uint32 + + SszMerkleizerImpl = object + combinedChunks: ptr UncheckedArray[Digest] + totalChunks: uint64 + topIndex: int + + SszMerkleizer*[limit: static[Limit]] = object + combinedChunks: ref array[binaryTreeHeight limit, Digest] + impl: SszMerkleizerImpl + +template chunks*(m: SszMerkleizerImpl): openArray[Digest] = + m.combinedChunks.toOpenArray(0, m.topIndex) + +template getChunkCount*(m: SszMerkleizer): uint64 = + m.impl.totalChunks + +template getCombinedChunks*(m: SszMerkleizer): openArray[Digest] = + toOpenArray(m.impl.combinedChunks, 0, m.impl.topIndex) + +when USE_BLST_SHA256: + export blscurve.update + type DigestCtx* = BLST_SHA256_CTX +else: + type DigestCtx* = sha2.sha256 + +template computeDigest*(body: untyped): Digest = + ## This little helper will init the hash function and return the sliced + ## hash: + ## let hashOfData = withHash: h.update(data) + when nimvm: + # In SSZ, computeZeroHashes require compile-time SHA256 + block: + var h {.inject.}: sha256 + init(h) + body + finish(h) + else: + when USE_BLST_SHA256: + block: + var h {.inject, noInit.}: DigestCtx + init(h) + body + var res {.noInit.}: Digest + finalize(res.data, h) + res + else: + block: + var h {.inject, noInit.}: DigestCtx + init(h) + body + finish(h) + +func digest(a, b: openArray[byte]): Digest = + result = computeDigest: + trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b) + trs toHex(a) + trs toHex(b) + + h.update a + h.update b + trs "HASH RESULT ", result + +func digest(a, b, c: openArray[byte]): Digest = + result = computeDigest: + trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b), " ", toHex(c) + + h.update a + h.update b + h.update c + trs "HASH RESULT ", result + +func mergeBranches(existing: Digest, newData: openArray[byte]): Digest = + trs "MERGING BRANCHES OPEN ARRAY" + + let paddingBytes = bytesPerChunk - newData.len + digest(existing.data, newData, zero64.toOpenArray(0, paddingBytes - 1)) + +template mergeBranches(existing: Digest, newData: array[32, byte]): Digest = + trs "MERGING BRANCHES ARRAY" + digest(existing.data, newData) + +template mergeBranches(a, b: Digest): Digest = + trs "MERGING BRANCHES DIGEST" + digest(a.data, b.data) + +func computeZeroHashes: array[sizeof(Limit) * 8, Digest] = + result[0] = Digest() + for i in 1 .. result.high: + result[i] = mergeBranches(result[i - 1], result[i - 1]) + +const zeroHashes* = computeZeroHashes() + +func addChunk*(merkleizer: var SszMerkleizerImpl, data: openArray[byte]) = + doAssert data.len > 0 and data.len <= bytesPerChunk + + if getBitLE(merkleizer.totalChunks, 0): + var hash = mergeBranches(merkleizer.combinedChunks[0], data) + + for i in 1 .. merkleizer.topIndex: + trs "ITERATING" + if getBitLE(merkleizer.totalChunks, i): + trs "CALLING MERGE BRANCHES" + hash = mergeBranches(merkleizer.combinedChunks[i], hash) + else: + trs "WRITING FRESH CHUNK AT ", i, " = ", hash + merkleizer.combinedChunks[i] = hash + break + else: + let paddingBytes = bytesPerChunk - data.len + + merkleizer.combinedChunks[0].data[0.. 0 and merkleizer.topIndex > 0 + + let proofHeight = merkleizer.topIndex + 1 + result = newSeq[Digest](chunks.len * proofHeight) + + if chunks.len == 1: + merkleizer.addChunkAndGenMerkleProof(chunks[0], result) + return + + let newTotalChunks = merkleizer.totalChunks + chunks.len.uint64 + + var + # A perfect binary tree will take either `chunks.len * 2` values if the + # number of elements in the base layer is odd and `chunks.len * 2 - 1` + # otherwise. Each row may also need a single extra element at most if + # it must be combined with the existing values in the Merkleizer: + merkleTree = newSeqOfCap[Digest](chunks.len + merkleizer.topIndex) + inRowIdx = merkleizer.totalChunks + postUpdateInRowIdx = newTotalChunks + zeroMixed = false + + template writeResult(chunkIdx, level: int, chunk: Digest) = + result[chunkIdx * proofHeight + level] = chunk + + # We'll start by generating the first row of the merkle tree. + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeResult(0, 0, merkleizer.combinedChunks[0]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[0], chunks[0]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[0] = chunks[^1] + + while currPairEnd < chunks.len: + writeResult(currPairEnd - 1, 0, chunks[currPairEnd]) + writeResult(currPairEnd, 0, chunks[currPairEnd - 1]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + chunks[currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < chunks.len: + zeroMixed = true + writeResult(currPairEnd - 1, 0, zeroHashes[0]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + zeroHashes[0]) + var + level = 0 + baseChunksPerElement = 1 + treeRowStart = 0 + rowLen = merkleTree.len + + template writeProofs(rowChunkIdx: int, hash: Digest) = + let + startAbsIdx = (inRowIdx.int + rowChunkIdx) * baseChunksPerElement + endAbsIdx = startAbsIdx + baseChunksPerElement + startResIdx = max(startAbsIdx - merkleizer.totalChunks.int, 0) + endResIdx = min(endAbsIdx - merkleizer.totalChunks.int, chunks.len) + + for resultPos in startResIdx ..< endResIdx: + writeResult(resultPos, level, hash) + + if rowLen > 1: + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + postUpdateInRowIdx = postUpdateInRowIdx div 2 + + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeProofs(0, merkleizer.combinedChunks[level]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[level], + merkleTree[treeRowStart]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], level + 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[level] = merkleTree[treeRowStart + rowLen - + ord(zeroMixed) - 1] + while currPairEnd < rowLen: + writeProofs(currPairEnd - 1, merkleTree[treeRowStart + currPairEnd]) + writeProofs(currPairEnd, merkleTree[treeRowStart + currPairEnd - 1]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + merkleTree[treeRowStart + currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < rowLen: + zeroMixed = true + writeProofs(currPairEnd - 1, zeroHashes[level]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + zeroHashes[level]) + + treeRowStart += rowLen + rowLen = merkleTree.len - treeRowStart + + if rowLen == 1: + break + + doAssert rowLen == 1 + + if (inRowIdx and 2) != 0: + merkleizer.completeStartedChunk( + mergeBranches(merkleizer.combinedChunks[level + 1], merkleTree[^1]), + level + 2) + + if (not zeroMixed) and (postUpdateInRowIdx and 2) != 0: + merkleizer.combinedChunks[level + 1] = merkleTree[^1] + + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + + let hash = if getBitLE(merkleizer.totalChunks, level): + merkleizer.combinedChunks[level] + else: + zeroHashes[level] + + writeProofs(0, hash) + + merkleizer.totalChunks = newTotalChunks + +proc init*(S: type SszMerkleizer): S = + new result.combinedChunks + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(result.limit) - 1, + totalChunks: 0) + +proc init*(S: type SszMerkleizer, + combinedChunks: openArray[Digest], + totalChunks: uint64): S = + new result.combinedChunks + result.combinedChunks[][0 ..< combinedChunks.len] = combinedChunks + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(result.limit) - 1, + totalChunks: totalChunks) + +proc copy*[L: static[Limit]](cloned: SszMerkleizer[L]): SszMerkleizer[L] = + new result.combinedChunks + result.combinedChunks[] = cloned.combinedChunks[] + result.impl = SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]]( + addr result.combinedChunks[][0]), + topIndex: binaryTreeHeight(L) - 1, + totalChunks: cloned.totalChunks) + +template addChunksAndGenMerkleProofs*( + merkleizer: var SszMerkleizer, + chunks: openArray[Digest]): seq[Digest] = + addChunksAndGenMerkleProofs(merkleizer.impl, chunks) + +template addChunk*(merkleizer: var SszMerkleizer, data: openArray[byte]) = + addChunk(merkleizer.impl, data) + +template totalChunks*(merkleizer: SszMerkleizer): uint64 = + merkleizer.impl.totalChunks + +template getFinalHash*(merkleizer: SszMerkleizer): Digest = + merkleizer.impl.getFinalHash + +template createMerkleizer*(totalElements: static Limit): SszMerkleizerImpl = + trs "CREATING A MERKLEIZER FOR ", totalElements + + const treeHeight = binaryTreeHeight totalElements + var combinedChunks {.noInit.}: array[treeHeight, Digest] + + SszMerkleizerImpl( + combinedChunks: cast[ptr UncheckedArray[Digest]](addr combinedChunks), + topIndex: treeHeight - 1, + totalChunks: 0) + +func getFinalHash*(merkleizer: SszMerkleizerImpl): Digest = + if merkleizer.totalChunks == 0: + return zeroHashes[merkleizer.topIndex] + + let + bottomHashIdx = firstOne(merkleizer.totalChunks) - 1 + submittedChunksHeight = bitWidth(merkleizer.totalChunks - 1) + topHashIdx = merkleizer.topIndex + + trs "BOTTOM HASH ", bottomHashIdx + trs "SUBMITTED HEIGHT ", submittedChunksHeight + trs "TOP HASH IDX ", topHashIdx + + if bottomHashIdx != submittedChunksHeight: + # Our tree is not finished. We must complete the work in progress + # branches and then extend the tree to the right height. + result = mergeBranches(merkleizer.combinedChunks[bottomHashIdx], + zeroHashes[bottomHashIdx]) + + for i in bottomHashIdx + 1 ..< topHashIdx: + if getBitLE(merkleizer.totalChunks, i): + result = mergeBranches(merkleizer.combinedChunks[i], result) + trs "COMBINED" + else: + result = mergeBranches(result, zeroHashes[i]) + trs "COMBINED WITH ZERO" + + elif bottomHashIdx == topHashIdx: + # We have a perfect tree (chunks == 2**n) at just the right height! + result = merkleizer.combinedChunks[bottomHashIdx] + else: + # We have a perfect tree of user chunks, but we have more work to + # do - we must extend it to reach the desired height + result = mergeBranches(merkleizer.combinedChunks[bottomHashIdx], + zeroHashes[bottomHashIdx]) + + for i in bottomHashIdx + 1 ..< topHashIdx: + result = mergeBranches(result, zeroHashes[i]) + +func mixInLength*(root: Digest, length: int): Digest = + var dataLen: array[32, byte] + dataLen[0..<8] = uint64(length).toBytesLE() + mergeBranches(root, dataLen) + +func hash_tree_root*(x: auto): Digest {.gcsafe, raises: [Defect].} + +template merkleizeFields(totalElements: static Limit, body: untyped): Digest = + var merkleizer {.inject.} = createMerkleizer(totalElements) + + template addField(field) = + let hash = hash_tree_root(field) + trs "MERKLEIZING FIELD ", astToStr(field), " = ", hash + addChunk(merkleizer, hash.data) + trs "CHUNK ADDED" + + body + + getFinalHash(merkleizer) + +template writeBytesLE(chunk: var array[bytesPerChunk, byte], atParam: int, + val: UintN) = + let at = atParam + chunk[at ..< at + sizeof(val)] = toBytesLE(val) + +func chunkedHashTreeRootForBasicTypes[T](merkleizer: var SszMerkleizerImpl, + arr: openArray[T]): Digest = + static: + doAssert T is BasicType + doAssert bytesPerChunk mod sizeof(T) == 0 + + if arr.len == 0: + return getFinalHash(merkleizer) + + when sizeof(T) == 1 or cpuEndian == littleEndian: + var + remainingBytes = when sizeof(T) == 1: arr.len + else: arr.len * sizeof(T) + pos = cast[ptr byte](unsafeAddr arr[0]) + + while remainingBytes >= bytesPerChunk: + merkleizer.addChunk(makeOpenArray(pos, bytesPerChunk)) + pos = offset(pos, bytesPerChunk) + remainingBytes -= bytesPerChunk + + if remainingBytes > 0: + merkleizer.addChunk(makeOpenArray(pos, remainingBytes)) + + else: + const valuesPerChunk = bytesPerChunk div sizeof(T) + + var writtenValues = 0 + + var chunk: array[bytesPerChunk, byte] + while writtenValues < arr.len - valuesPerChunk: + for i in 0 ..< valuesPerChunk: + chunk.writeBytesLE(i * sizeof(T), arr[writtenValues + i]) + merkleizer.addChunk chunk + inc writtenValues, valuesPerChunk + + let remainingValues = arr.len - writtenValues + if remainingValues > 0: + var lastChunk: array[bytesPerChunk, byte] + for i in 0 ..< remainingValues: + lastChunk.writeBytesLE(i * sizeof(T), arr[writtenValues + i]) + merkleizer.addChunk lastChunk + + getFinalHash(merkleizer) + +func bitListHashTreeRoot(merkleizer: var SszMerkleizerImpl, x: BitSeq): Digest = + # TODO: Switch to a simpler BitList representation and + # replace this with `chunkedHashTreeRoot` + var + totalBytes = bytes(x).len + lastCorrectedByte = bytes(x)[^1] + + if lastCorrectedByte == byte(1): + if totalBytes == 1: + # This is an empty bit list. + # It should be hashed as a tree containing all zeros: + return mergeBranches(zeroHashes[merkleizer.topIndex], + zeroHashes[0]) # this is the mixed length + + totalBytes -= 1 + lastCorrectedByte = bytes(x)[^2] + else: + let markerPos = log2trunc(lastCorrectedByte) + lastCorrectedByte.clearBit(markerPos) + + var + bytesInLastChunk = totalBytes mod bytesPerChunk + fullChunks = totalBytes div bytesPerChunk + + if bytesInLastChunk == 0: + fullChunks -= 1 + bytesInLastChunk = 32 + + for i in 0 ..< fullChunks: + let + chunkStartPos = i * bytesPerChunk + chunkEndPos = chunkStartPos + bytesPerChunk - 1 + + merkleizer.addChunk bytes(x).toOpenArray(chunkStartPos, chunkEndPos) + + var + lastChunk: array[bytesPerChunk, byte] + chunkStartPos = fullChunks * bytesPerChunk + + for i in 0 .. bytesInLastChunk - 2: + lastChunk[i] = bytes(x)[chunkStartPos + i] + + lastChunk[bytesInLastChunk - 1] = lastCorrectedByte + + merkleizer.addChunk lastChunk.toOpenArray(0, bytesInLastChunk - 1) + let contentsHash = merkleizer.getFinalHash + mixInLength contentsHash, x.len + +func maxChunksCount(T: type, maxLen: Limit): Limit = + when T is BitList|BitArray: + (maxLen + bitsPerChunk - 1) div bitsPerChunk + elif T is array|List: + maxChunkIdx(ElemType(T), maxLen) + else: + unsupported T # This should never happen + +func hashTreeRootAux[T](x: T): Digest = + when T is bool|char: + result.data[0] = byte(x) + elif T is UintN: + when cpuEndian == bigEndian: + result.data[0..= byteLen: + zeroHashes[1] + else: + let + nbytes = min(byteLen - byteIdx, 64) + padding = 64 - nbytes + + digest( + toOpenArray(bytes, int(byteIdx), int(byteIdx + nbytes - 1)), + toOpenArray(zero64, 0, int(padding - 1))) + else: + if chunkIdx + 1 > x.data.len(): + zeroHashes[x.maxDepth] + elif chunkIdx + 1 == x.data.len(): + mergeBranches( + hash_tree_root(x.data[chunkIdx]), + Digest()) + else: + mergeBranches( + hash_tree_root(x.data[chunkIdx]), + hash_tree_root(x.data[chunkIdx + 1])) + +template mergedHash(x: HashList|HashArray, vIdxParam: int64): Digest = + # The merged hash of the data at `vIdx` and `vIdx + 1` + let vIdx = vIdxParam + if vIdx >= x.maxChunks: + let dataIdx = vIdx - x.maxChunks + mergedDataHash(x, dataIdx) + else: + mergeBranches( + hashTreeRootCached(x, vIdx), + hashTreeRootCached(x, vIdx + 1)) + +func hashTreeRootCached*(x: HashList, vIdx: int64): Digest = + doAssert vIdx >= 1, "Only valid for flat merkle tree indices" + + let + layer = layer(vIdx) + idxInLayer = vIdx - (1'i64 shl layer) + layerIdx = idxInlayer + x.indices[layer] + + trs "GETTING ", vIdx, " ", layerIdx, " ", layer, " ", x.indices.len + + doAssert layer < x.maxDepth + if layerIdx >= x.indices[layer + 1]: + trs "ZERO ", x.indices[layer], " ", x.indices[layer + 1] + zeroHashes[x.maxDepth - layer] + else: + if not isCached(x.hashes[layerIdx]): + # TODO oops. so much for maintaining non-mutability. + let px = unsafeAddr x + + trs "REFRESHING ", vIdx, " ", layerIdx, " ", layer + + px[].hashes[layerIdx] = mergedHash(x, vIdx * 2) + else: + trs "CACHED ", layerIdx + + x.hashes[layerIdx] + +func hashTreeRootCached*(x: HashArray, vIdx: int): Digest = + doAssert vIdx >= 1, "Only valid for flat merkle tree indices" + + if not isCached(x.hashes[vIdx]): + # TODO oops. so much for maintaining non-mutability. + let px = unsafeAddr x + + px[].hashes[vIdx] = mergedHash(x, vIdx * 2) + + return x.hashes[vIdx] + +func hashTreeRootCached*(x: HashArray): Digest = + hashTreeRootCached(x, 1) # Array does not use idx 0 + +func hashTreeRootCached*(x: HashList): Digest = + if x.data.len == 0: + mergeBranches( + zeroHashes[x.maxDepth], + zeroHashes[0]) # mixInLength with 0! + else: + if not isCached(x.hashes[0]): + # TODO oops. so much for maintaining non-mutability. + let px = unsafeAddr x + px[].hashes[0] = mixInLength(hashTreeRootCached(x, 1), x.data.len) + + x.hashes[0] + +func hash_tree_root*(x: auto): Digest = + trs "STARTING HASH TREE ROOT FOR TYPE ", name(type(x)) + mixin toSszType + + result = + when x is HashArray|HashList: + hashTreeRootCached(x) + elif x is List|BitList: + hashTreeRootList(x) + else: + hashTreeRootAux toSszType(x) + + trs "HASH TREE ROOT FOR ", name(type x), " = ", "0x", $result diff --git a/ssz_serialization/navigator.nim b/ssz_serialization/navigator.nim new file mode 100644 index 0000000..414afdd --- /dev/null +++ b/ssz_serialization/navigator.nim @@ -0,0 +1,143 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} +{.pragma: raisesssz, raises: [Defect, MalformedSszError, SszSizeMismatchError].} + +import + stew/[ptrops, objects], stew/ranges/ptr_arith, + ./codec, ./types + +export codec, types + +type + MemRange* = object + startAddr*: ptr byte + length*: int + + SszNavigator*[T] = object + m: MemRange + +func sszMount*(data: openArray[byte], T: type): SszNavigator[T] = + let startAddr = unsafeAddr data[0] + SszNavigator[T](m: MemRange(startAddr: startAddr, length: data.len)) + +func sszMount*(data: openArray[char], T: type): SszNavigator[T] = + let startAddr = cast[ptr byte](unsafeAddr data[0]) + SszNavigator[T](m: MemRange(startAddr: startAddr, length: data.len)) + +template sszMount*(data: MemRange, T: type): SszNavigator[T] = + SszNavigator[T](m: data) + +template getMemRange*(n: SszNavigator): MemRange = + # Please note that this accessor was created intentionally. + # We don't want to expose the `m` field, because the navigated + # type may have a field by that name. We wan't any dot field + # access to be redirected to the navigated type. + # For this reason, this template should always be used with + # the function call syntax `getMemRange(n)`. + n.m + +template checkBounds(m: MemRange, offset: int) = + if offset > m.length: + raise newException(MalformedSszError, "Malformed SSZ") + +template toOpenArray(m: MemRange): auto = + makeOpenArray(m.startAddr, m.length) + +func navigateToField*[T](n: SszNavigator[T], + fieldName: static string, + FieldType: type): SszNavigator[FieldType] {.raisesssz.} = + mixin toSszType + type SszFieldType = type toSszType(declval FieldType) + + const boundingOffsets = getFieldBoundingOffsets(T, fieldName) + checkBounds(n.m, boundingOffsets[1]) + + when isFixedSize(SszFieldType): + SszNavigator[FieldType](m: MemRange( + startAddr: offset(n.m.startAddr, boundingOffsets[0]), + length: boundingOffsets[1] - boundingOffsets[0])) + else: + template readOffset(off): int = + int fromSszBytes(uint32, makeOpenArray(offset(n.m.startAddr, off), + sizeof(uint32))) + let + startOffset = readOffset boundingOffsets[0] + endOffset = when boundingOffsets[1] == -1: n.m.length + else: readOffset boundingOffsets[1] + + if endOffset < startOffset or endOffset > n.m.length: + raise newException(MalformedSszError, "Incorrect offset values") + + SszNavigator[FieldType](m: MemRange( + startAddr: offset(n.m.startAddr, startOffset), + length: endOffset - startOffset)) + +template `.`*[T](n: SszNavigator[T], field: untyped): auto = + type RecType = T + type FieldType = type(default(RecType).field) + navigateToField(n, astToStr(field), FieldType) + +func indexVarSizeList(m: MemRange, idx: int): MemRange {.raisesssz.} = + template readOffset(pos): int = + int fromSszBytes(uint32, makeOpenArray(offset(m.startAddr, pos), offsetSize)) + + let offsetPos = offsetSize * idx + checkBounds(m, offsetPos + offsetSize) + + let firstOffset = readOffset 0 + let listLen = firstOffset div offsetSize + + if idx >= listLen: + # TODO: Use a RangeError here? + # This would require the user to check the `len` upfront + raise newException(MalformedSszError, "Indexing past the end") + + let elemPos = readOffset offsetPos + checkBounds(m, elemPos) + + let endPos = if idx < listLen - 1: + let nextOffsetPos = offsetPos + offsetSize + # TODO. Is there a way to remove this bounds check? + checkBounds(m, nextOffsetPos + offsetSize) + readOffset(offsetPos + nextOffsetPos) + else: + m.length + + MemRange(startAddr: m.startAddr.offset(elemPos), length: endPos - elemPos) + +template indexList(n, idx, T: untyped): untyped = + type R = T + mixin toSszType + type ElemType = type toSszType(declval R) + when isFixedSize(ElemType): + const elemSize = fixedPortionSize(ElemType) + let elemPos = idx * elemSize + checkBounds(n.m, elemPos + elemSize) + SszNavigator[R](m: MemRange(startAddr: offset(n.m.startAddr, elemPos), + length: elemSize)) + else: + SszNavigator[R](m: indexVarSizeList(n.m, idx)) + +template `[]`*[T](n: SszNavigator[seq[T]], idx: int): SszNavigator[T] = + indexList n, idx, T + +template `[]`*[R, T](n: SszNavigator[array[R, T]], idx: int): SszNavigator[T] = + indexList(n, idx, T) + +func `[]`*[T](n: SszNavigator[T]): T {.raisesssz.} = + mixin toSszType, fromSszBytes + type SszRepr = type toSszType(declval T) + when type(SszRepr) is type(T) or T is List: + readSszValue(toOpenArray(n.m), result) + else: + fromSszBytes(T, toOpenArray(n.m)) + +converter derefNavigator*[T](n: SszNavigator[T]): T {.raisesssz.} = + n[] + diff --git a/ssz_serialization/types.nim b/ssz_serialization/types.nim new file mode 100644 index 0000000..0e2dc47 --- /dev/null +++ b/ssz_serialization/types.nim @@ -0,0 +1,570 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.push raises: [Defect].} + +import + std/[tables, typetraits, strformat], + stew/shims/macros, stew/[byteutils, bitops2, objects], stint, nimcrypto/hash, + serialization/[object_serialization, errors], + json_serialization, + "."/[bitseqs] + +export stint, bitseqs, json_serialization + +const + offsetSize* = 4 + bytesPerChunk* = 32 + +type + UintN* = SomeUnsignedInt|UInt128|UInt256 + BasicType* = bool|UintN + + Limit* = int64 + + Digest* = MDigest[32 * 8] + +# A few index types from here onwards: +# * dataIdx - leaf index starting from 0 to maximum length of collection +# * chunkIdx - leaf data index after chunking starting from 0 +# * vIdx - virtual index in merkle tree - the root is found at index 1, its +# two children at 2, 3 then 4, 5, 6, 7 etc + +func nextPow2Int64(x: int64): int64 = + # TODO the nextPow2 in bitops2 works with uint64 - there's a bug in the nim + # compiler preventing it to be used - it seems that a conversion to + # uint64 cannot be done with the static maxLen :( + var v = x - 1 + + # round down, make sure all bits are 1 below the threshold, then add 1 + v = v or v shr 1 + v = v or v shr 2 + v = v or v shr 4 + when bitsof(x) > 8: + v = v or v shr 8 + when bitsof(x) > 16: + v = v or v shr 16 + when bitsof(x) > 32: + v = v or v shr 32 + + v + 1 + +template dataPerChunk(T: type): int = + # How many data items fit in a chunk + when T is BasicType: + bytesPerChunk div sizeof(T) + else: + 1 + +template chunkIdx*(T: type, dataIdx: int64): int64 = + # Given a data index, which chunk does it belong to? + dataIdx div dataPerChunk(T) + +template maxChunkIdx*(T: type, maxLen: Limit): int64 = + # Given a number of data items, how many chunks are needed? + # TODO compiler bug: + # beacon_chain/ssz/types.nim(75, 53) Error: cannot generate code for: maxLen + # nextPow2(chunkIdx(T, maxLen + dataPerChunk(T) - 1).uint64).int64 + nextPow2Int64(chunkIdx(T, maxLen.int64 + dataPerChunk(T) - 1)) + +template layer*(vIdx: int64): int = + ## Layer 0 = layer at which the root hash is + ## We place the root hash at index 1 which simplifies the math and leaves + ## index 0 for the mixed-in-length + log2trunc(vIdx.uint64).int + +func hashListIndicesLen(maxChunkIdx: int64): int = + # TODO: This exists only to work-around a compilation issue when the complex + # expression is used directly in the HastList array size definition below + int(layer(maxChunkIdx)) + 1 + +type + List*[T; maxLen: static Limit] = distinct seq[T] + BitList*[maxLen: static Limit] = distinct BitSeq + + SingleMemberUnion*[T] = object + selector*: uint8 + value*: T + + HashArray*[maxLen: static Limit; T] = object + ## Array implementation that caches the hash of each chunk of data - see + ## also HashList for more details. + data*: array[maxLen, T] + hashes* {.dontSerialize.}: array[maxChunkIdx(T, maxLen), Digest] + + HashList*[T; maxLen: static Limit] = object + ## List implementation that caches the hash of each chunk of data as well + ## as the combined hash of each level of the merkle tree using a flattened + ## list of hashes. + ## + ## The merkle tree of a list is formed by imagining a virtual buffer of + ## `maxLen` length which is zero-filled where there is no data. Then, + ## a merkle tree of hashes is formed as usual - at each level of the tree, + ## iff the hash is combined from two zero-filled chunks, the hash is not + ## stored in the `hashes` list - instead, `indices` keeps track of where in + ## the list each level starts. When the length of `data` changes, the + ## `hashes` and `indices` structures must be updated accordingly using + ## `growHashes`. + ## + ## All mutating operators (those that take `var HashList`) will + ## automatically invalidate the cache for the relevant chunks - the leaf and + ## all intermediate chunk hashes up to the root. When large changes are made + ## to `data`, it might be more efficient to batch the updates then reset + ## the cache using resetCache` instead. + + data*: List[T, maxLen] + hashes* {.dontSerialize.}: seq[Digest] ## \ + ## Flattened tree store that skips "empty" branches of the tree - the + ## starting index in this sequence of each "level" in the tree is found + ## in `indices`. + indices* {.dontSerialize.}: array[hashListIndicesLen(maxChunkIdx(T, maxLen)), int64] ##\ + ## Holds the starting index in the hashes list for each level of the tree + + # Note for readers: + # We use `array` for `Vector` and + # `BitArray` for `BitVector` + + SszError* = object of SerializationError + + MalformedSszError* = object of SszError + + SszSizeMismatchError* = object of SszError + deserializedType*: cstring + actualSszSize*: int + elementSize*: int + + # These are supported by the SSZ library - anything that's not covered here + # needs to overload toSszType and fromSszBytes + SszType* = + BasicType | array | HashArray | List | HashList | BitArray | BitList | + object | tuple + +template asSeq*(x: List): auto = distinctBase(x) + +template init*[T, N](L: type List[T, N], x: seq[T]): auto = + List[T, N](x) + +template `$`*(x: List): auto = $(distinctBase x) +template len*(x: List): auto = len(distinctBase x) +template low*(x: List): auto = low(distinctBase x) +template high*(x: List): auto = high(distinctBase x) +template `[]`*(x: List, idx: auto): untyped = distinctBase(x)[idx] +template `[]=`*(x: var List, idx: auto, val: auto) = distinctBase(x)[idx] = val +template `==`*(a, b: List): bool = distinctBase(a) == distinctBase(b) + +template `&`*(a, b: List): auto = (type(a)(distinctBase(a) & distinctBase(b))) + +template items* (x: List): untyped = items(distinctBase x) +template pairs* (x: List): untyped = pairs(distinctBase x) +template mitems*(x: var List): untyped = mitems(distinctBase x) +template mpairs*(x: var List): untyped = mpairs(distinctBase x) + +proc add*(x: var List, val: auto): bool = + if x.len < x.maxLen: + add(distinctBase x, val) + true + else: + false + +proc setLen*(x: var List, newLen: int): bool = + if newLen <= x.maxLen: + setLen(distinctBase x, newLen) + true + else: + false + +template init*(L: type BitList, x: seq[byte], N: static Limit): auto = + BitList[N](data: x) + +template init*[N](L: type BitList[N], x: seq[byte]): auto = + L(data: x) + +template init*(T: type BitList, len: int): auto = T init(BitSeq, len) +template len*(x: BitList): auto = len(BitSeq(x)) +template bytes*(x: BitList): auto = seq[byte](x) +template `[]`*(x: BitList, idx: auto): auto = BitSeq(x)[idx] +template `[]=`*(x: var BitList, idx: auto, val: bool) = BitSeq(x)[idx] = val +template `==`*(a, b: BitList): bool = BitSeq(a) == BitSeq(b) +template setBit*(x: var BitList, idx: Natural) = setBit(BitSeq(x), idx) +template clearBit*(x: var BitList, idx: Natural) = clearBit(BitSeq(x), idx) +template overlaps*(a, b: BitList): bool = overlaps(BitSeq(a), BitSeq(b)) +template incl*(a: var BitList, b: BitList) = incl(BitSeq(a), BitSeq(b)) +template isSubsetOf*(a, b: BitList): bool = isSubsetOf(BitSeq(a), BitSeq(b)) +template isZeros*(x: BitList): bool = isZeros(BitSeq(x)) +template countOnes*(x: BitList): int = countOnes(BitSeq(x)) +template countZeros*(x: BitList): int = countZeros(BitSeq(x)) +template countOverlap*(x, y: BitList): int = countOverlap(BitSeq(x), BitSeq(y)) +template `$`*(a: BitList): string = $(BitSeq(a)) + +iterator items*(x: BitList): bool = + for i in 0 ..< x.len: + yield x[i] + +template isCached*(v: Digest): bool = + ## An entry is "in the cache" if the first 8 bytes are zero - conveniently, + ## Nim initializes values this way, and while there may be false positives, + ## that's fine. + + # Checking and resetting the cache status are hotspots - profile before + # touching! + cast[ptr uint64](unsafeAddr v.data[0])[] != 0 # endian safe + +template clearCache*(v: var Digest) = + cast[ptr uint64](addr v.data[0])[] = 0 # endian safe + +template maxChunks*(a: HashList|HashArray): int64 = + ## Layer where data is + maxChunkIdx(a.T, a.maxLen) + +template maxDepth*(a: HashList|HashArray): int = + ## Layer where data is + static: doAssert a.maxChunks <= high(int64) div 2 + layer(nextPow2(a.maxChunks.uint64).int64) + +template chunkIdx(a: HashList|HashArray, dataIdx: int64): int64 = + chunkIdx(a.T, dataIdx) + +proc clearCaches*(a: var HashArray, dataIdx: auto) = + ## Clear all cache entries after data at dataIdx has been modified + var idx = 1 shl (a.maxDepth - 1) + (chunkIdx(a, dataIdx) shr 1) + while idx != 0: + clearCache(a.hashes[idx]) + idx = idx shr 1 + +func nodesAtLayer*(layer, depth, leaves: int): int = + ## Given a number of leaves, how many nodes do you need at a given layer + ## in a binary tree structure? + let leavesPerNode = 1'i64 shl (depth - layer) + int((leaves + leavesPerNode - 1) div leavesPerNode) + +func cacheNodes*(depth, leaves: int): int = + ## Total number of nodes needed to cache a tree of a given depth with + ## `leaves` items in it - chunks that are zero-filled have well-known hash + ## trees and don't need to be stored in the tree. + var res = 0 + for i in 0.. 0: + let + idxInLayer = idx - (1'i64 shl layer) + layerIdx = idxInlayer + a.indices[layer] + if layerIdx < a.indices[layer + 1]: + # Only clear cache when we're actually storing it - ie it hasn't been + # skipped by the "combined zero hash" optimization + clearCache(a.hashes[layerIdx]) + + idx = idx shr 1 + layer = layer - 1 + + clearCache(a.hashes[0]) + +proc clearCache*(a: var HashList) = + # Clear the full merkle tree, in anticipation of a complete rewrite of the + # contents + for c in a.hashes.mitems(): clearCache(c) + +proc growHashes*(a: var HashList) = + ## Ensure that the hash cache is big enough for the data in the list - must + ## be called whenever `data` grows. + let + leaves = int( + chunkIdx(a, a.data.len() + dataPerChunk(a.T) - 1)) + newSize = 1 + cacheNodes(a.maxDepth, leaves) + + if a.hashes.len >= newSize: + return + + var + newHashes = newSeq[Digest](newSize) + newIndices = default(type a.indices) + + if a.hashes.len != newSize: + newIndices[0] = nodesAtLayer(0, a.maxDepth, leaves) + for i in 1..a.maxDepth: + newIndices[i] = newIndices[i - 1] + nodesAtLayer(i - 1, a.maxDepth, leaves) + + for i in 1..= x.maxLen: + return nil + + distinctBase(x.data).setLen(x.data.len + 1) + x.growHashes() + clearCaches(x, x.data.len() - 1) + addr x.data[^1] + +template init*[T, N](L: type HashList[T, N], x: seq[T]): auto = + var tmp = HashList[T, N](data: List[T, N].init(x)) + tmp.growHashes() + tmp + +template len*(x: HashList|HashArray): auto = len(x.data) +template low*(x: HashList|HashArray): auto = low(x.data) +template high*(x: HashList|HashArray): auto = high(x.data) +template `[]`*(x: HashList|HashArray, idx: auto): auto = x.data[idx] + +proc `[]`*(a: var HashArray, b: auto): var a.T = + # Access item and clear cache - use asSeq when only reading! + clearCaches(a, b.Limit) + a.data[b] + +proc `[]=`*(a: var HashArray, b: auto, c: auto) = + clearCaches(a, b.Limit) + a.data[b] = c + +proc `[]`*(x: var HashList, idx: auto): var x.T = + # Access item and clear cache - use asSeq when only reading! + clearCaches(x, idx.int64) + x.data[idx] + +proc `[]=`*(x: var HashList, idx: auto, val: auto) = + clearCaches(x, idx.int64) + x.data[idx] = val + +template `==`*(a, b: HashList|HashArray): bool = a.data == b.data +template asSeq*(x: HashList): auto = asSeq(x.data) +template `$`*(x: HashList): auto = $(x.data) + +template items* (x: HashList|HashArray): untyped = items(x.data) +template pairs* (x: HashList|HashArray): untyped = pairs(x.data) + +template swap*(a, b: var HashList) = + swap(a.data, b.data) + swap(a.hashes, b.hashes) + swap(a.indices, b.indices) + +template clear*(a: var HashList) = + if not a.data.setLen(0): + raiseAssert "length 0 should always succeed" + a.hashes.setLen(0) + a.indices = default(type a.indices) + +template fill*(a: var HashArray, c: auto) = + mixin fill + fill(a.data, c) +template sum*[maxLen; T](a: var HashArray[maxLen, T]): T = + mixin sum + sum(a.data) + +macro unsupported*(T: typed): untyped = + # TODO: {.fatal.} breaks compilation even in `compiles()` context, + # so we use this macro instead. It's also much better at figuring + # out the actual type that was used in the instantiation. + # File both problems as issues. + when T is enum: + error "Nim `enum` types map poorly to SSZ and make it easy to introduce security issues because of spurious Defect's" + else: + error "SSZ serialization of the type " & humaneTypeName(T) & " is not supported, overload toSszType and fromSszBytes" + +template ElemType*(T0: type HashArray): untyped = + T0.T + +template ElemType*(T0: type HashList): untyped = + T0.T + +template ElemType*(T: type array): untyped = + type(default(T)[low(T)]) + +template ElemType*(T: type seq): untyped = + type(default(T)[0]) + +template ElemType*(T0: type List): untyped = + T0.T + +func isFixedSize*(T0: type): bool {.compileTime.} = + mixin toSszType, enumAllSerializedFields + + type T = type toSszType(declval T0) + + when T is BasicType: + return true + elif T is array|HashArray: + return isFixedSize(ElemType(T)) + elif T is object|tuple: + enumAllSerializedFields(T): + when not isFixedSize(FieldType): + return false + return true + +func fixedPortionSize*(T0: type): int {.compileTime.} = + mixin enumAllSerializedFields, toSszType + + type T = type toSszType(declval T0) + + when T is BasicType: sizeof(T) + elif T is array|HashArray: + type E = ElemType(T) + when isFixedSize(E): int(len(T)) * fixedPortionSize(E) + else: int(len(T)) * offsetSize + elif T is object|tuple: + enumAllSerializedFields(T): + when isFixedSize(FieldType): + result += fixedPortionSize(FieldType) + else: + result += offsetSize + else: + unsupported T0 + +# TODO This should have been an iterator, but the VM can't compile the +# code due to "too many registers required". +proc fieldInfos*(RecordType: type): seq[tuple[name: string, + offset: int, + fixedSize: int, + branchKey: string]] = + mixin enumAllSerializedFields + + var + offsetInBranch = {"": 0}.toTable + nestedUnder = initTable[string, string]() + + enumAllSerializedFields(RecordType): + const + isFixed = isFixedSize(FieldType) + fixedSize = when isFixed: fixedPortionSize(FieldType) + else: 0 + branchKey = when fieldCaseDiscriminator.len == 0: "" + else: fieldCaseDiscriminator & ":" & $fieldCaseBranches + fieldSize = when isFixed: fixedSize + else: offsetSize + + nestedUnder[fieldName] = branchKey + + var fieldOffset: int + offsetInBranch.withValue(branchKey, val): + fieldOffset = val[] + val[] += fieldSize + do: + try: + let parentBranch = nestedUnder.getOrDefault(fieldCaseDiscriminator, "") + fieldOffset = offsetInBranch[parentBranch] + offsetInBranch[branchKey] = fieldOffset + fieldSize + except KeyError as e: + raiseAssert e.msg + + result.add((fieldName, fieldOffset, fixedSize, branchKey)) + +func getFieldBoundingOffsetsImpl(RecordType: type, + fieldName: static string): + tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool] {.compileTime.} = + result = (-1, -1, false) + var fieldBranchKey: string + var isFirstOffset = true + + for f in fieldInfos(RecordType): + if fieldName == f.name: + result[0] = f.offset + if f.fixedSize > 0: + result[1] = result[0] + f.fixedSize + return + else: + fieldBranchKey = f.branchKey + result.isFirstOffset = isFirstOffset + + elif result[0] != -1 and + f.fixedSize == 0 and + f.branchKey == fieldBranchKey: + # We have found the next variable sized field + result[1] = f.offset + return + + if f.fixedSize == 0: + isFirstOffset = false + +func getFieldBoundingOffsets*(RecordType: type, + fieldName: static string): + tuple[fieldOffset, nextFieldOffset: int, isFirstOffset: bool] {.compileTime.} = + ## Returns the start and end offsets of a field. + ## + ## For fixed-size fields, the start offset points to the first + ## byte of the field and the end offset points to 1 byte past the + ## end of the field. + ## + ## For variable-size fields, the returned offsets point to the + ## statically known positions of the 32-bit offset values written + ## within the SSZ object. You must read the 32-bit values stored + ## at the these locations in order to obtain the actual offsets. + ## + ## For variable-size fields, the end offset may be -1 when the + ## designated field is the last variable sized field within the + ## object. Then the SSZ object boundary known at run-time marks + ## the end of the variable-size field. + type T = RecordType + anonConst getFieldBoundingOffsetsImpl(T, fieldName) + +template enumerateSubFields*(holder, fieldVar, body: untyped) = + when holder is array|HashArray: + for fieldVar in holder: body + else: + enumInstanceSerializedFields(holder, _{.used.}, fieldVar): body + +method formatMsg*( + err: ref SszSizeMismatchError, + filename: string): string {.gcsafe, raises: [Defect].} = + try: + &"SSZ size mismatch, element {err.elementSize}, actual {err.actualSszSize}, type {err.deserializedType}, file {filename}" + except CatchableError: + "SSZ size mismatch" + +template readValue*(reader: var JsonReader, value: var List) = + value = type(value)(readValue(reader, seq[type value[0]])) + +template writeValue*(writer: var JsonWriter, value: List) = + writeValue(writer, asSeq value) + +proc writeValue*(writer: var JsonWriter, value: HashList) + {.raises: [IOError, SerializationError, Defect].} = + writeValue(writer, value.data) + +proc readValue*(reader: var JsonReader, value: var HashList) + {.raises: [IOError, SerializationError, Defect].} = + value.resetCache() + readValue(reader, value.data) + +template readValue*(reader: var JsonReader, value: var BitList) = + type T = type(value) + value = T readValue(reader, BitSeq) + +template writeValue*(writer: var JsonWriter, value: BitList) = + writeValue(writer, BitSeq value) diff --git a/tests/test_all.nim b/tests/test_all.nim index e69de29..075d2b8 100644 --- a/tests/test_all.nim +++ b/tests/test_all.nim @@ -0,0 +1,3 @@ +import + ./test_ssz_roundtrip, + ./test_ssz_serialization diff --git a/tests/test_ssz_roundtrip.nim b/tests/test_ssz_roundtrip.nim new file mode 100644 index 0000000..39c3f0c --- /dev/null +++ b/tests/test_ssz_roundtrip.nim @@ -0,0 +1,14 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + serialization/testing/generic_suite, + ../ssz_serialization + +executeRoundTripTests SSZ diff --git a/tests/test_ssz_serialization.nim b/tests/test_ssz_serialization.nim new file mode 100644 index 0000000..6daac21 --- /dev/null +++ b/tests/test_ssz_serialization.nim @@ -0,0 +1,359 @@ +# ssz_serialization +# Copyright (c) 2018-2021 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +{.used.} + +import + std/typetraits, + unittest2, stew/byteutils, + ../ssz_serialization, + ../ssz_serialization/[merkleization, navigator, dynamic_navigator] + +# TODO: Move to types? +func `$`*(x: Digest): string = + x.data.toHex() + +type + SomeEnum = enum + A, B, C + + Simple = object + flag: bool + # ignored {.dontSerialize.}: string + data: array[256, bool] + data2: HashArray[256, bool] + + NonFixed = object + data: HashList[uint64, 1024] + +template reject(stmt) = + doAssert(not compiles(stmt)) + +static: + doAssert isFixedSize(bool) == true + + doAssert fixedPortionSize(array[10, bool]) == 10 + doAssert fixedPortionSize(array[SomeEnum, uint64]) == 24 + doAssert fixedPortionSize(array[3..5, List[byte, 256]]) == 12 + + doAssert isFixedSize(array[20, bool]) == true + doAssert isFixedSize(Simple) == true + doAssert isFixedSize(List[bool, 128]) == false + + doAssert isFixedSize(NonFixed) == false + + reject fixedPortionSize(int) + +type + ObjWithFields = object + f0: uint8 + f1: uint32 + f2: array[20, byte] + f3: Digest + +static: + doAssert fixedPortionSize(ObjWithFields) == + 1 + 4 + sizeof(array[20, byte]) + (256 div 8) + +type + Foo = object + bar: Bar + + BarList = List[uint64, 128] + + Bar = object + b: BarList + baz: Baz + + Baz = object + i: uint64 + +func toDigest[N: static int](x: array[N, byte]): Digest = + result.data[0 .. N-1] = x + +suite "SSZ navigator": + test "simple object fields": + var foo = Foo(bar: Bar(b: BarList @[1'u64, 2, 3], baz: Baz(i: 10'u64))) + let encoded = SSZ.encode(foo) + + check SSZ.decode(encoded, Foo) == foo + + let mountedFoo = sszMount(encoded, Foo) + check mountedFoo.bar.b[] == BarList @[1'u64, 2, 3] + + let mountedBar = mountedFoo.bar + check mountedBar.baz.i == 10'u64 + + test "lists with max size": + let a = [byte 0x01, 0x02, 0x03].toDigest + let b = [byte 0x04, 0x05, 0x06].toDigest + let c = [byte 0x07, 0x08, 0x09].toDigest + + var xx: List[uint64, 16] + check: + not xx.setLen(17) + xx.setLen(16) + + var leaves = HashList[Digest, 1'i64 shl 3]() + check: + leaves.add a + leaves.add b + leaves.add c + let root = hash_tree_root(leaves) + check $root == "5248085b588fab1dd1e03f3cd62201602b12e6560665935964f46e805977e8c5" + + while leaves.len < 1 shl 3: + check: + leaves.add c + hash_tree_root(leaves) == hash_tree_root(leaves.data) + + leaves = default(type leaves) + + while leaves.len < (1 shl 3) - 1: + check: + leaves.add c + leaves.add c + hash_tree_root(leaves) == hash_tree_root(leaves.data) + + leaves = default(type leaves) + + while leaves.len < (1 shl 3) - 2: + check: + leaves.add c + leaves.add c + leaves.add c + hash_tree_root(leaves) == hash_tree_root(leaves.data) + + for i in 0 ..< leaves.data.len - 2: + leaves[i] = a + leaves[i + 1] = b + leaves[i + 2] = c + check hash_tree_root(leaves) == hash_tree_root(leaves.data) + + var leaves2 = HashList[Digest, 1'i64 shl 48]() # Large number! + check: + leaves2.add a + leaves2.add b + leaves2.add c + hash_tree_root(leaves2) == hash_tree_root(leaves2.data) + + var leaves3 = HashList[Digest, 7]() # Non-power-of-2 + check: + hash_tree_root(leaves3) == hash_tree_root(leaves3.data) + leaves3.add a + leaves3.add b + leaves3.add c + hash_tree_root(leaves3) == hash_tree_root(leaves3.data) + + test "basictype": + var leaves = HashList[uint64, 1'i64 shl 3]() + while leaves.len < leaves.maxLen: + check: + leaves.add leaves.len.uint64 + hash_tree_root(leaves) == hash_tree_root(leaves.data) + +suite "SSZ dynamic navigator": + test "navigating fields": + var fooOrig = Foo(bar: Bar(b: BarList @[1'u64, 2, 3], baz: Baz(i: 10'u64))) + let fooEncoded = SSZ.encode(fooOrig) + + var navFoo = DynamicSszNavigator.init(fooEncoded, Foo) + + var navBar = navFoo.navigate("bar") + check navBar.toJson(pretty = false) == """{"b":[1,2,3],"baz":{"i":10}}""" + + var navB = navBar.navigate("b") + check navB.toJson(pretty = false) == "[1,2,3]" + + var navBaz = navBar.navigate("baz") + var navI = navBaz.navigate("i") + check navI.toJson == "10" + + expect KeyError: + discard navBar.navigate("biz") + +type + Obj = object + arr: array[8, Digest] + + li: List[Digest, 8] + + HashObj = object + arr: HashArray[8, Digest] + + li: HashList[Digest, 8] + +suite "hash": + test "HashArray": + var + o = Obj() + ho = HashObj() + + template both(body) = + block: + template it: auto {.inject.} = o + body + block: + template it: auto {.inject.} = ho + body + + let htro = hash_tree_root(o) + let htrho = hash_tree_root(ho) + + check: + o.arr == ho.arr.data + o.li == ho.li.data + htro == htrho + + both: it.arr[0].data[0] = byte 1 + + both: check: it.li.add Digest() + + var y: HashArray[32, uint64] + check: hash_tree_root(y) == hash_tree_root(y.data) + for i in 0.. Date: Thu, 21 Oct 2021 21:23:26 +0200 Subject: [PATCH 2/5] Update github actions --- .github/workflows/ci.yml | 44 +++++++--------------------------------- 1 file changed, 7 insertions(+), 37 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ddb848..2fb9e82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,38 +7,18 @@ jobs: fail-fast: false max-parallel: 20 matrix: - branch: [master] + test_lang: [c, cpp] target: - os: linux cpu: amd64 - TEST_LANG: c - - os: linux - cpu: amd64 - TEST_LANG: cpp - - os: linux - cpu: i386 - TEST_LANG: c - os: linux cpu: i386 - TEST_LANG: cpp - - os: macos - cpu: amd64 - TEST_LANG: c - os: macos cpu: amd64 - TEST_LANG: cpp - os: windows cpu: amd64 - TEST_LANG: c - - os: windows - cpu: amd64 - TEST_LANG: cpp - os: windows cpu: i386 - TEST_LANG: c - - os: windows - cpu: i386 - TEST_LANG: cpp include: - target: os: linux @@ -50,7 +30,7 @@ jobs: os: windows builder: windows-2019 - name: '${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ matrix.target.TEST_LANG }} (${{ matrix.branch }})' + name: '${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ matrix.test_lang }}' runs-on: ${{ matrix.builder }} steps: - name: Checkout nim-ssz-serialization @@ -103,10 +83,10 @@ jobs: run: | mkdir -p external if [[ '${{ matrix.target.cpu }}' == 'amd64' ]]; then - MINGW_URL="https://sourceforge.net/projects/mingw-w64/files/Toolchains targetting Win64/Personal Builds/mingw-builds/8.1.0/threads-posix/seh/x86_64-8.1.0-release-posix-seh-rt_v6-rev0.7z" + MINGW_URL="https://github.com/brechtsanders/winlibs_mingw/releases/download/11.2.0-12.0.1-9.0.0-r1/winlibs-x86_64-posix-seh-gcc-11.2.0-mingw-w64-9.0.0-r1.7z" ARCH=64 else - MINGW_URL="https://sourceforge.net/projects/mingw-w64/files/Toolchains targetting Win32/Personal Builds/mingw-builds/8.1.0/threads-posix/dwarf/i686-8.1.0-release-posix-dwarf-rt_v6-rev0.7z" + MINGW_URL="https://github.com/brechtsanders/winlibs_mingw/releases/download/11.2.0-12.0.1-9.0.0-r1/winlibs-i686-posix-dwarf-gcc-11.2.0-mingw-w64-9.0.0-r1.7z" ARCH=32 fi curl -L "$MINGW_URL" -o "external/mingw-${{ matrix.target.cpu }}.7z" @@ -145,11 +125,10 @@ jobs: id: nim-cache uses: actions/cache@v2 with: - path: nim - key: 'nim-${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ steps.versions.outputs.nimbus_build_system }}' + path: NimBinaries + key: 'NimBinaries-${{ matrix.target.os }}-${{ matrix.target.cpu }}-${{ steps.versions.outputs.nimbus_build_system }}' - name: Build Nim and associated tools - if: steps.nim-cache.outputs.cache-hit != 'true' shell: bash run: | curl -O -L -s -S https://raw.githubusercontent.com/status-im/nimbus-build-system/master/scripts/build_nim.sh @@ -165,15 +144,6 @@ jobs: fi env MAKE="$MAKE_CMD -j2" ARCH_OVERRIDE=$PLATFORM CC=gcc bash build_nim.sh nim csources dist/nimble NimBinaries - # clean up to save cache space - cd nim - rm koch - rm -rf nimcache - rm -rf csources - rm -rf tests - rm -rf dist - rm -rf .git - - name: Setup environment shell: bash run: echo '${{ github.workspace }}/nim/bin' >> $GITHUB_PATH @@ -183,4 +153,4 @@ jobs: working-directory: nim-ssz-serialization run: | nimble install -y --depsOnly - env TEST_LANG="${{ matrix.target.TEST_LANG }}" nimble test + env TEST_LANG="${{ matrix.test_lang }}" nimble test From 627ddfacaac20ed67f9c47a67ddbe99693f48672 Mon Sep 17 00:00:00 2001 From: kdeme Date: Fri, 22 Oct 2021 11:51:06 +0200 Subject: [PATCH 3/5] Add few helper templates for SSZ List --- ssz_serialization/types.nim | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ssz_serialization/types.nim b/ssz_serialization/types.nim index 0e2dc47..68d4c70 100644 --- a/ssz_serialization/types.nim +++ b/ssz_serialization/types.nim @@ -145,6 +145,9 @@ type template asSeq*(x: List): auto = distinctBase(x) +template init*[T](L: type List, x: seq[T], N: static Limit): auto = + List[T, N](x) + template init*[T, N](L: type List[T, N], x: seq[T]): auto = List[T, N](x) @@ -162,6 +165,7 @@ template items* (x: List): untyped = items(distinctBase x) template pairs* (x: List): untyped = pairs(distinctBase x) template mitems*(x: var List): untyped = mitems(distinctBase x) template mpairs*(x: var List): untyped = mpairs(distinctBase x) +template contains* (x: List, val: auto): untyped = contains(distinctBase x, val) proc add*(x: var List, val: auto): bool = if x.len < x.maxLen: From 9e5455225c44e1f4a242c24ba1dcd56d6a274bed Mon Sep 17 00:00:00 2001 From: kdeme Date: Fri, 22 Oct 2021 16:56:18 +0200 Subject: [PATCH 4/5] Remove the List.init() helper again Seems to conflict with SszWriter.init ... sometimes. Need a magician for this. --- ssz_serialization/types.nim | 3 --- 1 file changed, 3 deletions(-) diff --git a/ssz_serialization/types.nim b/ssz_serialization/types.nim index 68d4c70..2fbe6a2 100644 --- a/ssz_serialization/types.nim +++ b/ssz_serialization/types.nim @@ -145,9 +145,6 @@ type template asSeq*(x: List): auto = distinctBase(x) -template init*[T](L: type List, x: seq[T], N: static Limit): auto = - List[T, N](x) - template init*[T, N](L: type List[T, N], x: seq[T]): auto = List[T, N](x) From e409ba448c1a81a1e8bba59cc6134c0fa09cf61f Mon Sep 17 00:00:00 2001 From: kdeme Date: Fri, 22 Oct 2021 17:02:37 +0200 Subject: [PATCH 5/5] Add test_bitseqs.nim to the tests --- tests/test_all.nim | 1 + tests/test_bitseqs.nim | 162 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 tests/test_bitseqs.nim diff --git a/tests/test_all.nim b/tests/test_all.nim index 075d2b8..0ec9ac4 100644 --- a/tests/test_all.nim +++ b/tests/test_all.nim @@ -1,3 +1,4 @@ import + ./test_bitseqs, ./test_ssz_roundtrip, ./test_ssz_serialization diff --git a/tests/test_bitseqs.nim b/tests/test_bitseqs.nim new file mode 100644 index 0000000..48547b9 --- /dev/null +++ b/tests/test_bitseqs.nim @@ -0,0 +1,162 @@ +{.used.} + +import + unittest2, + std/[sequtils, strformat], + ../ssz_serialization/bitseqs + +suite "Bit fields": + test "roundtrips BitArray": + var + a = BitArray[100]() + b = BitArray[100]() + c = BitArray[100]() + + check: + not a[0] + + a.setBit 1 + + check: + not a[0] + a[1] + toSeq(a.oneIndices()) == [1] + + a + b == a + a - b == a + a - a == c # empty + + b + a == a + b - b == c # b is empty + + b.setBit 2 + + check: + (a + b)[2] + (b - a)[2] + not (b - a)[1] + + a.setBit 99 + + check: + (a + b)[99] + (b - a)[2] + not (b - a)[1] + not (b - a)[99] + toSeq(a.oneIndices()) == [1, 99] + + a.incl(b) + + check: + not a[0] + a[1] + a[2] + + a.clear() + check: + not a[1] + + test "roundtrips BitSeq": + var + a = BitSeq.init(100) + b = BitSeq.init(100) + + check: + not a[0] + a.isZeros() + + a.setBit 1 + + check: + not a[0] + a[1] + a.countOnes() == 1 + a.countZeros() == 99 + not a.isZeros() + a.countOverlap(a) == 1 + + b.setBit 2 + + a.incl(b) + + check: + not a[0] + a[1] + a[2] + a.countOverlap(a) == 2 + a.countOverlap(b) == 1 + b.countOverlap(a) == 1 + b.countOverlap(b) == 1 + a.clear() + check: + not a[1] + + test "iterating words": + for bitCount in [8, 3, 7, 8, 14, 15, 16, 19, 260]: + checkpoint &"trying bit count {bitCount}" + var + a = BitSeq.init(bitCount) + b = BitSeq.init(bitCount) + bitsInWord = sizeof(uint) * 8 + expectedWordCount = (bitCount div bitsInWord) + 1 + + for i in 0 ..< expectedWordCount: + let every3rdBit = i * sizeof(uint) * 8 + 2 + a[every3rdBit] = true + b[every3rdBit] = true + + for word in words(a): + check word == 4 + word = 2 + + for wa, wb in words(a, b): + check wa == 2 and wb == 4 + wa = 1 + wb = 2 + + for i in 0 ..< expectedWordCount: + for j in 0 ..< bitsInWord: + let bitPos = i * bitsInWord + j + if bitPos < bitCount: + check a[j] == (j == 0) + check b[j] == (j == 1) + + test "overlaps": + for bitCount in [1, 62, 63, 64, 91, 127, 128, 129]: + checkpoint &"trying bit count {bitCount}" + var + a = BitSeq.init(bitCount) + b = BitSeq.init(bitCount) + + for pos in [4, 8, 9, 12, 29, 32, 63, 64, 67]: + if pos + 2 < bitCount: + a.setBit(pos) + b.setBit(pos + 2) + + check: + not a.overlaps(b) + not b.overlaps(a) + a.countOverlap(b) == 0 + + test "isZeros": + template carryOutTests(N: static int) = + var a = BitArray[N]() + check a.isZeros() + + for i in 0 ..< N: + var b = a + b.setBit(i) + check(not b.isZeros()) + + carryOutTests(1) + carryOutTests(10) + carryOutTests(31) + carryOutTests(32) + carryOutTests(63) + carryOutTests(64) + carryOutTests(65) + carryOutTests(95) + carryOutTests(96) + carryOutTests(97) + carryOutTests(12494) +