diff --git a/sha3/sha3.go b/sha3/sha3.go index 4884d172a4..33bd73b0f6 100644 --- a/sha3/sha3.go +++ b/sha3/sha3.go @@ -40,7 +40,7 @@ type state struct { // Extendable-Output Functions (May 2014)" dsbyte byte - storage storageBuf + storage [maxRate]byte // Specific to SHA-3 and SHAKE. outputLen int // the default output size in bytes @@ -61,15 +61,15 @@ func (d *state) Reset() { d.a[i] = 0 } d.state = spongeAbsorbing - d.buf = d.storage.asBytes()[:0] + d.buf = d.storage[:0] } func (d *state) clone() *state { ret := *d if ret.state == spongeAbsorbing { - ret.buf = ret.storage.asBytes()[:len(ret.buf)] + ret.buf = ret.storage[:len(ret.buf)] } else { - ret.buf = ret.storage.asBytes()[d.rate-cap(d.buf) : d.rate] + ret.buf = ret.storage[d.rate-cap(d.buf) : d.rate] } return &ret @@ -83,13 +83,13 @@ func (d *state) permute() { // If we're absorbing, we need to xor the input into the state // before applying the permutation. xorIn(d, d.buf) - d.buf = d.storage.asBytes()[:0] + d.buf = d.storage[:0] keccakF1600(&d.a) case spongeSqueezing: // If we're squeezing, we need to apply the permutation before // copying more output. keccakF1600(&d.a) - d.buf = d.storage.asBytes()[:d.rate] + d.buf = d.storage[:d.rate] copyOut(d, d.buf) } } @@ -98,7 +98,7 @@ func (d *state) permute() { // the multi-bitrate 10..1 padding rule, and permutes the state. func (d *state) padAndPermute(dsbyte byte) { if d.buf == nil { - d.buf = d.storage.asBytes()[:0] + d.buf = d.storage[:0] } // Pad with this instance's domain-separator bits. We know that there's // at least one byte of space in d.buf because, if it were full, @@ -106,7 +106,7 @@ func (d *state) padAndPermute(dsbyte byte) { // first one bit for the padding. See the comment in the state struct. d.buf = append(d.buf, dsbyte) zerosStart := len(d.buf) - d.buf = d.storage.asBytes()[:d.rate] + d.buf = d.storage[:d.rate] for i := zerosStart; i < d.rate; i++ { d.buf[i] = 0 } @@ -117,7 +117,7 @@ func (d *state) padAndPermute(dsbyte byte) { // Apply the permutation d.permute() d.state = spongeSqueezing - d.buf = d.storage.asBytes()[:d.rate] + d.buf = d.storage[:d.rate] copyOut(d, d.buf) } @@ -128,7 +128,7 @@ func (d *state) Write(p []byte) (written int, err error) { panic("sha3: Write after Read") } if d.buf == nil { - d.buf = d.storage.asBytes()[:0] + d.buf = d.storage[:0] } written = len(p) diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go index afcb722e7a..21e8cbad7b 100644 --- a/sha3/sha3_test.go +++ b/sha3/sha3_test.go @@ -76,86 +76,73 @@ type KeccakKats struct { } } -func testUnalignedAndGeneric(t *testing.T, testf func(impl string)) { - xorInOrig, copyOutOrig := xorIn, copyOut - xorIn, copyOut = xorInGeneric, copyOutGeneric - testf("generic") - if xorImplementationUnaligned != "generic" { - xorIn, copyOut = xorInUnaligned, copyOutUnaligned - testf("unaligned") - } - xorIn, copyOut = xorInOrig, copyOutOrig -} - // TestKeccakKats tests the SHA-3 and Shake implementations against all the // ShortMsgKATs from https://github.com/gvanas/KeccakCodePackage // (The testvectors are stored in keccakKats.json.deflate due to their length.) func TestKeccakKats(t *testing.T) { - testUnalignedAndGeneric(t, func(impl string) { - // Read the KATs. - deflated, err := os.Open(katFilename) - if err != nil { - t.Errorf("error opening %s: %s", katFilename, err) - } - file := flate.NewReader(deflated) - dec := json.NewDecoder(file) - var katSet KeccakKats - err = dec.Decode(&katSet) - if err != nil { - t.Errorf("error decoding KATs: %s", err) - } + // Read the KATs. + deflated, err := os.Open(katFilename) + if err != nil { + t.Errorf("error opening %s: %s", katFilename, err) + } + file := flate.NewReader(deflated) + dec := json.NewDecoder(file) + var katSet KeccakKats + err = dec.Decode(&katSet) + if err != nil { + t.Errorf("error decoding KATs: %s", err) + } - for algo, function := range testDigests { - d := function() - for _, kat := range katSet.Kats[algo] { - d.Reset() - in, err := hex.DecodeString(kat.Message) - if err != nil { - t.Errorf("error decoding KAT: %s", err) - } - d.Write(in[:kat.Length/8]) - got := strings.ToUpper(hex.EncodeToString(d.Sum(nil))) - if got != kat.Digest { - t.Errorf("function=%s, implementation=%s, length=%d\nmessage:\n %s\ngot:\n %s\nwanted:\n %s", - algo, impl, kat.Length, kat.Message, got, kat.Digest) - t.Logf("wanted %+v", kat) - t.FailNow() - } - continue + for algo, function := range testDigests { + d := function() + for _, kat := range katSet.Kats[algo] { + d.Reset() + in, err := hex.DecodeString(kat.Message) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } + d.Write(in[:kat.Length/8]) + got := strings.ToUpper(hex.EncodeToString(d.Sum(nil))) + if got != kat.Digest { + t.Errorf("function=%s, length=%d\nmessage:\n %s\ngot:\n %s\nwanted:\n %s", + algo, kat.Length, kat.Message, got, kat.Digest) + t.Logf("wanted %+v", kat) + t.FailNow() } + continue } + } - for algo, v := range testShakes { - for _, kat := range katSet.Kats[algo] { - N, err := hex.DecodeString(kat.N) - if err != nil { - t.Errorf("error decoding KAT: %s", err) - } + for algo, v := range testShakes { + for _, kat := range katSet.Kats[algo] { + N, err := hex.DecodeString(kat.N) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } - S, err := hex.DecodeString(kat.S) - if err != nil { - t.Errorf("error decoding KAT: %s", err) - } - d := v.constructor(N, S) - in, err := hex.DecodeString(kat.Message) - if err != nil { - t.Errorf("error decoding KAT: %s", err) - } + S, err := hex.DecodeString(kat.S) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } + d := v.constructor(N, S) + in, err := hex.DecodeString(kat.Message) + if err != nil { + t.Errorf("error decoding KAT: %s", err) + } - d.Write(in[:kat.Length/8]) - out := make([]byte, len(kat.Digest)/2) - d.Read(out) - got := strings.ToUpper(hex.EncodeToString(out)) - if got != kat.Digest { - t.Errorf("function=%s, implementation=%s, length=%d N:%s\n S:%s\nmessage:\n %s \ngot:\n %s\nwanted:\n %s", - algo, impl, kat.Length, kat.N, kat.S, kat.Message, got, kat.Digest) - t.Logf("wanted %+v", kat) - t.FailNow() - } - continue + d.Write(in[:kat.Length/8]) + out := make([]byte, len(kat.Digest)/2) + d.Read(out) + got := strings.ToUpper(hex.EncodeToString(out)) + if got != kat.Digest { + t.Errorf("function=%s, length=%d N:%s\n S:%s\nmessage:\n %s \ngot:\n %s\nwanted:\n %s", + algo, kat.Length, kat.N, kat.S, kat.Message, got, kat.Digest) + t.Logf("wanted %+v", kat) + t.FailNow() } + continue } - }) + } } // TestKeccak does a basic test of the non-standardized Keccak hash functions. @@ -219,119 +206,111 @@ func TestShakeSum(t *testing.T) { // TestUnalignedWrite tests that writing data in an arbitrary pattern with // small input buffers. func TestUnalignedWrite(t *testing.T) { - testUnalignedAndGeneric(t, func(impl string) { - buf := sequentialBytes(0x10000) - for alg, df := range testDigests { - d := df() - d.Reset() - d.Write(buf) - want := d.Sum(nil) - d.Reset() - for i := 0; i < len(buf); { - // Cycle through offsets which make a 137 byte sequence. - // Because 137 is prime this sequence should exercise all corner cases. - offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1} - for _, j := range offsets { - if v := len(buf) - i; v < j { - j = v - } - d.Write(buf[i : i+j]) - i += j + buf := sequentialBytes(0x10000) + for alg, df := range testDigests { + d := df() + d.Reset() + d.Write(buf) + want := d.Sum(nil) + d.Reset() + for i := 0; i < len(buf); { + // Cycle through offsets which make a 137 byte sequence. + // Because 137 is prime this sequence should exercise all corner cases. + offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1} + for _, j := range offsets { + if v := len(buf) - i; v < j { + j = v } - } - got := d.Sum(nil) - if !bytes.Equal(got, want) { - t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want) + d.Write(buf[i : i+j]) + i += j } } + got := d.Sum(nil) + if !bytes.Equal(got, want) { + t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want) + } + } - // Same for SHAKE - for alg, df := range testShakes { - want := make([]byte, 16) - got := make([]byte, 16) - d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr)) - - d.Reset() - d.Write(buf) - d.Read(want) - d.Reset() - for i := 0; i < len(buf); { - // Cycle through offsets which make a 137 byte sequence. - // Because 137 is prime this sequence should exercise all corner cases. - offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1} - for _, j := range offsets { - if v := len(buf) - i; v < j { - j = v - } - d.Write(buf[i : i+j]) - i += j + // Same for SHAKE + for alg, df := range testShakes { + want := make([]byte, 16) + got := make([]byte, 16) + d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr)) + + d.Reset() + d.Write(buf) + d.Read(want) + d.Reset() + for i := 0; i < len(buf); { + // Cycle through offsets which make a 137 byte sequence. + // Because 137 is prime this sequence should exercise all corner cases. + offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1} + for _, j := range offsets { + if v := len(buf) - i; v < j { + j = v } - } - d.Read(got) - if !bytes.Equal(got, want) { - t.Errorf("Unaligned writes, implementation=%s, alg=%s\ngot %q, want %q", impl, alg, got, want) + d.Write(buf[i : i+j]) + i += j } } - }) + d.Read(got) + if !bytes.Equal(got, want) { + t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want) + } + } } // TestAppend checks that appending works when reallocation is necessary. func TestAppend(t *testing.T) { - testUnalignedAndGeneric(t, func(impl string) { - d := New224() + d := New224() - for capacity := 2; capacity <= 66; capacity += 64 { - // The first time around the loop, Sum will have to reallocate. - // The second time, it will not. - buf := make([]byte, 2, capacity) - d.Reset() - d.Write([]byte{0xcc}) - buf = d.Sum(buf) - expected := "0000DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39" - if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected { - t.Errorf("got %s, want %s", got, expected) - } + for capacity := 2; capacity <= 66; capacity += 64 { + // The first time around the loop, Sum will have to reallocate. + // The second time, it will not. + buf := make([]byte, 2, capacity) + d.Reset() + d.Write([]byte{0xcc}) + buf = d.Sum(buf) + expected := "0000DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39" + if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected { + t.Errorf("got %s, want %s", got, expected) } - }) + } } // TestAppendNoRealloc tests that appending works when no reallocation is necessary. func TestAppendNoRealloc(t *testing.T) { - testUnalignedAndGeneric(t, func(impl string) { - buf := make([]byte, 1, 200) - d := New224() - d.Write([]byte{0xcc}) - buf = d.Sum(buf) - expected := "00DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39" - if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected { - t.Errorf("%s: got %s, want %s", impl, got, expected) - } - }) + buf := make([]byte, 1, 200) + d := New224() + d.Write([]byte{0xcc}) + buf = d.Sum(buf) + expected := "00DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39" + if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected { + t.Errorf("got %s, want %s", got, expected) + } } // TestSqueezing checks that squeezing the full output a single time produces // the same output as repeatedly squeezing the instance. func TestSqueezing(t *testing.T) { - testUnalignedAndGeneric(t, func(impl string) { - for algo, v := range testShakes { - d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) - d0.Write([]byte(testString)) - ref := make([]byte, 32) - d0.Read(ref) - - d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) - d1.Write([]byte(testString)) - var multiple []byte - for range ref { - one := make([]byte, 1) - d1.Read(one) - multiple = append(multiple, one...) - } - if !bytes.Equal(ref, multiple) { - t.Errorf("%s (%s): squeezing %d bytes one at a time failed", algo, impl, len(ref)) - } + for algo, v := range testShakes { + d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d0.Write([]byte(testString)) + ref := make([]byte, 32) + d0.Read(ref) + + d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d1.Write([]byte(testString)) + var multiple []byte + for range ref { + one := make([]byte, 1) + d1.Read(one) + multiple = append(multiple, one...) + } + if !bytes.Equal(ref, multiple) { + t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref)) } - }) + } } // sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing. diff --git a/sha3/xor.go b/sha3/xor.go index 7337cca88e..6ada5c9574 100644 --- a/sha3/xor.go +++ b/sha3/xor.go @@ -2,22 +2,39 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build (!amd64 && !386 && !ppc64le) || purego - package sha3 -// A storageBuf is an aligned array of maxRate bytes. -type storageBuf [maxRate]byte - -func (b *storageBuf) asBytes() *[maxRate]byte { - return (*[maxRate]byte)(b) -} +import ( + "crypto/subtle" + "encoding/binary" + "unsafe" -var ( - xorIn = xorInGeneric - copyOut = copyOutGeneric - xorInUnaligned = xorInGeneric - copyOutUnaligned = copyOutGeneric + "golang.org/x/sys/cpu" ) -const xorImplementationUnaligned = "generic" +// xorIn xors the bytes in buf into the state. +func xorIn(d *state, buf []byte) { + if cpu.IsBigEndian { + for i := 0; len(buf) >= 8; i++ { + a := binary.LittleEndian.Uint64(buf) + d.a[i] ^= a + buf = buf[8:] + } + } else { + ab := (*[25 * 64 / 8]byte)(unsafe.Pointer(&d.a)) + subtle.XORBytes(ab[:], ab[:], buf) + } +} + +// copyOut copies uint64s to a byte buffer. +func copyOut(d *state, b []byte) { + if cpu.IsBigEndian { + for i := 0; len(b) >= 8; i++ { + binary.LittleEndian.PutUint64(b, d.a[i]) + b = b[8:] + } + } else { + ab := (*[25 * 64 / 8]byte)(unsafe.Pointer(&d.a)) + copy(b, ab[:]) + } +} diff --git a/sha3/xor_generic.go b/sha3/xor_generic.go deleted file mode 100644 index 8d94771127..0000000000 --- a/sha3/xor_generic.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sha3 - -import "encoding/binary" - -// xorInGeneric xors the bytes in buf into the state; it -// makes no non-portable assumptions about memory layout -// or alignment. -func xorInGeneric(d *state, buf []byte) { - n := len(buf) / 8 - - for i := 0; i < n; i++ { - a := binary.LittleEndian.Uint64(buf) - d.a[i] ^= a - buf = buf[8:] - } -} - -// copyOutGeneric copies uint64s to a byte buffer. -func copyOutGeneric(d *state, b []byte) { - for i := 0; len(b) >= 8; i++ { - binary.LittleEndian.PutUint64(b, d.a[i]) - b = b[8:] - } -} diff --git a/sha3/xor_unaligned.go b/sha3/xor_unaligned.go deleted file mode 100644 index 870e2d16e0..0000000000 --- a/sha3/xor_unaligned.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build (amd64 || 386 || ppc64le) && !purego - -package sha3 - -import "unsafe" - -// A storageBuf is an aligned array of maxRate bytes. -type storageBuf [maxRate / 8]uint64 - -func (b *storageBuf) asBytes() *[maxRate]byte { - return (*[maxRate]byte)(unsafe.Pointer(b)) -} - -// xorInUnaligned uses unaligned reads and writes to update d.a to contain d.a -// XOR buf. -func xorInUnaligned(d *state, buf []byte) { - n := len(buf) - bw := (*[maxRate / 8]uint64)(unsafe.Pointer(&buf[0]))[: n/8 : n/8] - if n >= 72 { - d.a[0] ^= bw[0] - d.a[1] ^= bw[1] - d.a[2] ^= bw[2] - d.a[3] ^= bw[3] - d.a[4] ^= bw[4] - d.a[5] ^= bw[5] - d.a[6] ^= bw[6] - d.a[7] ^= bw[7] - d.a[8] ^= bw[8] - } - if n >= 104 { - d.a[9] ^= bw[9] - d.a[10] ^= bw[10] - d.a[11] ^= bw[11] - d.a[12] ^= bw[12] - } - if n >= 136 { - d.a[13] ^= bw[13] - d.a[14] ^= bw[14] - d.a[15] ^= bw[15] - d.a[16] ^= bw[16] - } - if n >= 144 { - d.a[17] ^= bw[17] - } - if n >= 168 { - d.a[18] ^= bw[18] - d.a[19] ^= bw[19] - d.a[20] ^= bw[20] - } -} - -func copyOutUnaligned(d *state, buf []byte) { - ab := (*[maxRate]uint8)(unsafe.Pointer(&d.a[0])) - copy(buf, ab[:]) -} - -var ( - xorIn = xorInUnaligned - copyOut = copyOutUnaligned -) - -const xorImplementationUnaligned = "unaligned"