Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve target Model primary key type for BelongsTo|HasManyKey|HasOnKey #118

Merged
merged 4 commits into from
Jan 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ func (g *Generator) generateUserTypes(outdir string, api *design.APIDefinition)
codegen.SimpleImport("github.com/goadesign/goa"),
codegen.SimpleImport("github.com/jinzhu/gorm"),
codegen.SimpleImport("golang.org/x/net/context"),
codegen.SimpleImport("golang.org/x/net/context"),
codegen.SimpleImport("github.com/goadesign/goa/uuid"),
codegen.NewImport("uuid", "github.com/satori/go.uuid"),
}

if model.Cached {
Expand Down Expand Up @@ -168,8 +167,7 @@ func (g *Generator) generateUserHelpers(outdir string, api *design.APIDefinition
codegen.SimpleImport("github.com/goadesign/goa"),
codegen.SimpleImport("github.com/jinzhu/gorm"),
codegen.SimpleImport("golang.org/x/net/context"),
codegen.SimpleImport("golang.org/x/net/context"),
codegen.SimpleImport("github.com/goadesign/goa/uuid"),
codegen.NewImport("uuid", "github.com/satori/go.uuid"),
}

if model.Cached {
Expand Down
53 changes: 50 additions & 3 deletions relationalfield.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ func goDatatype(f *RelationalFieldDefinition, includePtr bool) string {
case Timestamp, NullableTimestamp:
return ptr + "time.Time"
case BelongsTo:
return ptr + "int"
return ptr + belongsToIDType(f, includePtr)
case HasMany:
return fmt.Sprintf("[]%s", f.HasMany)
case HasManyKey, HasOneKey:
return ptr + "int"
case HasManyKey:
return ptr + hasManyIDType(f, includePtr)
case HasOneKey:
return ptr + hasOneIDType(f, includePtr)
case HasOne:
return fmt.Sprintf("%s", f.HasOne)
default:
Expand All @@ -121,6 +123,51 @@ func goDatatype(f *RelationalFieldDefinition, includePtr bool) string {
return "UNKNOWN TYPE"
}

func goDatatypeByModel(m *RelationalModelDefinition, belongsToModelName string) string {
f := m.RelationalFields[belongsToModelName+"ID"]
if f == nil {
return "int"
}
return belongsToIDType(f, true)
}

func belongsToIDType(f *RelationalFieldDefinition, includePtr bool) string {
if f.Parent == nil {
return "int"
}
modelName := strings.Replace(f.FieldName, "ID", "", -1)
model := f.Parent.BelongsTo[modelName]

This comment was marked as off-topic.

return relatedIDType(model, includePtr)
}

func hasOneIDType(f *RelationalFieldDefinition, includePtr bool) string {
if f.Parent == nil {
return "int"
}
modelName := strings.Replace(f.FieldName, "ID", "", -1)
model := f.Parent.HasOne[modelName]
return relatedIDType(model, includePtr)
}

func hasManyIDType(f *RelationalFieldDefinition, includePtr bool) string {
if f.Parent == nil {
return "int"
}
modelName := strings.Replace(f.FieldName, "ID", "", -1)
model := f.Parent.HasMany[modelName]
return relatedIDType(model, includePtr)
}

func relatedIDType(m *RelationalModelDefinition, includePtr bool) string {
if m == nil {

This comment was marked as off-topic.

return "int"
}
if len(m.PrimaryKeys) > 1 {
panic("Can't determine field Type when using multiple primary keys")
}
return goDatatype(m.PrimaryKeys[0], includePtr)
}

func tags(f *RelationalFieldDefinition) string {
var sqltags []string
if f.SQLTag != "" {
Expand Down
3 changes: 2 additions & 1 deletion writers.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ func (w *UserTypesWriter) Execute(data *UserTypeTemplateData) error {
fm["viewFields"] = viewFields
fm["viewFieldNames"] = viewFieldNames
fm["goDatatype"] = goDatatype
fm["goDatatypeByModel"] = goDatatypeByModel
fm["plural"] = inflect.Pluralize
fm["gtt"] = codegen.GoTypeTransform
fm["gttn"] = codegen.GoTypeTransformName
Expand Down Expand Up @@ -450,7 +451,7 @@ func (m *{{$ut.ModelName}}DB) TableName() string {
// Belongs To Relationships

// {{$ut.ModelName}}FilterBy{{$bt.ModelName}} is a gorm filter for a Belongs To relationship.
func {{$ut.ModelName}}FilterBy{{$bt.ModelName}}({{goify (printf "%s%s" $bt.ModelName "ID") false}} int, originaldb *gorm.DB) func(db *gorm.DB) *gorm.DB {
func {{$ut.ModelName}}FilterBy{{$bt.ModelName}}({{goify (printf "%s%s" $bt.ModelName "ID") false}} {{ goDatatypeByModel $ut $bt.ModelName }}, originaldb *gorm.DB) func(db *gorm.DB) *gorm.DB {
if {{goify (printf "%s%s" $bt.ModelName "ID") false}} > 0 {
return func(db *gorm.DB) *gorm.DB {
return db.Where("{{if $bt.RelationalFields.ID.DatabaseFieldName}}{{ if ne $bt.RelationalFields.ID.DatabaseFieldName "id" }}{{$bt.RelationalFields.ID.DatabaseFieldName}} = ?", {{goify (printf "%s%s" $bt.ModelName "ID") false}}){{else}}{{$bt.LowerName}}_id = ?", {{goify (printf "%s%s" $bt.ModelName "ID") false}}){{end}}
Expand Down