From e5d5390ef42395d99ae7c0ca072672b6f723da06 Mon Sep 17 00:00:00 2001 From: Jorropo Date: Mon, 12 Jun 2023 08:54:21 +0200 Subject: [PATCH] perf: outline logic in Decode to allow for stack allocations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I took extra efforts for this to be a backward compatible change, I think `DecodedMultihash` should return a value struct not a pointer. I also updated the error type to a value because this allows for 1 instead of 2 allocations when erroring. ``` name old time/op new time/op delta Decode-12 102ns ± 3% 18ns ± 3% -82.47% (p=0.000 n=9+9) name old alloc/op new alloc/op delta Decode-12 64.0B ± 0% 0.0B -100.00% (p=0.000 n=10+10) name old allocs/op new allocs/op delta Decode-12 1.00 ± 0% 0.00 -100.00% (p=0.000 n=10+10) ``` I originally found this problem by benchmarking `go-cid`: ``` github.com/ipfs/go-cid.CidFromBytes /home/hugo/go/pkg/mod/github.com/ipfs/go-cid@v0.4.0/cid.go Total: 4.64GB 10.75GB (flat, cum) 100% 638 . . if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 { 639 . . if len(data) < 34 { 640 . . return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")} 641 . . } 642 . . 643 . 6.11GB h, err := mh.Cast(data[:34]) _, err := Decode(buf) multihash.go:215 644 . . if err != nil { 645 . . return 0, Undef, ErrInvalidCid{err} 646 . . } ``` We can see it call `mh.Cast` and `mh.Cast` call `Decode` and instantly drops the `DecodedMultihash`. The point of this is purely to validate the multihash by checking err. --- allocate_go119_test.go | 11 ++++++ allocate_go120_test.go | 12 +++++++ multihash.go | 22 +++++++++--- multihash_test.go | 78 ++++++++++++++++++++++++++++-------------- 4 files changed, 94 insertions(+), 29 deletions(-) create mode 100644 allocate_go119_test.go create mode 100644 allocate_go120_test.go diff --git a/allocate_go119_test.go b/allocate_go119_test.go new file mode 100644 index 0000000..984cb67 --- /dev/null +++ b/allocate_go119_test.go @@ -0,0 +1,11 @@ +//go:build !go1.20 + +package multihash + +import "testing" + +func mustNotAllocateMore(_ *testing.T, _ float64, f func()) { + // the compiler isn't able to detect our outlined stack allocation on before + // 1.20 so let's not test for it. We don't mind if outdated versions are slightly slower. + f() +} diff --git a/allocate_go120_test.go b/allocate_go120_test.go new file mode 100644 index 0000000..78f1286 --- /dev/null +++ b/allocate_go120_test.go @@ -0,0 +1,12 @@ +//go:build go1.20 + +package multihash + +import "testing" + +func mustNotAllocateMore(t *testing.T, n float64, f func()) { + t.Helper() + if b := testing.AllocsPerRun(10, f); b > n { + t.Errorf("it allocated %f max %f !", b, n) + } +} diff --git a/multihash.go b/multihash.go index 58e631d..1ef8d92 100644 --- a/multihash.go +++ b/multihash.go @@ -27,7 +27,7 @@ var ( // ErrInconsistentLen is returned when a decoded multihash has an inconsistent length type ErrInconsistentLen struct { - dm *DecodedMultihash + dm DecodedMultihash lengthFound int } @@ -222,12 +222,26 @@ func Cast(buf []byte) (Multihash, error) { // Decode parses multihash bytes into a DecodedMultihash. func Decode(buf []byte) (*DecodedMultihash, error) { - rlen, code, hdig, err := readMultihashFromBuf(buf) + // outline decode allowing the &dm expression to be inlined into the caller. + // This moves the heap allocation into the caller and if the caller doesn't + // leak dm the compiler will use a stack allocation instead. + // If you do not outline this &dm always heap allocate since the pointer is + // returned which cause a heap allocation because Decode's stack frame is + // about to disapear. + dm, err := decode(buf) if err != nil { return nil, err } + return &dm, nil +} + +func decode(buf []byte) (dm DecodedMultihash, err error) { + rlen, code, hdig, err := readMultihashFromBuf(buf) + if err != nil { + return DecodedMultihash{}, err + } - dm := &DecodedMultihash{ + dm = DecodedMultihash{ Code: code, Name: Codes[code], Length: len(hdig), @@ -235,7 +249,7 @@ func Decode(buf []byte) (*DecodedMultihash, error) { } if len(buf) != rlen { - return nil, ErrInconsistentLen{dm, rlen} + return dm, ErrInconsistentLen{dm, rlen} } return dm, nil diff --git a/multihash_test.go b/multihash_test.go index 6230e29..4080f8e 100644 --- a/multihash_test.go +++ b/multihash_test.go @@ -151,27 +151,29 @@ func TestDecode(t *testing.T) { nb := append(pre[:n], ob...) - dec, err := Decode(nb) - if err != nil { - t.Error(err) - continue - } + mustNotAllocateMore(t, 0, func() { + dec, err := Decode(nb) + if err != nil { + t.Error(err) + return + } - if dec.Code != tc.code { - t.Error("decoded code mismatch: ", dec.Code, tc.code) - } + if dec.Code != tc.code { + t.Error("decoded code mismatch: ", dec.Code, tc.code) + } - if dec.Name != tc.name { - t.Error("decoded name mismatch: ", dec.Name, tc.name) - } + if dec.Name != tc.name { + t.Error("decoded name mismatch: ", dec.Name, tc.name) + } - if dec.Length != len(ob) { - t.Error("decoded length mismatch: ", dec.Length, len(ob)) - } + if dec.Length != len(ob) { + t.Error("decoded length mismatch: ", dec.Length, len(ob)) + } - if !bytes.Equal(dec.Digest, ob) { - t.Error("decoded byte mismatch: ", dec.Digest, ob) - } + if !bytes.Equal(dec.Digest, ob) { + t.Error("decoded byte mismatch: ", dec.Digest, ob) + } + }) } } @@ -242,15 +244,20 @@ func TestCast(t *testing.T) { nb := append(pre[:n], ob...) - if _, err := Cast(nb); err != nil { - t.Error(err) - continue - } + mustNotAllocateMore(t, 0, func() { + if _, err := Cast(nb); err != nil { + t.Error(err) + return + } + }) - if _, err = Cast(ob); err == nil { - t.Error("cast failed to detect non-multihash") - continue - } + // 1 for the error object. + mustNotAllocateMore(t, 1, func() { + if _, err = Cast(ob); err == nil { + t.Error("cast failed to detect non-multihash") + return + } + }) } } @@ -343,8 +350,29 @@ func BenchmarkDecode(b *testing.B) { pre[1] = byte(uint8(len(ob))) nb := append(pre, ob...) + b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { Decode(nb) } } + +func BenchmarkCast(b *testing.B) { + tc := testCases[0] + ob, err := hex.DecodeString(tc.hex) + if err != nil { + b.Error(err) + return + } + + pre := make([]byte, 2) + pre[0] = byte(uint8(tc.code)) + pre[1] = byte(uint8(len(ob))) + nb := append(pre, ob...) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Cast(nb) + } +}