Skip to content

Commit

Permalink
feat: add support for go2proto generating ToProto() methods (#3690)
Browse files Browse the repository at this point in the history
Some notes:
- I haven't converted the schema types to use this yet.
- FromProto functions will be in a separate PR.
- maps/slices aren't supported yet

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
alecthomas and github-actions[bot] authored Dec 9, 2024
1 parent b2f5613 commit 4c5bcba
Show file tree
Hide file tree
Showing 9 changed files with 1,101 additions and 84 deletions.
13 changes: 11 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,18 @@ dev *args:
watchexec -r {{WATCHEXEC_ARGS}} -- "just build-sqlc && ftl dev --plain {{args}}"

# Build everything
build-all: build-protos-unconditionally build-backend build-frontend build-backend-tests build-generate build-sqlc build-zips lsp-generate build-jvm build-language-plugins
build-all: build-protos-unconditionally build-backend build-frontend build-backend-tests build-generate build-sqlc build-zips lsp-generate build-jvm build-language-plugins build-go2proto-testdata

# Run "go generate" on all packages
build-generate:
@mk internal/schema/aliaskind_enumer.go : internal/schema/metadataalias.go -- go generate -x ./internal/schema
@mk internal/log/log_level_string.go : internal/log/api.go -- go generate -x ./internal/log

# Generate testdata for go2proto. This should be run after any changes to go2proto.
build-go2proto-testdata:
@mk cmd/go2proto/testdata/go2proto.to.go cmd/go2proto/testdata/testdatapb/model.proto : cmd/go2proto/*.go cmd/go2proto/testdata/model.go -- go2proto -m -o ./cmd/go2proto/testdata/testdatapb/model.proto -O 'go_package="github.com/TBD54566975/ftl/cmd/go2proto/testdata/testdatapb"' xyz.block.ftl.go2proto.test ./cmd/go2proto/testdata.Root
@mk cmd/go2proto/testdata/testdatapb/model.pb.go : cmd/go2proto/testdata/testdatapb/model.proto -- '(cd ./cmd/go2proto/testdata/testdatapb && protoc --go_out=paths=source_relative:. model.proto) && go build ./cmd/go2proto/testdata'

# Build command-line tools
build +tools: build-protos build-zips build-frontend
@just build-without-frontend $@
Expand Down Expand Up @@ -236,9 +241,13 @@ ensure-frozen-migrations:
@scripts/ensure-frozen-migrations

# Run backend tests
test-backend:
test-backend: test-go2proto
@gotestsum --hide-summary skipped --format-hide-empty-pkg -- -short -fullpath ./...

# Run go2proto tests
test-go2proto: build-go2proto-testdata
@gotestsum --hide-summary skipped --format-hide-empty-pkg -- -short -fullpath ./cmd/go2proto/testdata

# Run shell script tests
test-scripts:
GIT_COMMITTER_NAME="CI" \
Expand Down
160 changes: 78 additions & 82 deletions cmd/go2proto/main.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package main

import (
"encoding/json"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"maps"
"os"
"path/filepath"
"reflect"
"sort"
"strconv"
"strings"
"text/template"

"github.com/alecthomas/kong"
"golang.org/x/tools/go/packages"
Expand Down Expand Up @@ -94,8 +95,9 @@ And this is the corresponding protobuf schema:
`

type File struct {
Imports []string
Decls []Decl
GoPackage string
Imports []string
Decls []Decl
}

func (f *File) AddImport(name string) {
Expand All @@ -116,6 +118,15 @@ func (f File) OrderedDecls() []Decl {
return decls
}

func (f File) TypeOf(name string) string {
for _, decl := range f.Decls {
if decl.DeclName() == name {
return decl.DeclName()
}
}
return ""
}

//sumtype:decl
type Decl interface {
decl()
Expand All @@ -131,11 +142,24 @@ func (Message) decl() {}
func (m Message) DeclName() string { return m.Name }

type Field struct {
ID int
Name string
Type string
Optional bool
Repeated bool
ID int
Name string
Type string
Optional bool
Repeated bool
ProtoGoType string
Pointer bool
}

var reservedWords = map[string]string{
"String": "String_",
}

func (f Field) EscapedName() string {
if name, ok := reservedWords[f.Name]; ok {
return name
}
return f.Name
}

type Enum struct {
Expand All @@ -155,17 +179,18 @@ func (Enum) decl() {}
func (e Enum) DeclName() string { return e.Name }

type SumType struct {
Name string
Elems map[string]int
Name string
Variants map[string]int
}

func (SumType) decl() {}
func (s SumType) DeclName() string { return s.Name }

type Config struct {
Output string `help:"Output file to write generated protobuf schema to." short:"o"`
Output string `help:"Output file to write generated protobuf schema to." short:"o" xor:"output"`
JSON bool `help:"Dump intermediate JSON represesentation." short:"j" xor:"output"`
Options map[string]string `placeholder:"OPTION=VALUE" help:"Additional options to include in the generated protobuf schema. Note: strings must be double quoted." short:"O" mapsep:"\\0"`
// Mappers bool `help:"Generate ToProto and FromProto mappers for each message." short:"m"`
Mappers bool `help:"Generate ToProto and FromProto mappers for each message." short:"m"`

Package string `arg:"" help:"Package name to use in the generated protobuf schema."`
Ref []string `arg:"" help:"Type to generate protobuf schema from in the form PKG.TYPE. eg. github.com/foo/bar/waz.Waz or ./waz.Waz" required:"true" placeholder:"PKG.TYPE"`
Expand Down Expand Up @@ -226,9 +251,27 @@ func main() {
kctx.FatalIfErrorf(err)
}

if cli.JSON {
b, err := json.MarshalIndent(file, "", " ")
kctx.FatalIfErrorf(err)
fmt.Println(string(b))
return
}

err = render(out, cli, file)
kctx.FatalIfErrorf(err)

if cli.Mappers {
w, err := os.CreateTemp(resolved.Path, "go2proto.to.go-*")
kctx.FatalIfErrorf(err)
defer os.Remove(w.Name())
defer w.Close()
err = renderToProto(w, cli, file)
kctx.FatalIfErrorf(err)
err = os.Rename(w.Name(), filepath.Join(resolved.Path, "go2proto.to.go"))
kctx.FatalIfErrorf(err)
}

if cli.Output != "" {
err = os.Rename(cli.Output+"~", cli.Output)
}
Expand Down Expand Up @@ -266,69 +309,6 @@ func genErrorf(pos token.Pos, format string, args ...any) error {
return &GenError{pos: pos, err: err}
}

var tmpl = template.Must(template.New("proto").
Funcs(template.FuncMap{
"typeof": func(t any) string { return reflect.Indirect(reflect.ValueOf(t)).Type().Name() },
"toLowerCamel": strcase.ToLowerCamel,
"toUpperCamel": strcase.ToUpperCamel,
"toLowerSnake": strcase.ToLowerSnake,
"toUpperSnake": strcase.ToUpperSnake,
"trimPrefix": strings.TrimPrefix,
}).
Parse(`
// THIS FILE IS GENERATED; DO NOT MODIFY
syntax = "proto3";
package {{ .Package }};
{{ range .Imports }}
import "{{.}}";
{{- end}}
{{ range $name, $value := .Options }}
option {{ $name }} = {{ $value }};
{{ end -}}
{{ range $decl := .OrderedDecls }}
{{- if eq (typeof $decl) "Message" }}
message {{ .Name }} {
{{- range $name, $field := .Fields }}
{{ if .Repeated }}repeated {{else if .Optional}}optional {{ end }}{{ .Type }} {{ .Name | toLowerSnake }} = {{ .ID }};
{{- end }}
}
{{- else if eq (typeof $decl) "Enum" }}
enum {{ .Name }} {
{{- range $value, $name := .ByValue }}
{{ $name | toUpperSnake }} = {{ $value }};
{{- end }}
}
{{- else if eq (typeof $decl) "SumType" }}
{{ $sumtype := . }}
message {{ .Name }} {
oneof value {
{{- range $name, $id := .Elems }}
{{ $name }} {{ trimPrefix $name $sumtype.Name | toLowerSnake }} = {{ $id }};
{{- end }}
}
}
{{- end }}
{{ end }}
`))

type RenderContext struct {
Config
File
}

func render(out *os.File, config Config, file File) error {
err := tmpl.Execute(out, RenderContext{
Config: config,
File: file,
})
if err != nil {
return fmt.Errorf("template error: %w", err)
}
return nil
}

func extract(config Config, pkg *PkgRefs) (File, error) {
state := State{
Seen: map[string]bool{},
Expand All @@ -340,6 +320,9 @@ func extract(config Config, pkg *PkgRefs) (File, error) {
if obj == nil {
return File{}, fmt.Errorf("%s: not found in package %s", sym, pkg.Pkg.ID)
}
if !strings.HasSuffix(pkg.Pkg.Name, "_test") {
state.Dest.GoPackage = pkg.Pkg.Name
}
if err := state.extractDecl(obj, obj.Type()); err != nil {
return File{}, fmt.Errorf("%s: %w", sym, err)
}
Expand Down Expand Up @@ -373,18 +356,18 @@ func (s *State) extractDecl(obj types.Object, t types.Type) error {
}
}

type builtinType struct {
type stdType struct {
ref string
path string
}

var builtinTypes = map[string]builtinType{
var stdTypes = map[string]stdType{
"time.Time": {"google.protobuf.Timestamp", "google/protobuf/timestamp.proto"},
"time.Duration": {"google.protobuf.Duration", "google/protobuf/duration.proto"},
}

func (s *State) extractStruct(n *types.Named, t *types.Struct) error {
if imp, ok := builtinTypes[n.String()]; ok {
if imp, ok := stdTypes[n.String()]; ok {
s.Dest.AddImport(imp.path)
return nil
}
Expand Down Expand Up @@ -435,8 +418,8 @@ func (s *State) extractSumType(obj types.Object, i *types.Interface) error {
}
s.Seen[sumTypeName] = true
decl := SumType{
Name: sumTypeName,
Elems: map[string]int{},
Name: sumTypeName,
Variants: map[string]int{},
}
scope := s.Pkg.Types.Scope()
for _, name := range scope.Names() {
Expand All @@ -463,14 +446,18 @@ func (s *State) extractSumType(obj types.Object, i *types.Interface) error {
if err := s.extractDecl(sym, sym.Type()); err != nil {
return genErrorf(sym.Pos(), "%s: %w", name, err)
}
decl.Elems[name] = directive.ID
decl.Variants[name] = directive.ID
}
}
s.Dest.Decls = append(s.Dest.Decls, decl)
return nil
}

func (s *State) extractEnum(t *types.Named) error {
if imp, ok := stdTypes[t.String()]; ok {
s.Dest.AddImport(imp.path)
return nil
}
enumName := t.Obj().Name()
if _, ok := s.Seen[enumName]; ok {
return nil
Expand Down Expand Up @@ -540,7 +527,7 @@ func (s *State) applyFieldType(t types.Type, field *Field) error {
return err
}
ref := t.Obj().Pkg().Path() + "." + t.Obj().Name()
if bt, ok := builtinTypes[ref]; ok {
if bt, ok := stdTypes[ref]; ok {
field.Type = bt.ref
} else {
field.Type = t.Obj().Name()
Expand All @@ -555,21 +542,30 @@ func (s *State) applyFieldType(t types.Type, field *Field) error {
}

case *types.Pointer:
field.Pointer = true
if _, ok := t.Elem().(*types.Slice); ok {
return fmt.Errorf("pointer to named type is not supported")
}
return s.applyFieldType(t.Elem(), field)

default:
field.ProtoGoType = t.String()
switch t.String() {
case "int":
field.Type = "int64"
field.ProtoGoType = "int64"

case "uint":
field.Type = "uint64"
field.ProtoGoType = "uint64"

case "float64":
field.Type = "double"
field.ProtoGoType = "float64"

case "float32":
field.Type = "float"
field.ProtoGoType = "float32"

case "string", "bool", "uint64", "int64", "uint32", "int32":
field.Type = t.String()
Expand Down
Loading

0 comments on commit 4c5bcba

Please sign in to comment.