diff --git a/pkg/grpc_testing/recorder.go b/pkg/grpc_testing/recorder.go new file mode 100644 index 00000000000..d6b6d2aac2b --- /dev/null +++ b/pkg/grpc_testing/recorder.go @@ -0,0 +1,69 @@ +// Copyright 2021 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpc_testing + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type GrpcRecorder struct { + mux sync.RWMutex + requests []RequestInfo +} + +type RequestInfo struct { + FullMethod string + Authority string +} + +func (ri *GrpcRecorder) UnaryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + ri.record(toRequestInfo(ctx, info)) + resp, err := handler(ctx, req) + return resp, err + } +} + +func (ri *GrpcRecorder) RecordedRequests() []RequestInfo { + ri.mux.RLock() + defer ri.mux.RUnlock() + reqs := make([]RequestInfo, len(ri.requests)) + copy(reqs, ri.requests) + return reqs +} + +func toRequestInfo(ctx context.Context, info *grpc.UnaryServerInfo) RequestInfo { + req := RequestInfo{ + FullMethod: info.FullMethod, + } + md, ok := metadata.FromIncomingContext(ctx) + if ok { + as := md.Get(":authority") + if len(as) != 0 { + req.Authority = as[0] + } + } + return req +} + +func (ri *GrpcRecorder) record(r RequestInfo) { + ri.mux.Lock() + defer ri.mux.Unlock() + ri.requests = append(ri.requests, r) +} diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 001302f991b..aca637c8627 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -539,7 +539,7 @@ func (e *Etcd) servePeers() (err error) { for _, p := range e.Peers { u := p.Listener.Addr().String() - gs := v3rpc.Server(e.Server, peerTLScfg) + gs := v3rpc.Server(e.Server, peerTLScfg, nil) m := cmux.New(p.Listener) go gs.Serve(m.Match(cmux.HTTP2())) srv := &http.Server{ diff --git a/server/embed/serve.go b/server/embed/serve.go index 17b55384ebb..c3e786321cd 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -110,7 +110,7 @@ func (sctx *serveCtx) serve( }() if sctx.insecure { - gs = v3rpc.Server(s, nil, gopts...) + gs = v3rpc.Server(s, nil, nil, gopts...) v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { @@ -148,7 +148,7 @@ func (sctx *serveCtx) serve( if tlsErr != nil { return tlsErr } - gs = v3rpc.Server(s, tlscfg, gopts...) + gs = v3rpc.Server(s, tlscfg, nil, gopts...) v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { diff --git a/server/etcdserver/api/v3rpc/grpc.go b/server/etcdserver/api/v3rpc/grpc.go index 26c52b385b4..ea3dd75705f 100644 --- a/server/etcdserver/api/v3rpc/grpc.go +++ b/server/etcdserver/api/v3rpc/grpc.go @@ -36,19 +36,21 @@ const ( maxSendBytes = math.MaxInt32 ) -func Server(s *etcdserver.EtcdServer, tls *tls.Config, gopts ...grpc.ServerOption) *grpc.Server { +func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server { var opts []grpc.ServerOption opts = append(opts, grpc.CustomCodec(&codec{})) if tls != nil { bundle := credentials.NewBundle(credentials.Config{TLSConfig: tls}) opts = append(opts, grpc.Creds(bundle.TransportCredentials())) } - chainUnaryInterceptors := []grpc.UnaryServerInterceptor{ newLogUnaryInterceptor(s), newUnaryInterceptor(s), grpc_prometheus.UnaryServerInterceptor, } + if interceptor != nil { + chainUnaryInterceptors = append(chainUnaryInterceptors, interceptor) + } chainStreamInterceptors := []grpc.StreamServerInterceptor{ newStreamInterceptor(s), diff --git a/tests/integration/cluster.go b/tests/integration/cluster.go index a98fe44ea18..c8bb969544d 100644 --- a/tests/integration/cluster.go +++ b/tests/integration/cluster.go @@ -39,6 +39,7 @@ import ( "go.etcd.io/etcd/client/pkg/v3/types" "go.etcd.io/etcd/client/v2" "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/pkg/v3/grpc_testing" "go.etcd.io/etcd/raft/v3" "go.etcd.io/etcd/server/v3/config" "go.etcd.io/etcd/server/v3/embed" @@ -602,6 +603,8 @@ type member struct { isLearner bool closed bool + + grpcServerRecorder *grpc_testing.GrpcRecorder } func (m *member) GRPCURL() string { return m.grpcURL } @@ -733,7 +736,7 @@ func mustNewMember(t testutil.TB, mcfg memberConfig) *member { m.WarningApplyDuration = embed.DefaultWarningApplyDuration m.V2Deprecation = config.V2_DEPR_DEFAULT - + m.grpcServerRecorder = &grpc_testing.GrpcRecorder{} m.Logger = memberLogger(t, mcfg.name) t.Cleanup(func() { // if we didn't cleanup the logger, the consecutive test @@ -945,8 +948,8 @@ func (m *member) Launch() error { return err } } - m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerOpts...) - m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg) + m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerRecorder.UnaryInterceptor(), m.grpcServerOpts...) + m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg, m.grpcServerRecorder.UnaryInterceptor()) m.serverClient = v3client.New(m.s) lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient)) epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient)) @@ -1081,6 +1084,10 @@ func (m *member) Launch() error { return nil } +func (m *member) RecordedRequests() []grpc_testing.RequestInfo { + return m.grpcServerRecorder.RecordedRequests() +} + func (m *member) WaitOK(t testutil.TB) { m.WaitStarted(t) for m.s.Leader() == 0 { @@ -1370,8 +1377,9 @@ func (p SortableMemberSliceByPeerURLs) Swap(i, j int) { p[i], p[j] = p[j], p[i] type ClusterV3 struct { *cluster - mu sync.Mutex - clients []*clientv3.Client + mu sync.Mutex + clients []*clientv3.Client + clusterClient *clientv3.Client } // NewClusterV3 returns a launched cluster with a grpc client connection @@ -1417,6 +1425,11 @@ func (c *ClusterV3) Terminate(t testutil.TB) { t.Error(err) } } + if c.clusterClient != nil { + if err := c.clusterClient.Close(); err != nil { + t.Error(err) + } + } c.mu.Unlock() c.cluster.Terminate(t) } @@ -1429,6 +1442,25 @@ func (c *ClusterV3) Client(i int) *clientv3.Client { return c.clients[i] } +func (c *ClusterV3) ClusterClient() (client *clientv3.Client, err error) { + if c.clusterClient == nil { + endpoints := []string{} + for _, m := range c.Members { + endpoints = append(endpoints, m.grpcURL) + } + cfg := clientv3.Config{ + Endpoints: endpoints, + DialTimeout: 5 * time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + } + c.clusterClient, err = newClientV3(cfg, cfg.Logger) + if err != nil { + return nil, err + } + } + return c.clusterClient, nil +} + // NewClientV3 creates a new grpc client connection to the member func (c *ClusterV3) NewClientV3(memberIndex int) (*clientv3.Client, error) { return NewClientV3(c.Members[memberIndex]) diff --git a/tests/integration/grpc_test.go b/tests/integration/grpc_test.go new file mode 100644 index 00000000000..49cbd1df511 --- /dev/null +++ b/tests/integration/grpc_test.go @@ -0,0 +1,182 @@ +// Copyright 2021 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "context" + tls "crypto/tls" + "fmt" + "strings" + "testing" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" +) + +func TestAuthority(t *testing.T) { + tcs := []struct { + name string + useTCP bool + useTLS bool + // Pattern used to generate endpoints for client. Fields filled + // %d - will be filled with member grpc port + // %s - will be filled with member name + clientURLPattern string + + // Pattern used to validate authority received by server. Fields filled: + // %s - list of endpoints concatenated with ";" + expectAuthorityPattern string + }{ + { + name: "unix:path", + clientURLPattern: "unix:localhost:%s", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "unix://absolute_path", + clientURLPattern: "unix://localhost:%s", + expectAuthorityPattern: "#initially=[%s]", + }, + // "unixs" is not standard schema supported by etcd + { + name: "unixs:absolute_path", + useTLS: true, + clientURLPattern: "unixs:localhost:%s", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "unixs://absolute_path", + useTLS: true, + clientURLPattern: "unixs://localhost:%s", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "http://domain[:port]", + useTCP: true, + clientURLPattern: "http://localhost:%d", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "https://domain[:port]", + useTLS: true, + useTCP: true, + clientURLPattern: "https://localhost:%d", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "http://address[:port]", + useTCP: true, + clientURLPattern: "http://127.0.0.1:%d", + expectAuthorityPattern: "#initially=[%s]", + }, + { + name: "https://address[:port]", + useTCP: true, + useTLS: true, + clientURLPattern: "https://127.0.0.1:%d", + expectAuthorityPattern: "#initially=[%s]", + }, + } + for _, tc := range tcs { + for _, clusterSize := range []int{1, 3} { + t.Run(fmt.Sprintf("Size: %d, Scenario: %q", clusterSize, tc.name), func(t *testing.T) { + BeforeTest(t) + cfg := ClusterConfig{ + Size: clusterSize, + UseTCP: tc.useTCP, + UseIP: tc.useTCP, + } + cfg, tlsConfig := setupTLS(t, tc.useTLS, cfg) + clus := NewClusterV3(t, &cfg) + defer clus.Terminate(t) + endpoints := templateEndpoints(t, tc.clientURLPattern, clus) + + kv := setupClient(t, tc.clientURLPattern, clus, tlsConfig) + defer kv.Close() + + _, err := kv.Put(context.TODO(), "foo", "bar") + if err != nil { + t.Fatal(err) + } + + assertAuthority(t, fmt.Sprintf(tc.expectAuthorityPattern, strings.Join(endpoints, ";")), clus) + }) + } + } +} + +func setupTLS(t *testing.T, useTLS bool, cfg ClusterConfig) (ClusterConfig, *tls.Config) { + t.Helper() + if useTLS { + cfg.ClientTLS = &testTLSInfo + tlsConfig, err := testTLSInfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + return cfg, tlsConfig + } + return cfg, nil +} + +func setupClient(t *testing.T, endpointPattern string, clus *ClusterV3, tlsConfig *tls.Config) *clientv3.Client { + t.Helper() + endpoints := templateEndpoints(t, endpointPattern, clus) + kv, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: 5 * time.Second, + DialOptions: []grpc.DialOption{grpc.WithBlock()}, + TLS: tlsConfig, + }) + if err != nil { + t.Fatal(err) + } + return kv +} + +func templateEndpoints(t *testing.T, pattern string, clus *ClusterV3) []string { + t.Helper() + endpoints := []string{} + for _, m := range clus.Members { + ent := pattern + if strings.Contains(ent, "%d") { + ent = fmt.Sprintf(ent, GrpcPortNumber(m.UniqNumber, m.MemberNumber)) + } + if strings.Contains(ent, "%s") { + ent = fmt.Sprintf(ent, m.Name) + } + if strings.Contains(ent, "%") { + t.Fatalf("Failed to template pattern, %% symbol left %q", ent) + } + endpoints = append(endpoints, ent) + } + return endpoints +} + +func assertAuthority(t *testing.T, expectedAuthority string, clus *ClusterV3) { + t.Helper() + requestsFound := 0 + for _, m := range clus.Members { + for _, r := range m.RecordedRequests() { + requestsFound++ + if r.Authority != expectedAuthority { + t.Errorf("Got unexpected authority header, member: %q, request: %q, got authority: %q, expected %q", m.Name, r.FullMethod, r.Authority, expectedAuthority) + } + } + } + if requestsFound == 0 { + t.Errorf("Expected at least one request") + } +}