From 6e4c69f947e2ed7dc1d2841ed44ad35629ba9bd6 Mon Sep 17 00:00:00 2001 From: w1ck3dg0ph3r Date: Thu, 5 Nov 2020 13:49:01 +0300 Subject: [PATCH] Allow unmarshalers to handle explicit null values --- Makefile | 3 ++- gen/decoder.go | 18 +++++++++----- tests/defined_null.go | 51 ++++++++++++++++++++++++++++++++++++++ tests/defined_null_test.go | 26 +++++++++++++++++++ 4 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 tests/defined_null.go create mode 100644 tests/defined_null_test.go diff --git a/Makefile b/Makefile index c5273407..96ddf164 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,8 @@ generate: build ./tests/intern.go \ ./tests/nocopy.go \ ./tests/escaping.go \ - ./tests/nested_marshaler.go + ./tests/nested_marshaler.go \ + ./tests/defined_null.go bin/easyjson -snake_case ./tests/snake.go bin/easyjson -omit_empty ./tests/omitempty.go bin/easyjson -build_tags=use_easyjson -disable_members_unescape ./benchmark/data.go diff --git a/gen/decoder.go b/gen/decoder.go index 0a0faa26..328d4eac 100644 --- a/gen/decoder.go +++ b/gen/decoder.go @@ -110,15 +110,26 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field ws := strings.Repeat(" ", indent) // Check whether type is primitive, needs to be done after interface check. if dec := customDecoders[t.String()]; dec != "" { + fmt.Fprintln(g.out, " if in.IsNull() {") + fmt.Fprintln(g.out, " in.Skip()") + fmt.Fprintln(g.out, " } else {") fmt.Fprintln(g.out, ws+out+" = "+dec) + fmt.Fprintln(g.out, " }") return nil } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString { + fmt.Fprintln(g.out, " if in.IsNull() {") + fmt.Fprintln(g.out, " in.Skip()") + fmt.Fprintln(g.out, " } else {") if tags.intern && t.Kind() == reflect.String { dec = "in.StringIntern()" } fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, " }") return nil } else if dec := primitiveDecoders[t.Kind()]; dec != "" { + fmt.Fprintln(g.out, " if in.IsNull() {") + fmt.Fprintln(g.out, " in.Skip()") + fmt.Fprintln(g.out, " } else {") if tags.intern && t.Kind() == reflect.String { dec = "in.StringIntern()" } @@ -126,6 +137,7 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field dec = "in.UnsafeString()" } fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") + fmt.Fprintln(g.out, " }") return nil } @@ -514,12 +526,6 @@ func (g *Generator) genStructDecoder(t reflect.Type) error { fmt.Fprintln(g.out, " for !in.IsDelim('}') {") fmt.Fprintf(g.out, " key := in.UnsafeFieldName(%v)\n", g.skipMemberNameUnescaping) fmt.Fprintln(g.out, " in.WantColon()") - fmt.Fprintln(g.out, " if in.IsNull() {") - fmt.Fprintln(g.out, " in.Skip()") - fmt.Fprintln(g.out, " in.WantComma()") - fmt.Fprintln(g.out, " continue") - fmt.Fprintln(g.out, " }") - fmt.Fprintln(g.out, " switch key {") for _, f := range fs { if err := g.genStructFieldDecoder(t, f); err != nil { diff --git a/tests/defined_null.go b/tests/defined_null.go new file mode 100644 index 00000000..1c3e88d3 --- /dev/null +++ b/tests/defined_null.go @@ -0,0 +1,51 @@ +package tests + +import ( + "github.com/mailru/easyjson/jlexer" + "bytes" + "encoding/json" +) + +//easyjson:json +type NullStringStruct struct { + NS NullString + VNS VanillaNullString + S string +} + +type NullString struct { + V string + Valid bool + Set bool +} + +func (s *NullString) UnmarshalEasyJSON(l *jlexer.Lexer) { + s.Set = true + if l.IsNull() { + l.Skip() + s.V, s.Valid = "", false + return + } + s.V, s.Valid = l.String(), true +} + +type VanillaNullString struct { + V string + Valid bool + Set bool +} + +func (s *VanillaNullString) UnmarshalJSON(v []byte) error { + s.Set = true + if bytes.Equal(v, []byte("null")) { + s.V, s.Valid = "", false + return nil + } + err := json.Unmarshal(v, &s.V) + if err != nil { + s.Valid = false + return err + } + s.Valid = true + return nil +} \ No newline at end of file diff --git a/tests/defined_null_test.go b/tests/defined_null_test.go new file mode 100644 index 00000000..84e2817d --- /dev/null +++ b/tests/defined_null_test.go @@ -0,0 +1,26 @@ +package tests + +import ( + "testing" + "github.com/mailru/easyjson" + "reflect" +) + +func TestDefinedNull(t *testing.T) { + cases := []struct{data string; res NullStringStruct}{ + {`{"NS": "ns", "VNS": "vns", "S": "s"}`, NullStringStruct{NullString{"ns", true, true}, VanillaNullString{"vns", true, true}, "s"}}, + {`{"NS": null, "VNS": null, "S": null}`, NullStringStruct{NullString{"", false, true}, VanillaNullString{"", false, true}, ""}}, + {`{"Unknown": "Value"}`, NullStringStruct{NullString{"", false, false}, VanillaNullString{"", false, false}, ""}}, + } + for _, c := range cases { + var res NullStringStruct + if err := easyjson.Unmarshal([]byte(c.data), &res); err != nil { + t.Errorf("Unexpected Unmarshal erorr: %v", err) + } + if !reflect.DeepEqual(res, c.res) { + t.Errorf("Expected to unmarshal %+v, got %+v", c.res, res) + } + } +} + +