Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace directive #346

Merged
merged 11 commits into from
Jul 2, 2024
11 changes: 8 additions & 3 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (d *decodeGen) gBase(b *BaseElem) {

// open block for 'tmp'
var tmp string
if b.Convert {
if b.Convert && b.Value != IDENT { // we don't need block for 'tmp' in case of IDENT
tmp = randIdent()
d.p.printf("\n{ var %s %s", tmp, b.BaseType())
}
Expand All @@ -165,7 +165,12 @@ func (d *decodeGen) gBase(b *BaseElem) {
d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname)
}
case IDENT:
d.p.printf("\nerr = %s.DecodeMsg(dc)", vname)
if b.Convert {
lowered := b.ToBase() + "(" + vname + ")"
d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered)
} else {
d.p.printf("\nerr = %s.DecodeMsg(dc)", vname)
}
case Ext:
d.p.printf("\nerr = dc.ReadExtension(%s)", vname)
default:
Expand All @@ -178,7 +183,7 @@ func (d *decodeGen) gBase(b *BaseElem) {
d.p.wrapErrCheck(d.ctx.ArgsStr())

// close block for 'tmp'
if b.Convert {
if b.Convert && b.Value != IDENT {
if b.ShimMode == Cast {
d.p.printf("\n%s = %s(%s)\n}", vname, b.FromBase(), tmp)
} else {
Expand Down
6 changes: 6 additions & 0 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ var primitives = map[string]Primitive{
"int64": Int64,
"bool": Bool,
"interface{}": Intf,
"any": Intf,
"time.Time": Time,
"time.Duration": Duration,
"msgp.Extension": Ext,
Expand Down Expand Up @@ -385,6 +386,11 @@ func (s *Ptr) SetVarname(a string) {
case *BaseElem:
// identities have pointer receivers
if x.Value == IDENT {
// replace directive sets Convert=true and Needsref=true
// since BaseElem is behind a pointer we set Needsref=false
if x.Convert {
x.Needsref(false)
}
x.SetVarname(a)
} else {
x.SetVarname("*" + a)
Expand Down
13 changes: 8 additions & 5 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ func (u *unmarshalGen) gBase(b *BaseElem) {

refname := b.Varname() // assigned to
lowered := b.Varname() // passed as argument
if b.Convert {
// begin 'tmp' block
// begin 'tmp' block
if b.Convert && b.Value != IDENT { // we don't need block for 'tmp' in case of IDENT
refname = randIdent()
lowered = b.ToBase() + "(" + lowered + ")"
u.p.printf("\n{\nvar %s %s", refname, b.BaseType())
Expand All @@ -152,18 +152,21 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
case Ext:
u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered)
case IDENT:
if b.Convert {
lowered = b.ToBase() + "(" + lowered + ")"
}
u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered)
default:
u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName())
}
u.p.wrapErrCheck(u.ctx.ArgsStr())

if b.Convert {
// close 'tmp' block
// close 'tmp' block
if b.Convert && b.Value != IDENT {
if b.ShimMode == Cast {
u.p.printf("\n%s = %s(%s)\n", b.Varname(), b.FromBase(), refname)
} else {
u.p.printf("\n%s, err = %s(%s)", b.Varname(), b.FromBase(), refname)
u.p.printf("\n%s, err = %s(%s)\n", b.Varname(), b.FromBase(), refname)
u.p.wrapErrCheck(u.ctx.ArgsStr())
}
u.p.printf("}")
Expand Down
42 changes: 37 additions & 5 deletions parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package parse
import (
"fmt"
"go/ast"
"go/parser"
"strings"

"github.com/tinylib/msgp/gen"
Expand All @@ -21,9 +22,10 @@ type passDirective func(gen.Method, []string, *gen.Printer) error
// to add a directive, define a func([]string, *FileSet) error
// and then add it to this list.
var directives = map[string]directive{
"shim": applyShim,
"ignore": ignore,
"tuple": astuple,
"shim": applyShim,
"replace": replace,
"ignore": ignore,
"tuple": astuple,
}

var passDirectives = map[string]passDirective{
Expand Down Expand Up @@ -53,7 +55,7 @@ func yieldComments(c []*ast.CommentGroup) []string {
return out
}

//msgp:shim {Type} as:{Newtype} using:{toFunc/fromFunc} mode:{Mode}
//msgp:shim {Type} as:{NewType} using:{toFunc/fromFunc} mode:{Mode}
func applyShim(text []string, f *FileSet) error {
if len(text) < 4 || len(text) > 5 {
return fmt.Errorf("shim directive should have 3 or 4 arguments; found %d", len(text)-1)
Expand Down Expand Up @@ -90,7 +92,37 @@ func applyShim(text []string, f *FileSet) error {
}

infof("%s -> %s\n", name, be.Value.String())
f.findShim(name, be)
f.findShim(name, be, true)

return nil
}

//msgp:replace {Type} with:{NewType}
func replace(text []string, f *FileSet) error {
if len(text) != 3 {
return fmt.Errorf("replace directive should have only 2 arguments; found %d", len(text)-1)
}

name := text[1]
replacement := strings.TrimPrefix(strings.TrimSpace(text[2]), "with:")

expr, err := parser.ParseExpr(replacement)
if err != nil {
return err
}
e := f.parseExpr(expr)

if be, ok := e.(*gen.BaseElem); ok {
be.Convert = true
be.Alias(name)
if be.Value == gen.IDENT {
be.ShimToBase = "(*" + replacement + ")"
be.Needsref(true)
}
}

infof("%s -> %s\n", name, replacement)
f.findShim(name, e, false)

return nil
}
Expand Down
33 changes: 17 additions & 16 deletions parse/inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,50 @@ import (
const maxComplex = 5

// begin recursive search for identities with the
// given name and replace them with be
func (f *FileSet) findShim(id string, be *gen.BaseElem) {
// given name and replace them with e
func (f *FileSet) findShim(id string, e gen.Elem, addID bool) {
for name, el := range f.Identities {
pushstate(name)
switch el := el.(type) {
case *gen.Struct:
for i := range el.Fields {
f.nextShim(&el.Fields[i].FieldElem, id, be)
f.nextShim(&el.Fields[i].FieldElem, id, e)
}
case *gen.Array:
f.nextShim(&el.Els, id, be)
f.nextShim(&el.Els, id, e)
case *gen.Slice:
f.nextShim(&el.Els, id, be)
f.nextShim(&el.Els, id, e)
case *gen.Map:
f.nextShim(&el.Value, id, be)
f.nextShim(&el.Value, id, e)
case *gen.Ptr:
f.nextShim(&el.Value, id, be)
f.nextShim(&el.Value, id, e)
}
popstate()
}
// we'll need this at the top level as well
f.Identities[id] = be
if addID {
f.Identities[id] = e
}
}

func (f *FileSet) nextShim(ref *gen.Elem, id string, be *gen.BaseElem) {
func (f *FileSet) nextShim(ref *gen.Elem, id string, e gen.Elem) {
if (*ref).TypeName() == id {
vn := (*ref).Varname()
*ref = be.Copy()
*ref = e.Copy()
(*ref).SetVarname(vn)
} else {
switch el := (*ref).(type) {
case *gen.Struct:
for i := range el.Fields {
f.nextShim(&el.Fields[i].FieldElem, id, be)
f.nextShim(&el.Fields[i].FieldElem, id, e)
}
case *gen.Array:
f.nextShim(&el.Els, id, be)
f.nextShim(&el.Els, id, e)
case *gen.Slice:
f.nextShim(&el.Els, id, be)
f.nextShim(&el.Els, id, e)
case *gen.Map:
f.nextShim(&el.Value, id, be)
f.nextShim(&el.Value, id, e)
case *gen.Ptr:
f.nextShim(&el.Value, id, be)
f.nextShim(&el.Value, id, e)
}
}
}
Expand Down
Loading