Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2766 Support inherited defaultDocumentType #1202

Merged
merged 1 commit into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,12 @@ func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val r
}
field = field.Addr()

dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
dctx := DecodeContext{
Registry: r.Registry,
Truncate: fd.truncate || r.Truncate,
defaultDocumentType: r.defaultDocumentType,
}

if fd.decoder == nil {
return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
}
Expand Down
111 changes: 108 additions & 3 deletions bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ import (
)

func TestBasicDecode(t *testing.T) {
t.Parallel()

for _, tc := range unmarshalingTestCases() {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got := reflect.New(tc.sType).Elem()
vr := bsonrw.NewBSONDocumentReader(tc.data)
reg := DefaultRegistry
Expand All @@ -38,9 +44,17 @@ func TestBasicDecode(t *testing.T) {
}

func TestDecoderv2(t *testing.T) {
t.Parallel()

t.Run("Decode", func(t *testing.T) {
t.Parallel()

for _, tc := range unmarshalingTestCases() {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got := reflect.New(tc.sType).Interface()
vr := bsonrw.NewBSONDocumentReader(tc.data)
dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr)
Expand All @@ -51,6 +65,8 @@ func TestDecoderv2(t *testing.T) {
})
}
t.Run("lookup error", func(t *testing.T) {
t.Parallel()

type certainlydoesntexistelsewhereihope func(string, string) string
// Avoid unused code lint error.
_ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" })
Expand All @@ -63,6 +79,8 @@ func TestDecoderv2(t *testing.T) {
assert.Equal(t, want, got, "Received unexpected error.")
})
t.Run("Unmarshaler", func(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
err error
Expand Down Expand Up @@ -90,7 +108,11 @@ func TestDecoderv2(t *testing.T) {
}

for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

unmarshaler := &testUnmarshaler{err: tc.err}
dec, err := NewDecoder(tc.vr)
noerr(t, err)
Expand All @@ -110,6 +132,8 @@ func TestDecoderv2(t *testing.T) {
}

t.Run("Unmarshaler/success bsonrw.ValueReader", func(t *testing.T) {
t.Parallel()

want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
unmarshaler := &testUnmarshaler{}
vr := bsonrw.NewBSONDocumentReader(want)
Expand All @@ -125,14 +149,20 @@ func TestDecoderv2(t *testing.T) {
})
})
t.Run("NewDecoder", func(t *testing.T) {
t.Parallel()

t.Run("error", func(t *testing.T) {
t.Parallel()

_, got := NewDecoder(nil)
want := errors.New("cannot create a new Decoder with a nil ValueReader")
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("Was expecting error but got different error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
t.Parallel()

got, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
noerr(t, err)
if got == nil {
Expand All @@ -141,7 +171,11 @@ func TestDecoderv2(t *testing.T) {
})
})
t.Run("NewDecoderWithContext", func(t *testing.T) {
t.Parallel()

t.Run("errors", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
_, got := NewDecoderWithContext(dc, nil)
want := errors.New("cannot create a new Decoder with a nil ValueReader")
Expand All @@ -150,6 +184,8 @@ func TestDecoderv2(t *testing.T) {
}
})
t.Run("success", func(t *testing.T) {
t.Parallel()

got, err := NewDecoderWithContext(bsoncodec.DecodeContext{}, bsonrw.NewBSONDocumentReader([]byte{}))
noerr(t, err)
if got == nil {
Expand All @@ -164,6 +200,8 @@ func TestDecoderv2(t *testing.T) {
})
})
t.Run("Decode doesn't zero struct", func(t *testing.T) {
t.Parallel()

type foo struct {
Item string
Qty int
Expand All @@ -182,6 +220,8 @@ func TestDecoderv2(t *testing.T) {
assert.Equal(t, want, got, "Results do not match.")
})
t.Run("Reset", func(t *testing.T) {
t.Parallel()

vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{})
dc := bsoncodec.DecodeContext{Registry: DefaultRegistry}
dec, err := NewDecoderWithContext(dc, vr1)
Expand All @@ -196,6 +236,8 @@ func TestDecoderv2(t *testing.T) {
}
})
t.Run("SetContext", func(t *testing.T) {
t.Parallel()

dc1 := bsoncodec.DecodeContext{Registry: DefaultRegistry}
dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{}))
Expand All @@ -210,6 +252,8 @@ func TestDecoderv2(t *testing.T) {
}
})
t.Run("SetRegistry", func(t *testing.T) {
t.Parallel()

r1, r2 := DefaultRegistry, NewRegistryBuilder().Build()
dc1 := bsoncodec.DecodeContext{Registry: r1}
dc2 := bsoncodec.DecodeContext{Registry: r2}
Expand All @@ -225,6 +269,8 @@ func TestDecoderv2(t *testing.T) {
}
})
t.Run("DecodeToNil", func(t *testing.T) {
t.Parallel()

data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
vr := bsonrw.NewBSONDocumentReader(data)
dec, err := NewDecoder(vr)
Expand All @@ -236,7 +282,9 @@ func TestDecoderv2(t *testing.T) {
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
}
})
t.Run("SetDocumentType embedded map as empty interface", func(t *testing.T) {
t.Run("DefaultDocuemntD embedded map as empty interface", func(t *testing.T) {
t.Parallel()

type someMap map[string]interface{}

in := make(someMap)
Expand Down Expand Up @@ -267,7 +315,9 @@ func TestDecoderv2(t *testing.T) {
bsonFooOutType := reflect.TypeOf(bsonOut["foo"])
assert.Equal(t, mType, bsonFooOutType, "expected %v to equal %v", mType.String(), bsonFooOutType.String())
})
t.Run("SetDocumentType for decoding into interface{} alias", func(t *testing.T) {
t.Run("DefaultDocuemntD for decoding into interface{} alias", func(t *testing.T) {
t.Parallel()

var in interface{} = map[string]interface{}{"bar": "baz"}

bytes, err := Marshal(in)
Expand All @@ -291,7 +341,9 @@ func TestDecoderv2(t *testing.T) {
assert.Equal(t, dType, bsonOutType,
"expected %v to equal %v", dType.String(), bsonOutType.String())
})
t.Run("SetDocumentType for decoding into non-interface{} alias", func(t *testing.T) {
t.Run("DefaultDocuemntD for decoding into non-interface{} alias", func(t *testing.T) {
t.Parallel()

var in interface{} = map[string]interface{}{"bar": "baz"}

bytes, err := Marshal(in)
Expand All @@ -315,6 +367,59 @@ func TestDecoderv2(t *testing.T) {
assert.NotEqual(t, dType, bsonOutType,
"expected %v to not equal %v", dType.String(), bsonOutType.String())
})
t.Run("DefaultDocumentD for deep struct values", func(t *testing.T) {
t.Parallel()

type emb struct {
Foo map[int]interface{} `bson:"foo"`
}

objID := primitive.NewObjectID()

in := emb{
Foo: map[int]interface{}{
1: map[string]interface{}{"bar": "baz"},
2: map[int]interface{}{
3: map[string]interface{}{"bar": "baz"},
},
4: map[primitive.ObjectID]interface{}{
objID: map[string]interface{}{"bar": "baz"},
},
},
}

bytes, err := Marshal(in)
if err != nil {
t.Fatal(err)
}

dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes))
if err != nil {
t.Fatal(err)
}

dec.DefaultDocumentD()

var out emb
if err := dec.Decode(&out); err != nil {
t.Fatal(err)
}

mType := reflect.TypeOf(primitive.M{})
bsonOutType := reflect.TypeOf(out)
assert.NotEqual(t, mType, bsonOutType,
"expected %v to not equal %v", mType.String(), bsonOutType.String())

want := emb{
Foo: map[int]interface{}{
1: primitive.D{{Key: "bar", Value: "baz"}},
2: primitive.D{{Key: "3", Value: primitive.D{{Key: "bar", Value: "baz"}}}},
4: primitive.D{{Key: objID.Hex(), Value: primitive.D{{Key: "bar", Value: "baz"}}}},
},
}

assert.Equal(t, want, out, "expected %v, got %v", want, out)
})
}

type testUnmarshaler struct {
Expand Down