From 5b8ad84c13c447ced1e1ebfd72fb3b5e79b45bdc Mon Sep 17 00:00:00 2001 From: Levente Liu Date: Fri, 23 Jul 2021 09:44:02 +0800 Subject: [PATCH] Add all-validator support (#443) * Add all-validator support * Use single argument instead of variadic Since we are working on v2, keeping the signature compatible with v1 is not actually necessary. Use a clearer single argument instead. Also add comments for the new argument. --- interceptors/validator/validator.go | 58 ++++++++++++++++++---- interceptors/validator/validator_test.go | 42 +++++++++++++--- testing/testpb/test.manual_validator.pb.go | 10 +++- 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index 864adde7a..f68884745 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -11,6 +11,12 @@ import ( "google.golang.org/grpc/status" ) +// The validateAller interface at protoc-gen-validate main branch. +// See https://github.com/envoyproxy/protoc-gen-validate/pull/468. +type validateAller interface { + ValidateAll() error +} + // The validate interface starting with protoc-gen-validate v0.6.0. // See https://github.com/envoyproxy/protoc-gen-validate/pull/455. type validator interface { @@ -22,9 +28,25 @@ type validatorLegacy interface { Validate() error } -// Calls the Validate function on a proto message using either the current or legacy interface if the Validate function -// is present. If validation fails, the error is wrapped with `InvalidArgument` and returned. -func validate(req interface{}) error { +func validate(req interface{}, all bool) error { + if all { + switch v := req.(type) { + case validateAller: + if err := v.ValidateAll(); err != nil { + return status.Error(codes.InvalidArgument, err.Error()) + } + case validator: + if err := v.Validate(true); err != nil { + return status.Error(codes.InvalidArgument, err.Error()) + } + case validatorLegacy: + // Fallback to legacy validator + if err := v.Validate(); err != nil { + return status.Error(codes.InvalidArgument, err.Error()) + } + } + return nil + } switch v := req.(type) { case validatorLegacy: if err := v.Validate(); err != nil { @@ -41,9 +63,13 @@ func validate(req interface{}) error { // UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. // // Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. -func UnaryServerInterceptor() grpc.UnaryServerInterceptor { +// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor +// returns ALL validation error as a wrapped multi-error. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. +func UnaryServerInterceptor(all bool) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if err := validate(req); err != nil { + if err := validate(req, all); err != nil { return nil, err } return handler(ctx, req) @@ -53,9 +79,13 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor { // UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages. // // Invalid messages will be rejected with `InvalidArgument` before sending the request to server. -func UnaryClientInterceptor() grpc.UnaryClientInterceptor { +// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor +// returns ALL validation error as a wrapped multi-error. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. +func UnaryClientInterceptor(all bool) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(req); err != nil { + if err := validate(req, all); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -64,18 +94,26 @@ func UnaryClientInterceptor() grpc.UnaryClientInterceptor { // StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. // +// If `all` is false, the interceptor returns first validation error. Otherwise the interceptor +// returns ALL validation error as a wrapped multi-error. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. // The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the // type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace // handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on // calls to `stream.Recv()`. -func StreamServerInterceptor() grpc.StreamServerInterceptor { +func StreamServerInterceptor(all bool) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - wrapper := &recvWrapper{stream} + wrapper := &recvWrapper{ + all: all, + ServerStream: stream, + } return handler(srv, wrapper) } } type recvWrapper struct { + all bool grpc.ServerStream } @@ -83,7 +121,7 @@ func (s *recvWrapper) RecvMsg(m interface{}) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - if err := validate(m); err != nil { + if err := validate(m, s.all); err != nil { return err } return nil diff --git a/interceptors/validator/validator_test.go b/interceptors/validator/validator_test.go index 53d4c73da..cc4337a4a 100644 --- a/interceptors/validator/validator_test.go +++ b/interceptors/validator/validator_test.go @@ -18,32 +18,58 @@ import ( ) func TestValidateWrapper(t *testing.T) { - assert.NoError(t, validate(testpb.GoodPing)) - assert.Error(t, validate(testpb.BadPing)) - - assert.NoError(t, validate(testpb.GoodPingResponse)) - assert.Error(t, validate(testpb.BadPingResponse)) + assert.NoError(t, validate(testpb.GoodPing, false)) + assert.Error(t, validate(testpb.BadPing, false)) + assert.NoError(t, validate(testpb.GoodPing, true)) + assert.Error(t, validate(testpb.BadPing, true)) + + assert.NoError(t, validate(testpb.GoodPingError, false)) + assert.Error(t, validate(testpb.BadPingError, false)) + assert.NoError(t, validate(testpb.GoodPingError, true)) + assert.Error(t, validate(testpb.BadPingError, true)) + + assert.NoError(t, validate(testpb.GoodPingResponse, false)) + assert.NoError(t, validate(testpb.GoodPingResponse, true)) + assert.Error(t, validate(testpb.BadPingResponse, false)) + assert.Error(t, validate(testpb.BadPingResponse, true)) } func TestValidatorTestSuite(t *testing.T) { s := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(StreamServerInterceptor()), - grpc.UnaryInterceptor(UnaryServerInterceptor()), + grpc.StreamInterceptor(StreamServerInterceptor(false)), + grpc.UnaryInterceptor(UnaryServerInterceptor(false)), }, }, } suite.Run(t, s) + sAll := &ValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(StreamServerInterceptor(true)), + grpc.UnaryInterceptor(UnaryServerInterceptor(true)), + }, + }, + } + suite.Run(t, sAll) cs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(UnaryClientInterceptor()), + grpc.WithUnaryInterceptor(UnaryClientInterceptor(false)), }, }, } suite.Run(t, cs) + csAll := &ClientValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ClientOpts: []grpc.DialOption{ + grpc.WithUnaryInterceptor(UnaryClientInterceptor(true)), + }, + }, + } + suite.Run(t, csAll) } type ValidatorTestSuite struct { diff --git a/testing/testpb/test.manual_validator.pb.go b/testing/testpb/test.manual_validator.pb.go index 6bbf3147f..ec1e8f639 100644 --- a/testing/testpb/test.manual_validator.pb.go +++ b/testing/testpb/test.manual_validator.pb.go @@ -15,7 +15,7 @@ func (x *PingRequest) Validate(bool) error { return nil } -func (x *PingErrorRequest) Validate(bool) error { +func (x *PingErrorRequest) Validate() error { if x.SleepTimeMs > 10000 { return errors.New("cannot sleep for more than 10s") } @@ -44,6 +44,14 @@ func (x *PingResponse) Validate() error { return nil } +// Implements the new ValidateAll interface from protoc-gen-validate. +func (x *PingResponse) ValidateAll() error { + if x.Counter > math.MaxInt16 { + return errors.New("ping allocation exceeded") + } + return nil +} + var ( GoodPing = &PingRequest{Value: "something", SleepTimeMs: 9999} GoodPingError = &PingErrorRequest{Value: "something", SleepTimeMs: 9999}