Skip to content

Commit

Permalink
tool: add update-types sync command to sync existing types with proto
Browse files Browse the repository at this point in the history
Add a new command 'update-types sync' that synchronizes existing Go types
with their proto definitions. This ensures that Go field types and comments
stay in sync with the proto definitions while preserving:
- Special annotations (e.g. +required)
- Reference fields
- Ignored fields
- Manual edits
  • Loading branch information
jingyih committed Dec 18, 2024
1 parent 66b4710 commit bb9656a
Show file tree
Hide file tree
Showing 12 changed files with 841 additions and 98 deletions.
41 changes: 41 additions & 0 deletions dev/tools/controllerbuilder/pkg/codegen/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package codegen

import "strings"

// special-case proto messages that are currently not mapped to KRM Go structs
var protoMessagesNotMappedToGoStruct = map[string]string{
"google.protobuf.Timestamp": "string",
"google.protobuf.Duration": "string",
"google.protobuf.Int64Value": "int64",
"google.protobuf.StringValue": "string",
"google.protobuf.Struct": "map[string]string",
}

var Acronyms = []string{
"ID", "HTML", "URL", "HTTP", "HTTPS", "SSH",
"IP", "GB", "FS", "PD", "KMS", "GCE", "VTPM",
}

// IsAcronym returns true if the given string is an acronym
func IsAcronym(s string) bool {
for _, acronym := range Acronyms {
if strings.EqualFold(s, acronym) {
return true
}
}
return false
}
129 changes: 50 additions & 79 deletions dev/tools/controllerbuilder/pkg/codegen/typegenerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@ import (
"k8s.io/klog/v2"
)

// Some special-case values that are not obvious how to map in KRM
var protoMessagesNotMappedToGoStruct = map[string]string{
"google.protobuf.Timestamp": "string",
"google.protobuf.Duration": "string",
"google.protobuf.Int64Value": "int64",
"google.protobuf.StringValue": "string",
"google.protobuf.Struct": "map[string]string",
}

