Skip to content

Commit

Permalink
codec: implement protobuf unknown fields checker (#6557)
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em authored Jul 29, 2020
1 parent a4d1f30 commit b0c73ae
Show file tree
Hide file tree
Showing 11 changed files with 14,813 additions and 1,636 deletions.
115 changes: 115 additions & 0 deletions codec/unknownproto/benchmarks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package unknownproto_test

import (
"sync"
"testing"

"github.com/gogo/protobuf/proto"

"github.com/cosmos/cosmos-sdk/codec/unknownproto"
"github.com/cosmos/cosmos-sdk/testutil/testdata"
)

var n1BBlob []byte

func init() {
n1B := &testdata.Nested1B{
Id: 1,
Age: 99,
Nested: &testdata.Nested2B{
Id: 2,
Route: "Wintery route",
Fee: 99,
Nested: &testdata.Nested3B{
Id: 3,
Name: "3A this one that one there those oens",
Age: 4588,
B4: []*testdata.Nested4B{
{
Id: 4,
Age: 88,
Name: "Nested4B",
},
},
},
},
}

var err error
n1BBlob, err = proto.Marshal(n1B)
if err != nil {
panic(err)
}
}

func BenchmarkRejectUnknownFields_serial(b *testing.B) {
benchmarkRejectUnknownFields(b, false)
}
func BenchmarkRejectUnknownFields_parallel(b *testing.B) {
benchmarkRejectUnknownFields(b, true)
}

func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
b.ReportAllocs()

if !parallel {
ckr := new(unknownproto.Checker)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
}
} else {
var mu sync.Mutex
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ckr := new(unknownproto.Checker)
for pb.Next() {
// To simulate the conditions of multiple transactions being processed in parallel.
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()
b.SetBytes(int64(len(n1BBlob)))
mu.Unlock()
}
})
}
}

func BenchmarkProtoUnmarshal_serial(b *testing.B) {
benchmarkProtoUnmarshal(b, false)
}
func BenchmarkProtoUnmarshal_parallel(b *testing.B) {
benchmarkProtoUnmarshal(b, true)
}
func benchmarkProtoUnmarshal(b *testing.B, parallel bool) {
b.ReportAllocs()

if !parallel {
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := proto.Unmarshal(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
}
} else {
var mu sync.Mutex
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n1A := new(testdata.Nested1A)
if err := proto.Unmarshal(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()
b.SetBytes(int64(len(n1BBlob)))
mu.Unlock()
}
})
}
}
28 changes: 28 additions & 0 deletions codec/unknownproto/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
unknownproto implements functionality to "type check" protobuf serialized byte sequences
against an expected proto.Message to report:
a) Unknown fields in the stream -- this is indicative of mismatched services, perhaps a malicious actor
b) Mismatched wire types for a field -- this is indicative of mismatched services
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) as
ckr := new(unknownproto.Checker)
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
// Handle the error.
}
and ideally should be added before invoking proto.Unmarshal, if you'd like to enforce the features mentioned above.
By default, for security we report every single field that's unknown, whether a non-critical field or not. To customize
this behavior, please create a Checker and set the AllowUnknownNonCriticals to true, for example:
ckr := &unknownproto.Checker{
AllowUnknownNonCriticals: true,
}
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
// Handle the error.
}
*/
package unknownproto
32 changes: 32 additions & 0 deletions codec/unknownproto/unit_helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package unknownproto

import (
"fmt"
"testing"

"google.golang.org/protobuf/encoding/protowire"
)

func TestWireTypeToString(t *testing.T) {
tests := []struct {
typ protowire.Type
want string
}{
{typ: 0, want: "varint"},
{typ: 1, want: "fixed64"},
{typ: 2, want: "bytes"},
{typ: 3, want: "start_group"},
{typ: 4, want: "end_group"},
{typ: 5, want: "fixed32"},
{typ: 95, want: "unknown type: 95"},
}

for _, tt := range tests {
tt := tt
t.Run(fmt.Sprintf("wireType=%d", tt.typ), func(t *testing.T) {
if g, w := wireTypeToString(tt.typ), tt.want; g != w {
t.Fatalf("Mismatch:\nGot: %q\nWant: %q\n", g, w)
}
})
}
}
Loading

0 comments on commit b0c73ae

Please sign in to comment.