Skip to content

Commit

Permalink
test: add support for non-strict-equality decode tests
Browse files Browse the repository at this point in the history
Future work will enable us to decode structures containing deeply nested
pointers, which will fail strict equality checks.
  • Loading branch information
DHowett committed Dec 26, 2024
1 parent 2e1fca7 commit b06b43f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
13 changes: 7 additions & 6 deletions common_data_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
)

type TestData struct {
Name string
Value interface{}
DecodeValue interface{} // used when the document cannot encode parts of Value
Documents map[int][]byte
SkipDecode map[int]bool
SkipEncode map[int]bool
Name string
Value interface{}
DecodeValue interface{} // used when the document cannot encode parts of Value
TestDecodedValue func(interface{}) error
Documents map[int][]byte
SkipDecode map[int]bool
SkipEncode map[int]bool
}

type SparseBundleHeader struct {
Expand Down
21 changes: 15 additions & 6 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ func TestDecode(t *testing.T) {
expVal = expReflect.Interface()

results := make(map[int]interface{})
for fmt, doc := range test.Documents {
if test.SkipDecode[fmt] {
for format, doc := range test.Documents {
if test.SkipDecode[format] {
return
}
subtest(t, FormatNames[fmt], func(t *testing.T) {
subtest(t, FormatNames[format], func(t *testing.T) {
val := reflect.New(expReflect.Type()).Interface()
_, err := Unmarshal(doc, val)
if err != nil {
Expand All @@ -144,9 +144,18 @@ func TestDecode(t *testing.T) {
val = valReflect.Interface()
}

results[fmt] = val
if !reflect.DeepEqual(expVal, val) {
t.Logf("Expected: %#v\n", expVal)
results[format] = val
var passErr error
if test.TestDecodedValue == nil {
if !reflect.DeepEqual(expVal, val) {
passErr = fmt.Errorf("Expected: %#v", expVal)
}
} else {
passErr = test.TestDecodedValue(val)
}

if passErr != nil {
t.Logf("%v\n", passErr)
t.Logf("Received: %#v\n", val)
t.Fail()
}
Expand Down

0 comments on commit b06b43f

Please sign in to comment.