From 1a78435c66fc651202c3050b681d923bab8651c6 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Tue, 21 Mar 2023 11:08:26 -0600 Subject: [PATCH] GODRIVER-2766 support inherited defaultDocumentType (#1202) --- bson/bsoncodec/struct_codec.go | 7 ++- bson/decoder_test.go | 111 ++++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 4 deletions(-) diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index be3f2081e9..da1ae18e02 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -326,7 +326,12 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r } field = field.Addr() - dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate} + dctx := DecodeContext{ + Registry: r.Registry, + Truncate: fd.truncate || r.Truncate, + defaultDocumentType: r.defaultDocumentType, + } + if fd.decoder == nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 8517f5861b..3391600558 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -23,8 +23,14 @@ import ( ) func TestBasicDecode(t *testing.T) { + t.Parallel() + for _, tc := range unmarshalingTestCases() { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := reflect.New(tc.sType).Elem() vr := bsonrw.NewBSONDocumentReader(tc.data) reg := DefaultRegistry @@ -38,9 +44,17 @@ func TestBasicDecode(t *testing.T) { } func TestDecoderv2(t *testing.T) { + t.Parallel() + t.Run("Decode", func(t *testing.T) { + t.Parallel() + for _, tc := range unmarshalingTestCases() { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := reflect.New(tc.sType).Interface() vr := bsonrw.NewBSONDocumentReader(tc.data) dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr) @@ -51,6 +65,8 @@ func TestDecoderv2(t *testing.T) { }) } t.Run("lookup error", func(t *testing.T) { + t.Parallel() + type certainlydoesntexistelsewhereihope func(string, string) string // Avoid unused code lint error. _ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" }) @@ -63,6 +79,8 @@ func TestDecoderv2(t *testing.T) { assert.Equal(t, want, got, "Received unexpected error.") }) t.Run("Unmarshaler", func(t *testing.T) { + t.Parallel() + testCases := []struct { name string err error @@ -90,7 +108,11 @@ func TestDecoderv2(t *testing.T) { } for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + unmarshaler := &testUnmarshaler{err: tc.err} dec, err := NewDecoder(tc.vr) noerr(t, err) @@ -110,6 +132,8 @@ func TestDecoderv2(t *testing.T) { } t.Run("Unmarshaler/success bsonrw.ValueReader", func(t *testing.T) { + t.Parallel() + want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)) unmarshaler := &testUnmarshaler{} vr := bsonrw.NewBSONDocumentReader(want) @@ -125,7 +149,11 @@ func TestDecoderv2(t *testing.T) { }) }) t.Run("NewDecoder", func(t *testing.T) { + t.Parallel() + t.Run("error", func(t *testing.T) { + t.Parallel() + _, got := NewDecoder(nil) want := errors.New("cannot create a new Decoder with a nil ValueReader") if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { @@ -133,6 +161,8 @@ func TestDecoderv2(t *testing.T) { } }) t.Run("success", func(t *testing.T) { + t.Parallel() + got, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) noerr(t, err) if got == nil { @@ -141,7 +171,11 @@ func TestDecoderv2(t *testing.T) { }) }) t.Run("NewDecoderWithContext", func(t *testing.T) { + t.Parallel() + t.Run("errors", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} _, got := NewDecoderWithContext(dc, nil) want := errors.New("cannot create a new Decoder with a nil ValueReader") @@ -150,6 +184,8 @@ func TestDecoderv2(t *testing.T) { } }) t.Run("success", func(t *testing.T) { + t.Parallel() + got, err := NewDecoderWithContext(bsoncodec.DecodeContext{}, bsonrw.NewBSONDocumentReader([]byte{})) noerr(t, err) if got == nil { @@ -164,6 +200,8 @@ func TestDecoderv2(t *testing.T) { }) }) t.Run("Decode doesn't zero struct", func(t *testing.T) { + t.Parallel() + type foo struct { Item string Qty int @@ -182,6 +220,8 @@ func TestDecoderv2(t *testing.T) { assert.Equal(t, want, got, "Results do not match.") }) t.Run("Reset", func(t *testing.T) { + t.Parallel() + vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{}) dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} dec, err := NewDecoderWithContext(dc, vr1) @@ -196,6 +236,8 @@ func TestDecoderv2(t *testing.T) { } }) t.Run("SetContext", func(t *testing.T) { + t.Parallel() + dc1 := bsoncodec.DecodeContext{Registry: DefaultRegistry} dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{})) @@ -210,6 +252,8 @@ func TestDecoderv2(t *testing.T) { } }) t.Run("SetRegistry", func(t *testing.T) { + t.Parallel() + r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() dc1 := bsoncodec.DecodeContext{Registry: r1} dc2 := bsoncodec.DecodeContext{Registry: r2} @@ -225,6 +269,8 @@ func TestDecoderv2(t *testing.T) { } }) t.Run("DecodeToNil", func(t *testing.T) { + t.Parallel() + data := docToBytes(D{{"item", "canvas"}, {"qty", 4}}) vr := bsonrw.NewBSONDocumentReader(data) dec, err := NewDecoder(vr) @@ -236,7 +282,9 @@ func TestDecoderv2(t *testing.T) { t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err) } }) - t.Run("SetDocumentType embedded map as empty interface", func(t *testing.T) { + t.Run("DefaultDocuemntD embedded map as empty interface", func(t *testing.T) { + t.Parallel() + type someMap map[string]interface{} in := make(someMap) @@ -267,7 +315,9 @@ func TestDecoderv2(t *testing.T) { bsonFooOutType := reflect.TypeOf(bsonOut["foo"]) assert.Equal(t, mType, bsonFooOutType, "expected %v to equal %v", mType.String(), bsonFooOutType.String()) }) - t.Run("SetDocumentType for decoding into interface{} alias", func(t *testing.T) { + t.Run("DefaultDocuemntD for decoding into interface{} alias", func(t *testing.T) { + t.Parallel() + var in interface{} = map[string]interface{}{"bar": "baz"} bytes, err := Marshal(in) @@ -291,7 +341,9 @@ func TestDecoderv2(t *testing.T) { assert.Equal(t, dType, bsonOutType, "expected %v to equal %v", dType.String(), bsonOutType.String()) }) - t.Run("SetDocumentType for decoding into non-interface{} alias", func(t *testing.T) { + t.Run("DefaultDocuemntD for decoding into non-interface{} alias", func(t *testing.T) { + t.Parallel() + var in interface{} = map[string]interface{}{"bar": "baz"} bytes, err := Marshal(in) @@ -315,6 +367,59 @@ func TestDecoderv2(t *testing.T) { assert.NotEqual(t, dType, bsonOutType, "expected %v to not equal %v", dType.String(), bsonOutType.String()) }) + t.Run("DefaultDocumentD for deep struct values", func(t *testing.T) { + t.Parallel() + + type emb struct { + Foo map[int]interface{} `bson:"foo"` + } + + objID := primitive.NewObjectID() + + in := emb{ + Foo: map[int]interface{}{ + 1: map[string]interface{}{"bar": "baz"}, + 2: map[int]interface{}{ + 3: map[string]interface{}{"bar": "baz"}, + }, + 4: map[primitive.ObjectID]interface{}{ + objID: map[string]interface{}{"bar": "baz"}, + }, + }, + } + + bytes, err := Marshal(in) + if err != nil { + t.Fatal(err) + } + + dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes)) + if err != nil { + t.Fatal(err) + } + + dec.DefaultDocumentD() + + var out emb + if err := dec.Decode(&out); err != nil { + t.Fatal(err) + } + + mType := reflect.TypeOf(primitive.M{}) + bsonOutType := reflect.TypeOf(out) + assert.NotEqual(t, mType, bsonOutType, + "expected %v to not equal %v", mType.String(), bsonOutType.String()) + + want := emb{ + Foo: map[int]interface{}{ + 1: primitive.D{{Key: "bar", Value: "baz"}}, + 2: primitive.D{{Key: "3", Value: primitive.D{{Key: "bar", Value: "baz"}}}}, + 4: primitive.D{{Key: objID.Hex(), Value: primitive.D{{Key: "bar", Value: "baz"}}}}, + }, + } + + assert.Equal(t, want, out, "expected %v, got %v", want, out) + }) } type testUnmarshaler struct {