Skip to content

Commit

Permalink
improve type decoders
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac committed Mar 5, 2024
1 parent fc7e954 commit ac3cfbe
Show file tree
Hide file tree
Showing 23 changed files with 2,559 additions and 532 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Unit tests

on:
pull_request:
push:
paths:
- "**.go"
Expand Down
119 changes: 77 additions & 42 deletions cmd/ndc-go-sdk/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (
"path"
"sort"
"strings"

"github.com/hasura/ndc-sdk-go/schema"
)

const (
Expand All @@ -30,19 +28,42 @@ const (
`
)

type connectorTypeBuilder struct {
packageName string
imports map[string]string
builder *strings.Builder
}

func (ctb connectorTypeBuilder) String() string {
var bs strings.Builder
bs.WriteString(genFileHeader(ctb.packageName))
if len(ctb.imports) > 0 {
bs.WriteString("import (\n")
for pkg, alias := range ctb.imports {
if alias != "" {
alias = alias + " "
}
bs.WriteString(fmt.Sprintf(" %s\"%s\"\n", alias, pkg))
}
bs.WriteString(")\n")
}
_, _ = bs.WriteString(ctb.builder.String())
return bs.String()
}

type connectorGenerator struct {
basePath string
moduleName string
rawSchema *RawConnectorSchema
typeBuilders map[string]*strings.Builder
typeBuilders map[string]*connectorTypeBuilder
}

func NewConnectorGenerator(basePath string, moduleName string, rawSchema *RawConnectorSchema) *connectorGenerator {
return &connectorGenerator{
basePath: basePath,
moduleName: moduleName,
rawSchema: rawSchema,
typeBuilders: make(map[string]*strings.Builder),
typeBuilders: make(map[string]*connectorTypeBuilder),
}
}

Expand Down Expand Up @@ -278,13 +299,13 @@ func (cg *connectorGenerator) genObjectMethods() error {
continue
}
sb := cg.getTypeBuilder(object.PackageName, object.PackageName)
_, _ = sb.WriteString(fmt.Sprintf(`
_, _ = sb.builder.WriteString(fmt.Sprintf(`
// ToMap encodes the struct to a value map
func (j %s) ToMap() map[string]any {
`, objectName))
lines := cg.genObjectToMap(object, "j", "result", false, false)
sb.WriteString(strings.Join(lines, "\n"))
sb.WriteString(`
sb.builder.WriteString(strings.Join(lines, "\n"))
sb.builder.WriteString(`
return result
}`)
}
Expand Down Expand Up @@ -363,7 +384,7 @@ func (cg *connectorGenerator) genCustomScalarMethods() error {
for _, scalarKey := range scalarKeys {
scalar := cg.rawSchema.CustomScalars[scalarKey]
sb := cg.getTypeBuilder(scalar.PackageName, scalar.PackageName)
_, _ = sb.WriteString(fmt.Sprintf(`
_, _ = sb.builder.WriteString(fmt.Sprintf(`
// ScalarName get the schema name of the scalar
func (j %s) ScalarName() string {
return "%s"
Expand All @@ -383,7 +404,7 @@ func (cg *connectorGenerator) genFunctionArgumentConstructors() error {
continue
}
sb := cg.getTypeBuilder(fn.PackageName, fn.PackageName)
_, _ = sb.WriteString(fmt.Sprintf(`
_, _ = sb.builder.WriteString(fmt.Sprintf(`
// FromValue decodes values from map
func (j *%s) FromValue(input map[string]any) error {
var err error
Expand All @@ -392,20 +413,25 @@ func (j *%s) FromValue(input map[string]any) error {
argumentKeys := getSortedKeys(fn.Arguments)
for _, key := range argumentKeys {
arg := fn.Arguments[key]
_, _ = sb.WriteString(genGetTypeValueDecoder(arg.Type, key, arg.FieldName))
cg.genGetTypeValueDecoder(sb, arg.Type, key, arg.FieldName)
}
sb.WriteString(` return nil
sb.builder.WriteString(` return nil
}`)
}

return nil
}

func (cg *connectorGenerator) getTypeBuilder(fileName string, packageName string) *strings.Builder {
func (cg *connectorGenerator) getTypeBuilder(fileName string, packageName string) *connectorTypeBuilder {
bs, ok := cg.typeBuilders[fileName]
if !ok {
bs = &strings.Builder{}
bs.WriteString(genFileHeader(packageName))
bs = &connectorTypeBuilder{
packageName: packageName,
imports: map[string]string{
"github.com/hasura/ndc-sdk-go/utils": "",
},
builder: &strings.Builder{},
}
cg.typeBuilders[fileName] = bs
}
return bs
Expand All @@ -414,60 +440,69 @@ func (cg *connectorGenerator) getTypeBuilder(fileName string, packageName string
func genFileHeader(packageName string) string {
return fmt.Sprintf(`// Code generated by github.com/hasura/ndc-sdk-go/codegen, DO NOT EDIT.
package %s
import (
"github.com/hasura/ndc-sdk-go/utils"
)
`, packageName)
}

func genGetTypeValueDecoder(ty *TypeInfo, key string, fieldName string) string {
var sb strings.Builder
func (cg *connectorGenerator) genGetTypeValueDecoder(sb *connectorTypeBuilder, ty *TypeInfo, key string, fieldName string) {
typeName := ty.TypeAST.String()
switch typeName {
case "bool":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetBool(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetBool(input, "%s")`, fieldName, key))
case "*bool":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetBoolPtr(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetBoolPtr(input, "%s")`, fieldName, key))
case "string":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetString(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetString(input, "%s")`, fieldName, key))
case "*string":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetStringPtr(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetStringPtr(input, "%s")`, fieldName, key))
case "int", "int8", "int16", "int32", "int64", "rune", "byte":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetInt[%s](input, "%s")`, fieldName, typeName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetInt[%s](input, "%s")`, fieldName, typeName, key))
case "uint", "uint8", "uint16", "uint32", "uint64":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetUint[%s](input, "%s")`, fieldName, typeName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetUint[%s](input, "%s")`, fieldName, typeName, key))
case "*int", "*int8", "*int16", "*int32", "*int64", "*rune", "*byte":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetIntPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetIntPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
case "*uint", "*uint8", "*uint16", "*uint32", "*uint64":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetUintPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetUintPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
case "float32", "float64":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetFloat[%s](input, "%s")`, fieldName, typeName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetFloat[%s](input, "%s")`, fieldName, typeName, key))
case "*float32", "*float64":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetFloatPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetFloatPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
case "complex64", "complex128":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetComplex[%s](input, "%s")`, fieldName, typeName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetComplex[%s](input, "%s")`, fieldName, typeName, key))
case "*complex64", "*complex128":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetComplexPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetComplexPtr[%s](input, "%s")`, fieldName, strings.TrimPrefix(typeName, "*"), key))
case "time.Time":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDateTime(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDateTime(input, "%s")`, fieldName, key))
case "*time.Time":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDateTimePtr(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDateTimePtr(input, "%s")`, fieldName, key))
case "time.Duration":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDuration(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDuration(input, "%s")`, fieldName, key))
case "*time.Duration":
_, _ = sb.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDurationPtr(input, "%s")`, fieldName, key))
_, _ = sb.builder.WriteString(fmt.Sprintf(` j.%s, err = utils.GetDurationPtr(input, "%s")`, fieldName, key))
default:
switch ty.Schema.(type) {
case *schema.NamedType:
_, _ = sb.WriteString(fmt.Sprintf(` err = utils.DecodeObjectValue(&j.%s, input, "%s")`, fieldName, key))
case *schema.NullableType:
_, _ = sb.WriteString(fmt.Sprintf(` err = utils.DecodeObjectValue(j.%s, input, "%s")`, fieldName, key))
if ty.IsNullable {
_, _ = sb.builder.WriteString(fmt.Sprintf(` err = utils.DecodeObjectValue(j.%s, input, "%s")`, fieldName, key))
} else {
_, _ = sb.builder.WriteString(fmt.Sprintf(` err = utils.DecodeObjectValue(&j.%s, input, "%s")`, fieldName, key))
}
}
_, _ = sb.WriteString(textBlockErrorCheck)
return sb.String()
_, _ = sb.builder.WriteString(textBlockErrorCheck)
}

// TODO: will use in the future
// func extractModuleAndTypeName(name string) (string, string) {
// parts := strings.Split(strings.TrimPrefix(name, "*"), "/")
// typeName := parts[len(parts)-1]
// typeNameParts := strings.Split(typeName, ".")
// if len(typeNameParts) < 2 {
// return "", typeName
// }
// if len(parts) == 1 {
// return typeNameParts[0], typeName
// }

// return strings.Join(append(parts[:len(parts)-1], typeNameParts[0]), "/"), typeName
// }

func getSortedKeys[V any](input map[string]V) []string {
var results []string
for key := range input {
Expand Down
24 changes: 20 additions & 4 deletions cmd/ndc-go-sdk/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ var defaultScalarTypes = schema.SchemaResponseScalarTypes{
"Boolean": *schema.NewScalarType(),
"DateTime": *schema.NewScalarType(),
"Duration": *schema.NewScalarType(),
"UUID": *schema.NewScalarType(),
}

var ndcOperationNameRegex = regexp.MustCompile(`^(Function|Procedure)([A-Z][A-Za-z0-9]*)$`)
Expand All @@ -36,11 +35,13 @@ var ndcScalarCommentRegex = regexp.MustCompile(`^@scalar(\s+([A-Z]\w*))?`)

type OperationKind string

var (
const (
OperationFunction OperationKind = "Function"
OperationProcedure OperationKind = "Procedure"
)

type TypeKind string

// TypeInfo represents the serialization information of a type
type TypeInfo struct {
Name string
Expand Down Expand Up @@ -178,6 +179,17 @@ func (rcs RawConnectorSchema) Schema() *schema.SchemaResponse {
return result
}

// IsCustomType checks if the type name is a custom scalar or an exported object
func (rcs RawConnectorSchema) IsCustomType(name string) bool {
if _, ok := rcs.CustomScalars[name]; ok {
return true
}
if obj, ok := rcs.Objects[name]; ok {
return !obj.IsAnonymous
}
return false
}

type SchemaParser struct {
moduleName string
files map[string]*ast.File
Expand Down Expand Up @@ -445,14 +457,18 @@ func (sp *SchemaParser) parseType(rawSchema *RawConnectorSchema, rootType *TypeI
}
}
if scalarName != "" {
rawSchema.ScalarSchemas[scalarName] = defaultScalarTypes[scalarName]
if scalar, ok := defaultScalarTypes[scalarName]; ok {
rawSchema.ScalarSchemas[scalarName] = scalar
} else {
rawSchema.ScalarSchemas[scalarName] = *schema.NewScalarType()
}
typeInfo.Schema = schema.NewNamedType(scalarName)
return typeInfo, nil
}
}

if typeInfo.IsScalar {
rawSchema.CustomScalars[typeInfo.SchemaName] = typeInfo
rawSchema.CustomScalars[typeInfo.Name] = typeInfo
rawSchema.ScalarSchemas[typeInfo.SchemaName] = *schema.NewScalarType()
return typeInfo, nil
}
Expand Down
28 changes: 28 additions & 0 deletions cmd/ndc-go-sdk/testdata/basic/expected/connector.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,34 @@ func execQuery(ctx context.Context, state *types.State, request *schema.QueryReq
return nil, schema.BadRequestError("cannot evaluate selection fields for scalar", nil)
}
return functions.FunctionGetBool(ctx, state)
case "getTypes":
rawArgs, err := utils.ResolveArgumentVariables(request.Arguments, variables)
if err != nil {
return nil, schema.BadRequestError("failed to resolve argument variables", map[string]any{
"cause": err.Error(),
})
}

var args functions.GetTypesArguments
if err = args.FromValue(rawArgs); err != nil {
return nil, schema.BadRequestError("failed to resolve arguments", map[string]any{
"cause": err.Error(),
})
}
rawResult, err := functions.FunctionGetTypes(ctx, state, &args)
if err != nil {
return nil, err
}

if rawResult == nil {
return nil, nil
}

result, err := utils.EncodeObjectWithColumnSelection(request.Query.Fields, rawResult)
if err != nil {
return nil, err
}
return result, nil
case "hello":
rawResult, err := functions.FunctionHello(ctx, state)
if err != nil {
Expand Down
Loading

0 comments on commit ac3cfbe

Please sign in to comment.