type TypeGenerator struct {
generatorBase
api *protoapi.Proto
Expand Down Expand Up @@ -78,7 +69,7 @@ func (g *TypeGenerator) visitMessage(messageDescriptor protoreflect.MessageDescr

g.visitedMessages = append(g.visitedMessages, messageDescriptor)

msgs, err := findDependenciesForMessage(messageDescriptor)
msgs, err := FindDependenciesForMessage(messageDescriptor)
if err != nil {
return err
}
Expand Down Expand Up @@ -123,7 +114,7 @@ func (g *TypeGenerator) WriteVisitedMessages() error {
}
out := g.getOutputFile(k)

goTypeName := goNameForProtoMessage(msg)
goTypeName := GoNameForProtoMessage(msg)
skipGenerated := true
goType, err := g.findTypeDeclaration(goTypeName, out.OutputDir(), skipGenerated)
if err != nil {
Expand Down Expand Up @@ -151,7 +142,7 @@ func (g *TypeGenerator) WriteVisitedMessages() error {
}

func WriteMessage(out io.Writer, msg protoreflect.MessageDescriptor) {
goType := goNameForProtoMessage(msg)
goType := GoNameForProtoMessage(msg)

fmt.Fprintf(out, "\n")
fmt.Fprintf(out, "// +kcc:proto=%s\n", msg.FullName())
Expand All @@ -163,51 +154,58 @@ func WriteMessage(out io.Writer, msg protoreflect.MessageDescriptor) {
fmt.Fprintf(out, "}\n")
}

func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protoreflect.MessageDescriptor, fieldIndex int) {
sourceLocations := msg.ParentFile().SourceLocations().ByDescriptor(field)

jsonName := getJSONForKRM(field)
goFieldName := goFieldName(field)
goType := ""

func GoType(field protoreflect.FieldDescriptor) (string, error) {
if field.IsMap() {
entryMsg := field.Message()
keyKind := entryMsg.Fields().ByName("key").Kind()
valueKind := entryMsg.Fields().ByName("value").Kind()
if keyKind == protoreflect.StringKind && valueKind == protoreflect.StringKind {
goType = "map[string]string"
return "map[string]string", nil
} else if keyKind == protoreflect.StringKind && valueKind == protoreflect.Int64Kind {
goType = "map[string]int64"
return "map[string]int64", nil
} else {
fmt.Fprintf(out, "\n\t// TODO: map type %v %v for %v\n\n", keyKind, valueKind, field.Name())
return
return "", fmt.Errorf("unsupported map type with key %v and value %v", keyKind, valueKind)
}
}

var goType string
switch field.Kind() {
case protoreflect.MessageKind:
goType = GoNameForProtoMessage(field.Message())
case protoreflect.EnumKind:
goType = "string"
default:
goType = goTypeForProtoKind(field.Kind())
}

if field.Cardinality() == protoreflect.Repeated {
goType = "[]" + goType
} else {
switch field.Kind() {
case protoreflect.MessageKind:
goType = goNameForProtoMessage(field.Message())
goType = "*" + goType
}

case protoreflect.EnumKind:
goType = "string" //string(field.Enum().Name())
// Special case for proto "bytes" type
if goType == "*[]byte" {
goType = "[]byte"
}
// Special case for proto "google.protobuf.Struct" type
if goType == "*map[string]string" {
goType = "map[string]string"
}

default:
goType = goTypeForProtoKind(field.Kind())
}
return goType, nil
}

if field.Cardinality() == protoreflect.Repeated {
goType = "[]" + goType
} else {
goType = "*" + goType
}
func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protoreflect.MessageDescriptor, fieldIndex int) {
sourceLocations := msg.ParentFile().SourceLocations().ByDescriptor(field)

// Special case for proto "bytes" type
if goType == "*[]byte" {
goType = "[]byte"
}
// Special case for proto "google.protobuf.Struct" type
if goType == "*map[string]string" {
goType = "map[string]string"
}
jsonName := GetJSONForKRM(field)
GoFieldName := goFieldName(field)

goType, err := GoType(field)
if err != nil {
fmt.Fprintf(out, "\n\t// TODO: %v\n\n", err)
return
}

// Blank line between fields for readability
Expand All @@ -228,7 +226,7 @@ func WriteField(out io.Writer, field protoreflect.FieldDescriptor, msg protorefl

fmt.Fprintf(out, "\t// +kcc:proto=%s\n", field.FullName())
fmt.Fprintf(out, "\t%s %s `json:\"%s,omitempty\"`\n",
goFieldName,
GoFieldName,
goType,
jsonName,
)
Expand All @@ -253,7 +251,7 @@ func deduplicateAndSort(messages []protoreflect.MessageDescriptor) []protoreflec
return messages
}

func goNameForProtoMessage(msg protoreflect.MessageDescriptor) string {
func GoNameForProtoMessage(msg protoreflect.MessageDescriptor) string {
fullName := string(msg.FullName())

// Some special-case values that are not obvious how to map in KRM
Expand Down Expand Up @@ -307,16 +305,16 @@ func goTypeForProtoKind(kind protoreflect.Kind) string {
return goType
}

// getJSONForKRM returns the KRM JSON name for the field,
// GetJSONForKRM returns the KRM JSON name for the field,
// honoring KRM conventions
func getJSONForKRM(protoField protoreflect.FieldDescriptor) string {
func GetJSONForKRM(protoField protoreflect.FieldDescriptor) string {
tokens := strings.Split(string(protoField.Name()), "_")
for i, token := range tokens {
if i == 0 {
// Do not capitalize first token
continue
}
if isAcronym(token) {
if IsAcronym(token) {
token = strings.ToUpper(token)
} else {
token = strings.Title(token)
Expand All @@ -331,7 +329,7 @@ func getJSONForKRM(protoField protoreflect.FieldDescriptor) string {
func goFieldName(protoField protoreflect.FieldDescriptor) string {
tokens := strings.Split(string(protoField.Name()), "_")
for i, token := range tokens {
if isAcronym(token) {
if IsAcronym(token) {
token = strings.ToUpper(token)
} else {
token = strings.Title(token)
Expand All @@ -341,35 +339,8 @@ func goFieldName(protoField protoreflect.FieldDescriptor) string {
return strings.Join(tokens, "")
}

func isAcronym(s string) bool {
switch s {
case "id":
return true
case "html", "url":
return true
case "http", "https", "ssh":
return true
case "ip":
return true
case "gb":
return true
case "fs":
return true
case "pd":
return true
case "kms":
return true
case "gce":
return true
case "vtpm":
return true
default:
return false
}
}

// findDependenciesForMessage recursively explores the dependent proto messages of the given message.
func findDependenciesForMessage(message protoreflect.MessageDescriptor) ([]protoreflect.MessageDescriptor, error) {
// FindDependenciesForMessage recursively explores the dependent proto messages of the given message.
func FindDependenciesForMessage(message protoreflect.MessageDescriptor) ([]protoreflect.MessageDescriptor, error) {
msgs := make(map[string]protoreflect.MessageDescriptor)
for i := 0; i < message.Fields().Len(); i++ {
field := message.Fields().Get(i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func runInsert(opt *insertFieldOptions) func(*cobra.Command, []string) error {
}
}

func runFieldInserter(ctx context.Context, opt *insertFieldOptions) error {
func runFieldInserter(_ context.Context, opt *insertFieldOptions) error {
fieldInserter := typeupdater.NewFieldInserter(&typeupdater.InsertFieldOptions{
ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath,
ParentMessageFullName: opt.parent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package updatetypes

import (
"context"
"fmt"

"github.com/GoogleCloudPlatform/k8s-config-connector/dev/tools/controllerbuilder/pkg/typeupdater"
"github.com/spf13/cobra"
)

type syncProtoPackageOptions struct {
*baseUpdateTypeOptions

legacyMode bool
}

func buildSyncCommand(baseOptions *baseUpdateTypeOptions) *cobra.Command {
opt := &syncProtoPackageOptions{
baseUpdateTypeOptions: baseOptions,
}

cmd := &cobra.Command{
Use: "sync",
Short: "sync the KRM types with the proto package",
Long: `Sync the KRM types with the proto package. This command will update the KRM types
to match the proto package. If --message is specified, only the specified message and its
dependent messages will be synced. If --message is not specified, all messages in the proto
package indicated by --service will be synced.`,
PreRunE: validateSyncOptions(opt),
RunE: runSync(opt),
}

bindSyncFlags(cmd, opt)

return cmd
}

func bindSyncFlags(cmd *cobra.Command, opt *syncProtoPackageOptions) {
opt.BindFlags(cmd)
cmd.Flags().BoolVar(&opt.legacyMode, "legacy-mode", false, "Set to true if the resource has KRM fields that are missing proto annotations.")
}

func validateSyncOptions(opt *syncProtoPackageOptions) func(*cobra.Command, []string) error {
return func(cmd *cobra.Command, args []string) error {
if err := validateRequiredFlags(opt); err != nil {
return err
}
return nil
}
}

func validateRequiredFlags(opt *syncProtoPackageOptions) error {
if opt.apiDirectory == "" {
return fmt.Errorf("--api-dir is required")
}
if opt.apiGoPackagePath == "" {
return fmt.Errorf("--api-go-package-path is required")
}
if opt.ServiceName == "" {
return fmt.Errorf("--service is required")
}
return nil
}

func runSync(opt *syncProtoPackageOptions) func(*cobra.Command, []string) error {
return func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
if err := runPackageSyncer(ctx, opt); err != nil {
return err
}
return nil
}
}

func runPackageSyncer(ctx context.Context, opt *syncProtoPackageOptions) error {
syncer := typeupdater.NewProtoPackageSyncer(&typeupdater.SyncProtoPackageOptions{
ServiceName: opt.ServiceName,
APIVersion: opt.APIVersion,
ProtoSourcePath: opt.GenerateOptions.ProtoSourcePath,
APIDirectory: opt.apiDirectory,
GoPackagePath: opt.apiGoPackagePath,
LegacyMode: opt.legacyMode,
})
return syncer.Run()
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func BuildCommand(baseOptions *options.GenerateOptions) *cobra.Command {

// subcommands
cmd.AddCommand(buildInsertCommand(opt))
cmd.AddCommand(buildSyncCommand(opt))

return cmd
}
Loading

0 comments on commit bb9656a

Please sign in to comment.