diff --git a/internal/ent/generated/runtime.go b/internal/ent/generated/runtime.go index 394a7833e..46d4376f6 100644 --- a/internal/ent/generated/runtime.go +++ b/internal/ent/generated/runtime.go @@ -89,7 +89,21 @@ func init() { // originDescTarget is the schema descriptor for target field. originDescTarget := originFields[2].Descriptor() // origin.TargetValidator is a validator for the "target" field. It is called by the builders before save. - origin.TargetValidator = originDescTarget.Validators[0].(func(string) error) + origin.TargetValidator = func() func(string) error { + validators := originDescTarget.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(target string) error { + for _, fn := range fns { + if err := fn(target); err != nil { + return err + } + } + return nil + } + }() // originDescPortNumber is the schema descriptor for port_number field. originDescPortNumber := originFields[3].Descriptor() // origin.PortNumberValidator is a validator for the "port_number" field. It is called by the builders before save. diff --git a/internal/ent/schema/origin.go b/internal/ent/schema/origin.go index dec519b10..9eab0f6de 100644 --- a/internal/ent/schema/origin.go +++ b/internal/ent/schema/origin.go @@ -10,6 +10,8 @@ import ( "go.infratographer.com/x/entx" "go.infratographer.com/x/gidx" + + "go.infratographer.com/load-balancer-api/internal/ent/schema/validations" ) // Origin holds the schema definition for the Origin entity. @@ -39,6 +41,7 @@ func (Origin) Fields() []ent.Field { ), field.String("target"). NotEmpty(). + Validate(validations.IPAddress). // Comment("origin address"). Annotations( entgql.OrderField("target"), diff --git a/internal/ent/schema/validations/doc.go b/internal/ent/schema/validations/doc.go new file mode 100644 index 000000000..c81ae2cbe --- /dev/null +++ b/internal/ent/schema/validations/doc.go @@ -0,0 +1,2 @@ +// Package validations contains validation functions for ent fields +package validations diff --git a/internal/ent/schema/validations/errors.go b/internal/ent/schema/validations/errors.go new file mode 100644 index 000000000..e9f0ee8f5 --- /dev/null +++ b/internal/ent/schema/validations/errors.go @@ -0,0 +1,6 @@ +package validations + +import "errors" + +// ErrInvalidIPAddress is returned when the given string is not a valid IP address +var ErrInvalidIPAddress = errors.New("invalid ip address") diff --git a/internal/ent/schema/validations/validations.go b/internal/ent/schema/validations/validations.go new file mode 100644 index 000000000..867dd271e --- /dev/null +++ b/internal/ent/schema/validations/validations.go @@ -0,0 +1,15 @@ +// Package validations contains validation functions for ent fields +package validations + +import ( + "net" +) + +// IPAddress validates if the given string is a valid IP address +func IPAddress(ip string) error { + if net.ParseIP(ip) == nil { + return ErrInvalidIPAddress + } + + return nil +} diff --git a/internal/graphapi/origin_test.go b/internal/graphapi/origin_test.go index 66b18a77a..56508886b 100644 --- a/internal/graphapi/origin_test.go +++ b/internal/graphapi/origin_test.go @@ -141,6 +141,17 @@ func TestMutate_OriginCreate(t *testing.T) { Active: false, }, }, + { + TestName: "invalid target ip", + Input: graphclient.CreateLoadBalancerOriginInput{ + Name: "original", + Target: "not a valid target ip", + PortNumber: 22, + PoolID: pool1.ID, + Active: newBool(false), + }, + errorMsg: "invalid ip address", + }, } for _, tt := range testCases { diff --git a/internal/graphapi/port.resolvers.go b/internal/graphapi/port.resolvers.go index 1704ed1f4..f7237bec7 100644 --- a/internal/graphapi/port.resolvers.go +++ b/internal/graphapi/port.resolvers.go @@ -8,10 +8,9 @@ import ( "context" "strings" + "go.infratographer.com/load-balancer-api/internal/ent/generated" "go.infratographer.com/permissions-api/pkg/permissions" "go.infratographer.com/x/gidx" - - "go.infratographer.com/load-balancer-api/internal/ent/generated" ) // LoadBalancerPortCreate is the resolver for the loadBalancerPortCreate field.