Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mesh: add validation for the new pbmesh resources #18410

Merged
merged 16 commits into from
Aug 22, 2023
2 changes: 1 addition & 1 deletion internal/catalog/exports.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover
// FailoverPolicyMapper maintains the bidirectional tracking relationship of a
// FailoverPolicy to the Services related to it.
type FailoverPolicyMapper interface {
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
UntrackFailover(failoverID *pbresource.ID)
FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID
}
Expand Down
20 changes: 10 additions & 10 deletions internal/catalog/internal/controllers/failover/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type FailoverMapper interface {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service
// events into FailoverPolicy events properly.
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])

// UntrackFailover forgets the links inserted by TrackFailover for the
// provided FailoverPolicyID.
Expand Down Expand Up @@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
rt.Logger.Error("error retrieving corresponding service", "error", err)
return err
}
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service])
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service])
if service != nil {
destServices[resource.NewReferenceKey(serviceID)] = service
}
Expand Down Expand Up @@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
return nil
}

func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id)
func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id)
}

func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) {
return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id)
func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) {
return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id)
}

func computeNewStatus(
failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy],
service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy],
service *resource.DecodedResource[*pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Status {
if service == nil {
return &pbresource.Status{
Expand Down Expand Up @@ -238,7 +238,7 @@ func computeNewStatus(

func serviceHasPort(
dest *pbcatalog.FailoverDestination,
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Condition {
key := resource.NewReferenceKey(dest.Ref)
destService, ok := destServices[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func New() *Mapper {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service events
// into FailoverPolicy events properly.
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) {
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) {
destRefs := failover.Data.GetUnderlyingDestinationRefs()
destRefs = append(destRefs, &pbresource.Reference{
Type: types.ServiceType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1)
failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1)
failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1)

fail2 := rtest.Resource(types.FailoverPolicyType, "www").
WithData(t, &pbcatalog.FailoverPolicy{
Expand All @@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail2)
failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2)
failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2)

fail1_updated := rtest.Resource(types.FailoverPolicyType, "api").
WithData(t, &pbcatalog.FailoverPolicy{
Expand All @@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1_updated)
failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated)
failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated)

m := New()

Expand Down
12 changes: 6 additions & 6 deletions internal/catalog/internal/types/failover_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) {

err := MutateFailoverPolicy(res)

got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)

if tc.expectErr == "" {
require.NoError(t, err)
Expand Down Expand Up @@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) {
require.NoError(t, MutateFailoverPolicy(res))

// Verify that mutate didn't actually change the object.
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)

err := ValidateFailoverPolicy(res)

// Verify that validate didn't actually change the object.
got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)

if tc.expectErr == "" {
Expand Down Expand Up @@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) {
resourcetest.ValidateAndNormalize(t, registry, tc.failover)
resourcetest.ValidateAndNormalize(t, registry, tc.expect)

svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect)
svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect)

inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy)

Expand Down
47 changes: 46 additions & 1 deletion internal/mesh/internal/types/computed_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package types

import (
"github.com/hashicorp/go-multierror"

"github.com/hashicorp/consul/internal/resource"
pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v1alpha1"
"github.com/hashicorp/consul/proto-public/pbresource"
Expand All @@ -27,6 +29,49 @@ func RegisterComputedRoutes(r resource.Registry) {
r.Register(resource.Registration{
Type: ComputedRoutesV1Alpha1Type,
Proto: &pbmesh.ComputedRoutes{},
Validate: nil,
Validate: ValidateComputedRoutes,
})
}

func ValidateComputedRoutes(res *pbresource.Resource) error {
var config pbmesh.ComputedRoutes

if err := res.Data.UnmarshalTo(&config); err != nil {
return resource.NewErrDataParse(&config, err)
}

var merr error

if len(config.PortedConfigs) == 0 {
merr = multierror.Append(merr, resource.ErrInvalidField{
Name: "ported_configs",
Wrapped: resource.ErrEmpty,
})
}

// TODO(rb): do more elaborate validation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still valid?

Copy link
Member Author

@rboyer rboyer Aug 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to come back to this in NET-5066. I'm not sure how much more validation is required since it is a generated resource type, but we'll likely want a bit more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it probably is. Presumably both the Config and Targets fields should have their internals validated.


for port, pmc := range config.PortedConfigs {
wrapErr := func(err error) error {
return resource.ErrInvalidMapValue{
Map: "ported_configs",
Key: port,
Wrapped: err,
}
}
if pmc.Config == nil {
merr = multierror.Append(merr, wrapErr(resource.ErrInvalidField{
Name: "config",
Wrapped: resource.ErrEmpty,
}))
}
if len(pmc.Targets) == 0 {
merr = multierror.Append(merr, wrapErr(resource.ErrInvalidField{
Name: "targets",
Wrapped: resource.ErrEmpty,
}))
}
}

return merr
}
92 changes: 92 additions & 0 deletions internal/mesh/internal/types/computed_routes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package types

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/hashicorp/consul/internal/resource/resourcetest"
pbmesh "github.com/hashicorp/consul/proto-public/pbmesh/v1alpha1"
"github.com/hashicorp/consul/proto/private/prototest"
"github.com/hashicorp/consul/sdk/testutil"
)

func TestValidateComputedRoutes(t *testing.T) {
type testcase struct {
routes *pbmesh.ComputedRoutes
expectErr string
}

run := func(t *testing.T, tc testcase) {
res := resourcetest.Resource(ComputedRoutesType, "api").
WithData(t, tc.routes).
Build()

err := ValidateComputedRoutes(res)

// Verify that validate didn't actually change the object.
got := resourcetest.MustDecode[*pbmesh.ComputedRoutes](t, res)
prototest.AssertDeepEqual(t, tc.routes, got.Data)

if tc.expectErr == "" {
require.NoError(t, err)
} else {
testutil.RequireErrorContains(t, err, tc.expectErr)
}
}

cases := map[string]testcase{
"empty": {
routes: &pbmesh.ComputedRoutes{},
expectErr: `invalid "ported_configs" field: cannot be empty`,
},
"empty config": {
routes: &pbmesh.ComputedRoutes{
PortedConfigs: map[string]*pbmesh.ComputedPortRoutes{
"http": {
Config: nil,
Targets: map[string]*pbmesh.BackendTargetDetails{
"foo": {},
},
},
},
},
expectErr: `invalid value of key "http" within ported_configs: invalid "config" field: cannot be empty`,
},
"empty targets": {
rboyer marked this conversation as resolved.
Show resolved Hide resolved
routes: &pbmesh.ComputedRoutes{
PortedConfigs: map[string]*pbmesh.ComputedPortRoutes{
"http": {
Config: &pbmesh.ComputedPortRoutes_Tcp{
Tcp: &pbmesh.InterpretedTCPRoute{},
},
},
},
},
expectErr: `invalid value of key "http" within ported_configs: invalid "targets" field: cannot be empty`,
},
"valid": {
routes: &pbmesh.ComputedRoutes{
PortedConfigs: map[string]*pbmesh.ComputedPortRoutes{
"http": {
Config: &pbmesh.ComputedPortRoutes_Tcp{
Tcp: &pbmesh.InterpretedTCPRoute{},
},
Targets: map[string]*pbmesh.BackendTargetDetails{
"foo": {},
},
},
},
},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
run(t, tc)
})
}
}
Loading