Skip to content

Commit

Permalink
Add all-validator support (#443)
Browse files Browse the repository at this point in the history
* Add all-validator support

* Use single argument instead of variadic

Since we are working on v2, keeping the signature
compatible with v1 is not actually necessary.

Use a clearer single argument instead. Also add comments
for the new argument.
  • Loading branch information
leventeliu authored Jul 23, 2021
1 parent dc87da6 commit 5b8ad84
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 19 deletions.
58 changes: 48 additions & 10 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import (
"google.golang.org/grpc/status"
)

// The validateAller interface at protoc-gen-validate main branch.
// See https://github.com/envoyproxy/protoc-gen-validate/pull/468.
type validateAller interface {
ValidateAll() error
}

// The validate interface starting with protoc-gen-validate v0.6.0.
// See https://github.com/envoyproxy/protoc-gen-validate/pull/455.
type validator interface {
Expand All @@ -22,9 +28,25 @@ type validatorLegacy interface {
Validate() error
}

// Calls the Validate function on a proto message using either the current or legacy interface if the Validate function
// is present. If validation fails, the error is wrapped with `InvalidArgument` and returned.
func validate(req interface{}) error {
func validate(req interface{}, all bool) error {
if all {
switch v := req.(type) {
case validateAller:
if err := v.ValidateAll(); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(true); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
case validatorLegacy:
// Fallback to legacy validator
if err := v.Validate(); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
}
return nil
}
switch v := req.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
Expand All @@ -41,9 +63,13 @@ func validate(req interface{}) error {
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryServerInterceptor(all bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validate(req); err != nil {
if err := validate(req, all); err != nil {
return nil, err
}
return handler(ctx, req)
Expand All @@ -53,9 +79,13 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryClientInterceptor(all bool) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req); err != nil {
if err := validate(req, all); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
Expand All @@ -64,26 +94,34 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor {

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
//
// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
func StreamServerInterceptor(all bool) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{stream}
wrapper := &recvWrapper{
all: all,
ServerStream: stream,
}
return handler(srv, wrapper)
}
}

type recvWrapper struct {
all bool
grpc.ServerStream
}

func (s *recvWrapper) RecvMsg(m interface{}) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := validate(m); err != nil {
if err := validate(m, s.all); err != nil {
return err
}
return nil
Expand Down
42 changes: 34 additions & 8 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,58 @@ import (
)

func TestValidateWrapper(t *testing.T) {
assert.NoError(t, validate(testpb.GoodPing))
assert.Error(t, validate(testpb.BadPing))

assert.NoError(t, validate(testpb.GoodPingResponse))
assert.Error(t, validate(testpb.BadPingResponse))
assert.NoError(t, validate(testpb.GoodPing, false))
assert.Error(t, validate(testpb.BadPing, false))
assert.NoError(t, validate(testpb.GoodPing, true))
assert.Error(t, validate(testpb.BadPing, true))

assert.NoError(t, validate(testpb.GoodPingError, false))
assert.Error(t, validate(testpb.BadPingError, false))
assert.NoError(t, validate(testpb.GoodPingError, true))
assert.Error(t, validate(testpb.BadPingError, true))

assert.NoError(t, validate(testpb.GoodPingResponse, false))
assert.NoError(t, validate(testpb.GoodPingResponse, true))
assert.Error(t, validate(testpb.BadPingResponse, false))
assert.Error(t, validate(testpb.BadPingResponse, true))
}

func TestValidatorTestSuite(t *testing.T) {
s := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamServerInterceptor()),
grpc.UnaryInterceptor(UnaryServerInterceptor()),
grpc.StreamInterceptor(StreamServerInterceptor(false)),
grpc.UnaryInterceptor(UnaryServerInterceptor(false)),
},
},
}
suite.Run(t, s)
sAll := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(StreamServerInterceptor(true)),
grpc.UnaryInterceptor(UnaryServerInterceptor(true)),
},
},
}
suite.Run(t, sAll)

cs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(UnaryClientInterceptor()),
grpc.WithUnaryInterceptor(UnaryClientInterceptor(false)),
},
},
}
suite.Run(t, cs)
csAll := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(UnaryClientInterceptor(true)),
},
},
}
suite.Run(t, csAll)
}

type ValidatorTestSuite struct {
Expand Down
10 changes: 9 additions & 1 deletion testing/testpb/test.manual_validator.pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (x *PingRequest) Validate(bool) error {
return nil
}

func (x *PingErrorRequest) Validate(bool) error {
func (x *PingErrorRequest) Validate() error {
if x.SleepTimeMs > 10000 {
return errors.New("cannot sleep for more than 10s")
}
Expand Down Expand Up @@ -44,6 +44,14 @@ func (x *PingResponse) Validate() error {
return nil
}

// Implements the new ValidateAll interface from protoc-gen-validate.
func (x *PingResponse) ValidateAll() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
}
return nil
}

var (
GoodPing = &PingRequest{Value: "something", SleepTimeMs: 9999}
GoodPingError = &PingErrorRequest{Value: "something", SleepTimeMs: 9999}
Expand Down

0 comments on commit 5b8ad84

Please sign in to comment.