Skip to content

Commit

Permalink
Merge pull request #2418 from magodo/avoid_replace_discriminated_vari…
Browse files Browse the repository at this point in the history
…ant_by_parent_model

Avoid replacing the discriminated variant's model by its parent's model
  • Loading branch information
tombuildsstuff authored Aug 7, 2023
2 parents cd3d7c6 + b0287fd commit a9e3091
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 84 deletions.
8 changes: 4 additions & 4 deletions tools/generator-go-sdk/generator/templater_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ func (c modelsTemplater) codeForUnmarshalParentFunction(data ServiceGeneratorDat
// if this is a Discriminated Type (e.g. Parent) then we need to generate a unmarshal{Name}Implementations
// function which can be used in any usages
lines := make([]string, 0)
if c.model.TypeHintIn != nil && c.model.ParentTypeName == nil {
if c.model.IsDiscriminatedParentType() {
modelsImplementingThisClass := make([]string, 0)
for modelName, model := range data.models {
if model.ParentTypeName == nil || model.TypeHintIn == nil || model.TypeHintValue == nil || modelName == c.name {
Expand Down Expand Up @@ -486,7 +486,7 @@ func unmarshal%[1]sImplementation(input []byte) (%[1]s, error) {

func (c modelsTemplater) codeForUnmarshalStructFunction(data ServiceGeneratorData) (*string, error) {
// this is a parent, therefore there'll be no struct fields to check here
if c.model.TypeHintIn != nil && c.model.TypeHintValue == nil && c.model.ParentTypeName == nil {
if c.model.IsDiscriminatedParentType() {
out := ""
return &out, nil
}
Expand All @@ -500,7 +500,7 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data ServiceGeneratorDat
topLevelObject := topLevelObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == resourcemanager.ReferenceApiObjectDefinitionType {
model, ok := data.models[*topLevelObject.ReferenceName]
if ok && model.TypeHintIn != nil {
if ok && model.IsDiscriminatedParentType() {
fieldsRequiringUnmarshalling = append(fieldsRequiringUnmarshalling, fieldName)
continue
}
Expand All @@ -518,7 +518,7 @@ func (c modelsTemplater) codeForUnmarshalStructFunction(data ServiceGeneratorDat
topLevelObject := topLevelObjectDefinition(fieldDetails.ObjectDefinition)
if topLevelObject.Type == resourcemanager.ReferenceApiObjectDefinitionType {
model, ok := data.models[*topLevelObject.ReferenceName]
if ok && model.TypeHintIn != nil {
if ok && model.IsDiscriminatedParentType() {
fieldsRequiringUnmarshalling = append(fieldsRequiringUnmarshalling, fieldName)
continue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,98 @@ func (s Train) MarshalJSON() ([]byte, error) {
assertTemplatedCodeMatches(t, expected, *actual)
}

func TestTemplaterModelsFieldImplementation(t *testing.T) {
actual, err := modelsTemplater{
name: "TrainFactory",
model: resourcemanager.ModelDetails{
Fields: map[string]resourcemanager.FieldDetails{
"Train": {
Required: true,
JsonName: "train",
ObjectDefinition: resourcemanager.ApiObjectDefinition{
Type: resourcemanager.ReferenceApiObjectDefinitionType,
ReferenceName: stringPointer("Train"),
},
},
},
},
}.template(ServiceGeneratorData{
packageName: "somepackage",
models: map[string]resourcemanager.ModelDetails{
"ModeOfTransit": {
TypeHintIn: stringPointer("Type"),
Fields: map[string]resourcemanager.FieldDetails{
"Type": {
IsTypeHint: true,
JsonName: "type",
ObjectDefinition: resourcemanager.ApiObjectDefinition{
Type: resourcemanager.StringApiObjectDefinitionType,
},
Required: true,
},
},
},
"Train": {
Fields: map[string]resourcemanager.FieldDetails{
"Number": {
Required: true,
JsonName: "number",
ObjectDefinition: resourcemanager.ApiObjectDefinition{
Type: resourcemanager.StringApiObjectDefinitionType,
},
},
"Operator": {
Required: true,
JsonName: "operator",
ObjectDefinition: resourcemanager.ApiObjectDefinition{
Type: resourcemanager.StringApiObjectDefinitionType,
},
},
},
ParentTypeName: stringPointer("ModeOfTransit"),
TypeHintIn: stringPointer("Type"),
TypeHintValue: stringPointer("train"),
},
"TrainFactory": {
Fields: map[string]resourcemanager.FieldDetails{
"Train": {
Required: true,
JsonName: "train",
ObjectDefinition: resourcemanager.ApiObjectDefinition{
Type: resourcemanager.ReferenceApiObjectDefinitionType,
ReferenceName: stringPointer("Train"),
},
},
},
},
},
source: AccTestLicenceType,
})
if err != nil {
t.Fatal(err.Error())
}
expected := strings.ReplaceAll(`package somepackage
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/hashicorp/go-azure-helpers/lang/dates"
"github.com/hashicorp/go-azure-helpers/resourcemanager/edgezones"
"github.com/hashicorp/go-azure-helpers/resourcemanager/identity"
"github.com/hashicorp/go-azure-helpers/resourcemanager/systemdata"
"github.com/hashicorp/go-azure-helpers/resourcemanager/zones"
)
// acctests licence placeholder
type TrainFactory struct {
Train Train ''json:"train"''
}
`, "''", "`")
assertTemplatedCodeMatches(t, expected, *actual)
}

func TestTemplaterModelsImplementationInheritedFromParentType(t *testing.T) {
actual, err := modelsTemplater{
name: "FirstImplementation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,10 @@ func TestParseDiscriminatedChildTypeThatShouldntBe(t *testing.T) {
}
}

func TestParseDiscriminatedChildTypeWhereParentShouldBeUsed(t *testing.T) {
func TestParseDiscriminatedChildTypeWhereParentShouldNotBeUsed(t *testing.T) {
// Some Swagger files contain Models with a reference to a Discriminated Type (e.g. an implementation
// where a Parent should be used instead) - this asserts that we switch these out so that the Field
// references the Parent rather than the Implementation.
// where a Parent should be used instead) - this asserts that we shouldn't switch these out to
// referencing the Parent, instead should just use the Implementation itself.
result, err := ParseSwaggerFileForTesting(t, "model_discriminators_child_used_as_parent.json")
if err != nil {
t.Fatalf("parsing: %+v", err)
Expand Down Expand Up @@ -570,9 +570,9 @@ func TestParseDiscriminatedChildTypeWhereParentShouldBeUsed(t *testing.T) {
if nested.ObjectDefinition.ReferenceName == nil {
t.Fatalf("expected the Field `Nested` within the Model `ExampleWrapper` to be a Reference but it was nil")
}
// NOTE: this is the primary assertion here, since the Swagger defined "Dog" should be swapped out for "Animal"
if *nested.ObjectDefinition.ReferenceName != "Animal" {
t.Fatalf("expected the Field `Nested` within the Model `ExampleWrapper` to be a Reference to `Animal` but got %q", *nested.ObjectDefinition.ReferenceName)
// NOTE: this is the primary assertion here, since the Swagger defined "Dog" should be just "Dog", instead of swapped out to be "Animal"
if *nested.ObjectDefinition.ReferenceName != "Dog" {
t.Fatalf("expected the Field `Nested` within the Model `ExampleWrapper` to be a Reference to `Dog` but got %q", *nested.ObjectDefinition.ReferenceName)
}

animal, ok := resource.Models["Animal"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ func (d *SwaggerDefinition) parseResourcesWithinSwaggerTag(tag *string, resource
// then switch out any custom types (e.g. Identity)
result = switchOutCustomTypesAsNeeded(result)

nestedResult, err = d.replaceDiscriminatedTypesWithParents(*nestedResult)
if err != nil {
return nil, fmt.Errorf("replacing discriminated types with parent types: %+v", err)
}

// finally remove any models and constants which aren't referenced / have been replaced
constantsAndModels, resourceIdNamesToUris := removeUnusedItems(*operations, resourceIds.NamesToResourceIDs, result)

Expand Down Expand Up @@ -133,75 +128,6 @@ func pullOutModelForListOperations(input map[string]models.OperationDetails, kno
return &output, nil
}

func (d *SwaggerDefinition) replaceDiscriminatedTypesWithParents(inputResult internal.ParseResult) (*internal.ParseResult, error) {
// some Swaggers define both top-level request/response objects as implementations of discriminators, rather than the parent object
// in our case since we generate the unmarshal funcs etc based on the presence of the parent/interface, we switch these out
// should these be discriminators in the Swagger? likely no, but alas, DRY Swaggers.

nestedResult := internal.ParseResult{
Constants: map[string]resourcemanager.ConstantDetails{},
Models: map[string]models.ModelDetails{},
}
// models will be manually mapped below
nestedResult.AppendConstants(inputResult.Constants)

// TODO: we should consider doing the same for Top Level Requests/Responses in the future too

for name, model := range inputResult.Models {
fields := make(map[string]models.FieldDetails)
for key, value := range model.Fields {
if value.ObjectDefinition != nil {
obj, err := d.replaceDiscriminatedTypeWithinObjectDefinitionWithParent(value.ObjectDefinition, inputResult)
if err != nil {
return nil, fmt.Errorf("replacing object definition for model %q / field %q: %+v", name, key, err)
}
value.ObjectDefinition = obj
}

fields[key] = value
}
model.Fields = fields
nestedResult.Models[name] = model
}

return &nestedResult, nil
}

func (d *SwaggerDefinition) replaceDiscriminatedTypeWithinObjectDefinitionWithParent(input *models.ObjectDefinition, known internal.ParseResult) (*models.ObjectDefinition, error) {
if input.NestedItem != nil {
item, err := d.replaceDiscriminatedTypeWithinObjectDefinitionWithParent(input.NestedItem, known)
if err != nil {
return nil, fmt.Errorf("replacing nested item: %+v", err)
}
input.NestedItem = item
return input, nil
}

if input.Type == models.ObjectDefinitionReference {
// find the parent name and use that
if input.ReferenceName == nil {
return nil, fmt.Errorf("internal-error: reference was missing a reference name")
}
model, modelOk := known.Models[*input.ReferenceName]
_, constantOk := known.Constants[*input.ReferenceName]
if !constantOk && !modelOk {
return nil, fmt.Errorf("a constant or model called %q was not found", *input.ReferenceName)
}
if modelOk && model.ParentTypeName != nil {
parent, ok := known.Models[*model.ParentTypeName]
if !ok {
return nil, fmt.Errorf("parent model %q was not found", *model.ParentTypeName)
}
if parent.ParentTypeName != nil {
return nil, fmt.Errorf("unexpected discriminator within discriminator for parent %q", *parent.ParentTypeName)
}
input.ReferenceName = model.ParentTypeName
}
}

return input, nil
}

func switchOutCustomTypesAsNeeded(input internal.ParseResult) internal.ParseResult {
result := internal.ParseResult{
Constants: map[string]resourcemanager.ConstantDetails{},
Expand Down
8 changes: 8 additions & 0 deletions tools/sdk/resourcemanager/api_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ type ModelDetails struct {
TypeHintValue *string `json:"typeHintValue"`
}

func (m ModelDetails) IsDiscriminatedParentType() bool {
return m.ParentTypeName == nil && m.TypeHintIn != nil && m.TypeHintValue == nil
}

func (m ModelDetails) IsDiscriminatedImplType() bool {
return m.ParentTypeName != nil && m.TypeHintIn != nil && m.TypeHintValue != nil
}

type FieldDetails struct {
// Default is an optional value which should be used as the default for this field
Default *interface{} `json:"default"`
Expand Down

0 comments on commit a9e3091

Please sign in to comment.