Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom struct tags #347

Merged
merged 7 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions _generated/custom_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package _generated

//go:generate msgp
//msgp:tag mytag

type CustomTag struct {
Foo string `mytag:"foo_custom_name"`
Bar int `mytag:"bar1234"`
}
64 changes: 64 additions & 0 deletions _generated/custom_tag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package _generated

import (
"encoding/json"
"fmt"
"reflect"
"testing"

"bytes"

"github.com/tinylib/msgp/msgp"
)

func TestCustomTag(t *testing.T) {
t.Run("File Scope", func(t *testing.T) {
ts := CustomTag{
Foo: "foostring13579",
Bar: 999_999}
encDecCustomTag(t, ts, "mytag")
})
}

func encDecCustomTag(t *testing.T, testStruct msgp.Encodable, tag string) {
var b bytes.Buffer
msgp.Encode(&b, testStruct)

// Check tag names using JSON as an intermediary layer
// TODO: is there a way to avoid the JSON layer? We'd need to directly decode raw msgpack -> map[string]any
refJSON, err := json.Marshal(testStruct)
if err != nil {
t.Error(fmt.Sprintf("error encoding struct as JSON: %v", err))
}
ref := make(map[string]any)
// Encoding and decoding the original struct via JSON is necessary
// for field comparisons to work, since JSON -> map[string]any
// relies on type inferences such as all numbers being float64s
json.Unmarshal(refJSON, &ref)

var encJSON bytes.Buffer
msgp.UnmarshalAsJSON(&encJSON, b.Bytes())
encoded := make(map[string]any)
json.Unmarshal(encJSON.Bytes(), &encoded)

tsType := reflect.TypeOf(testStruct)
for i := 0; i < tsType.NumField(); i++ {
// Check encoded field name
field := tsType.Field(i)
encodedValue, ok := encoded[field.Tag.Get(tag)]
if !ok {
t.Error("missing encoded value for field", field.Name)
continue
}
// Check encoded field value (against original value post-JSON enc + dec)
jsonName, ok := field.Tag.Lookup("json")
if !ok {
jsonName = field.Name
}
refValue := ref[jsonName]
if !reflect.DeepEqual(refValue, encodedValue) {
t.Error(fmt.Sprintf("incorrect encoded value for field %s. reference: %v, encoded: %v",
field.Name, refValue, encodedValue))
}
}
}
20 changes: 18 additions & 2 deletions parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ type passDirective func(gen.Method, []string, *gen.Printer) error
var directives = map[string]directive{
"shim": applyShim,
"ignore": ignore,
"tuple": astuple,
}
"tuple": astuple}

// map of all recognized directives which will be applied
// before process() is called
//
// to add an early directive, define a func([]string, *FileSet) error
// and then add it to this list.
var earlyDirectives = map[string]directive{
"tag": tag}

var passDirectives = map[string]passDirective{
"ignore": passignore,
Expand Down Expand Up @@ -128,3 +135,12 @@ func astuple(text []string, f *FileSet) error {
}
return nil
}

//msgp:tag {tagname}
func tag(text []string, f *FileSet) error {
if len(text) != 2 {
return nil
}
f.tagName = strings.TrimSpace(text[1])
return nil
}
33 changes: 32 additions & 1 deletion parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type FileSet struct {
Identities map[string]gen.Elem // processed from specs
Directives []string // raw preprocessor directives
Imports []*ast.ImportSpec // imports
tagName string // tag to read field names from
}

// File parses a file at the relative path
Expand Down Expand Up @@ -82,6 +83,7 @@ func File(name string, unexported bool) (*FileSet, error) {
return nil, fmt.Errorf("no definitions in %s", name)
}

fs.applyEarlyDirectives()
fs.process()
fs.applyDirectives()
fs.propInline()
Expand Down Expand Up @@ -112,6 +114,29 @@ func (f *FileSet) applyDirectives() {
f.Directives = newdirs
}

// applyEarlyDirectives applies all early directives needed before process() is called.
// additional directives remain in f.Directives for future processing
func (f *FileSet) applyEarlyDirectives() {
newdirs := make([]string, 0, len(f.Directives))
for _, d := range f.Directives {
parts := strings.Split(d, " ")
if len(parts) == 0 {
continue
}
if fn, ok := earlyDirectives[parts[0]]; ok {
pushstate(parts[0])
err := fn(parts, f)
if err != nil {
warnf("early directive error: %s", err)
}
popstate()
} else {
newdirs = append(newdirs, d)
}
}
f.Directives = newdirs
}

// A linkset is a graph of unresolved
// identities.
//
Expand Down Expand Up @@ -329,7 +354,13 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField {
var extension, flatten bool
// parse tag; otherwise field name is field tag
if f.Tag != nil {
body := reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg")
var body string
if fs.tagName != "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get(fs.tagName)
}
if body == "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msg")
}
if body == "" {
body = reflect.StructTag(strings.Trim(f.Tag.Value, "`")).Get("msgpack")
}
Expand Down
Loading