Skip to content

Commit

Permalink
Add GetMessage and GetEnum mehods to template
Browse files Browse the repository at this point in the history
  • Loading branch information
b00f committed Jun 19, 2024
1 parent cd87b7c commit 827a23d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 16 deletions.
2 changes: 2 additions & 0 deletions renderer.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ type textRenderer struct {
}

func (mr *textRenderer) Apply(template *Template) ([]byte, error) {
funcMap["getMessage"] = template.GetMessage
funcMap["getEnum"] = template.GetEnum
tmpl, err := text_template.New("Text Template").Funcs(funcMap).Funcs(sprig.TxtFuncMap()).Parse(mr.inputTemplate)
if err != nil {
return nil, err
Expand Down
108 changes: 101 additions & 7 deletions template.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package gendoc

import (
"bytes"
"encoding/json"
"fmt"
"sort"
"strings"
"unicode"

"github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/pactus-project/protoc-gen-doc/extensions"
Expand Down Expand Up @@ -69,7 +71,7 @@ func NewTemplate(descs []*protokit.FileDescriptor) *Template {
if f.IsMap {
index, msg := getMessageByName(&file.Messages, f.Type)
if msg == nil || len(msg.Fields) != 2 {
panic(fmt.Sprintf("unable to find key/va;ue for %s", f.Name))
panic(fmt.Sprintf("unable to find key/value for %s", f.Name))
}

keyField := msg.Fields[0]
Expand Down Expand Up @@ -106,6 +108,32 @@ func NewTemplate(descs []*protokit.FileDescriptor) *Template {
return &Template{Files: files, Scalars: makeScalars()}
}

// GetMessage returns a message by type name
func (t *Template) GetMessage(typeName string) *Message {
for _, f := range t.Files {
for _, msg := range f.Messages {
if msg.Name == typeName {
return msg
}
}

}
return nil
}

// GetEnum returns an enum by type name
func (t *Template) GetEnum(typeName string) *Enum {
for _, f := range t.Files {
for _, e := range f.Enums {
if e.Name == typeName {
return e
}
}

}
return nil
}

func getMessageByName(orderedMessages *orderedMessages, name string) (int, *Message) {
for index, msg := range *orderedMessages {
if msg.Name == name {
Expand Down Expand Up @@ -263,14 +291,18 @@ func (m Message) FieldsWithOption(optionName string) []*MessageField {
// repeated (in which case it'll be "repeated").
type MessageField struct {
Name string `json:"name"`
JsonName string `json:"jsonName"`
Description string `json:"description"`
Label string `json:"label"`
Type string `json:"type"`
JsonType string `json:"jsonType"`
LongType string `json:"longType"`
FullType string `json:"fullType"`
IsMap bool `json:"ismap"`
IsOneof bool `json:"isoneof"`
OneofDecl string `json:"oneofdecl"`
IsMap bool `json:"isMap"`
IsEnum bool `json:"isEnum"`
IsOneOf bool `json:"isOneOf"`
IsRepeated bool `json:"isRepeated"`
OneOfDecl string `json:"oneOfDecl"`
DefaultValue string `json:"defaultValue"`

Options map[string]interface{} `json:"options,omitempty"`
Expand Down Expand Up @@ -351,6 +383,7 @@ func (v EnumValue) Option(name string) interface{} { return v.Options[name] }
// Service contains details about a service definition within a proto file.
type Service struct {
Name string `json:"name"`
JsonName string `json:"jsonName"`
LongName string `json:"longName"`
FullName string `json:"fullName"`
Description string `json:"description"`
Expand Down Expand Up @@ -399,6 +432,7 @@ func (s Service) MethodsWithOption(optionName string) []*ServiceMethod {
// ServiceMethod contains details about an individual method within a service.
type ServiceMethod struct {
Name string `json:"name"`
JsonName string `json:"jsonName"`
Description string `json:"description"`
RequestType string `json:"requestType"`
RequestLongType string `json:"requestLongType"`
Expand Down Expand Up @@ -510,21 +544,32 @@ func parseMessageExtension(pe *protokit.ExtensionDescriptor) *MessageExtension {

func parseMessageField(pf *protokit.FieldDescriptor, oneofDecls []*descriptor.OneofDescriptorProto) *MessageField {
t, lt, ft := parseType(pf)
jt := parseJsonType(pf)

m := &MessageField{
Name: pf.GetName(),
JsonName: camelToSnake(pf.GetName()),
Description: description(pf.GetComments().String()),
Label: labelName(pf.GetLabel(), pf.IsProto3(), pf.GetProto3Optional()),
Type: t,
JsonType: jt,
LongType: lt,
FullType: ft,
DefaultValue: pf.GetDefaultValue(),
Options: mergeOptions(extractOptions(pf.GetOptions()), extensions.Transform(pf.OptionExtensions)),
IsOneof: pf.OneofIndex != nil,
IsOneOf: pf.OneofIndex != nil,
}

if *pf.Type == descriptor.FieldDescriptorProto_TYPE_ENUM {
m.IsEnum = true
}

if m.IsOneOf {
m.OneOfDecl = oneofDecls[pf.GetOneofIndex()].GetName()
}

if m.IsOneof {
m.OneofDecl = oneofDecls[pf.GetOneofIndex()].GetName()
if m.Label == "repeated" {
m.IsRepeated = true
}

// Check if this is a map.
Expand All @@ -542,6 +587,7 @@ func parseMessageField(pf *protokit.FieldDescriptor, oneofDecls []*descriptor.On
func parseService(ps *protokit.ServiceDescriptor) *Service {
service := &Service{
Name: ps.GetName(),
JsonName: camelToSnake(ps.GetName()),
LongName: ps.GetLongName(),
FullName: ps.GetFullName(),
Description: description(ps.GetComments().String()),
Expand All @@ -558,6 +604,7 @@ func parseService(ps *protokit.ServiceDescriptor) *Service {
func parseServiceMethod(pm *protokit.MethodDescriptor) *ServiceMethod {
return &ServiceMethod{
Name: pm.GetName(),
JsonName: camelToSnake(pm.GetName()),
Description: description(pm.GetComments().String()),
RequestType: baseName(pm.GetInputType()),
RequestLongType: strings.TrimPrefix(pm.GetInputType(), "."+pm.GetPackage()+"."),
Expand Down Expand Up @@ -602,6 +649,53 @@ func parseType(tc typeContainer) (string, string, string) {
return name, name, name
}

func parseJsonType(tc typeContainer) string {
switch tc.GetType() {
case descriptor.FieldDescriptorProto_TYPE_STRING:
case descriptor.FieldDescriptorProto_TYPE_BYTES:
return "string"

case descriptor.FieldDescriptorProto_TYPE_BOOL:
return "boolean"

case descriptor.FieldDescriptorProto_TYPE_GROUP:
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
case descriptor.FieldDescriptorProto_TYPE_ENUM:
return "object"

case descriptor.FieldDescriptorProto_TYPE_UINT32:
case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
case descriptor.FieldDescriptorProto_TYPE_FLOAT:
case descriptor.FieldDescriptorProto_TYPE_INT64:
case descriptor.FieldDescriptorProto_TYPE_UINT64:
case descriptor.FieldDescriptorProto_TYPE_INT32:
case descriptor.FieldDescriptorProto_TYPE_FIXED64:
case descriptor.FieldDescriptorProto_TYPE_FIXED32:
case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
case descriptor.FieldDescriptorProto_TYPE_SFIXED64:
case descriptor.FieldDescriptorProto_TYPE_SINT32:
case descriptor.FieldDescriptorProto_TYPE_SINT64:
return "numeric"
}

return "unknown"
}

func camelToSnake(s string) string {
var buf bytes.Buffer
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
buf.WriteRune('_')
}
buf.WriteRune(unicode.ToLower(r))
} else {
buf.WriteRune(r)
}
}
return buf.String()
}

func description(comment string) string {
val := strings.TrimLeft(comment, "*/\n ")
if strings.HasPrefix(val, "@exclude") {
Expand Down
18 changes: 9 additions & 9 deletions template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "int32", field.LongType)
require.Equal(t, "int32", field.FullType)
require.Empty(t, field.DefaultValue)
require.False(t, field.IsOneof)
require.False(t, field.IsOneOf)
require.NotEmpty(t, field.Options)
require.True(t, *field.Option(E_ExtendField.Name).(*bool))

Expand All @@ -263,7 +263,7 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "BookingStatus.StatusCode", field.LongType)
require.Equal(t, "com.example.BookingStatus.StatusCode", field.FullType)
require.Empty(t, field.DefaultValue)
require.False(t, field.IsOneof)
require.False(t, field.IsOneOf)

field = findField("category", findMessage("Vehicle", vehicleFile))
require.Equal(t, "category", field.Name)
Expand All @@ -273,7 +273,7 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "Vehicle.Category", field.LongType)
require.Equal(t, "com.example.Vehicle.Category", field.FullType)
require.Empty(t, field.DefaultValue)
require.False(t, field.IsOneof)
require.False(t, field.IsOneOf)

field = findField("properties", findMessage("Vehicle", vehicleFile))
require.Equal(t, "properties", field.Name)
Expand All @@ -283,7 +283,7 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "map<string, string>", field.FullType)
require.Empty(t, field.DefaultValue)
require.True(t, field.IsMap)
require.False(t, field.IsOneof)
require.False(t, field.IsOneOf)

field = findField("rates", findMessage("Vehicle", vehicleFile))
require.Equal(t, "rates", field.Name)
Expand All @@ -292,7 +292,7 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "sint32", field.LongType)
require.Equal(t, "sint32", field.FullType)
require.False(t, field.IsMap)
require.False(t, field.IsOneof)
require.False(t, field.IsOneOf)

field = findField("kilometers", findMessage("Vehicle", vehicleFile))
require.Equal(t, "kilometers", field.Name)
Expand All @@ -301,8 +301,8 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "int32", field.LongType)
require.Equal(t, "int32", field.FullType)
require.False(t, field.IsMap)
require.True(t, field.IsOneof)
require.Equal(t, "travel", field.OneofDecl)
require.True(t, field.IsOneOf)
require.Equal(t, "travel", field.OneOfDecl)

field = findField("human_name", findMessage("Vehicle", vehicleFile))
require.Equal(t, "human_name", field.Name)
Expand All @@ -311,8 +311,8 @@ func TestFieldProperties(t *testing.T) {
require.Equal(t, "string", field.LongType)
require.Equal(t, "string", field.FullType)
require.False(t, field.IsMap)
require.True(t, field.IsOneof)
require.Equal(t, "drivers", field.OneofDecl)
require.True(t, field.IsOneOf)
require.Equal(t, "drivers", field.OneOfDecl)
}

func TestFieldPropertiesProto3(t *testing.T) {
Expand Down

0 comments on commit 827a23d

Please sign in to comment.