From 2e13c4037da3a3c63c9c8e706a5bc530a578dd0d Mon Sep 17 00:00:00 2001 From: eshitachandwani Date: Tue, 26 Nov 2024 14:09:46 +0530 Subject: [PATCH] tests --- .../delegatingresolver/delegatingresolver.go | 11 + internal/testutils/proxy.go | 114 +++++----- test/proxy_test.go | 208 ++++++++++++++---- 3 files changed, 236 insertions(+), 97 deletions(-) diff --git a/internal/resolver/delegatingresolver/delegatingresolver.go b/internal/resolver/delegatingresolver/delegatingresolver.go index f947e3bb00b5..00d56d51f117 100644 --- a/internal/resolver/delegatingresolver/delegatingresolver.go +++ b/internal/resolver/delegatingresolver/delegatingresolver.go @@ -69,7 +69,16 @@ func mapAddress(address string) (*url.URL, error) { return url, nil } +// OnClientResolution is a no-op function in non-test code. In tests, it can +// be overwritten to send a signal to a channel, indicating that client-side +// name resolution was triggered. This enables tests to verify that resolution +// is bypassed when a proxy is in use. +var OnClientResolution = func(int) { /* no-op */ } + +var OnDelegatingResolverCalled = func(int) { /* no-op */ } + func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions, targetResolverBuilder resolver.Builder, targetResolutionEnabled bool) (resolver.Resolver, error) { + OnDelegatingResolverCalled(1) r := &delegatingResolver{ target: target, cc: cc, @@ -84,6 +93,7 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti if r.proxyURL == nil { // Add an info log here. logger.Info("No proxy URL detected") + OnClientResolution(1) return targetResolverBuilder.Build(target, cc, opts) } @@ -95,6 +105,7 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti } else { // targetResolverBuilder must not be nil. If it is nil, the channel should // error out and not call this function. + OnClientResolution(1) r.targetResolver, err = targetResolverBuilder.Build(target, &wrappingClientConn{r, "target"}, opts) if err != nil { return nil, fmt.Errorf("delegating_resolver: unable to build the target resolver : %v", err) diff --git a/internal/testutils/proxy.go b/internal/testutils/proxy.go index 7169ed4d8ab4..26c6b3b4968a 100644 --- a/internal/testutils/proxy.go +++ b/internal/testutils/proxy.go @@ -38,56 +38,8 @@ type ProxyServer struct { requestCheck func(*http.Request) error } -func (p *ProxyServer) run(errCh chan error, backendAddr string) { - fmt.Printf("run proxy server") - in, err := p.lis.Accept() - if err != nil { - return - } - p.in = in +func (p *ProxyServer) run(errCh chan error, doneCh chan struct{}, backendAddr string) { - req, err := http.ReadRequest(bufio.NewReader(in)) - if err != nil { - errCh <- fmt.Errorf("failed to read CONNECT req: %v", err) - // p.t.Errorf("failed to read CONNECT req: %v", err) - return - } - fmt.Printf("request sent : %+v\n", req) - if err := p.requestCheck(req); err != nil { - resp := http.Response{StatusCode: http.StatusMethodNotAllowed} - resp.Write(p.in) - p.in.Close() - errCh <- fmt.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) - // p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) - return - } - var out net.Conn - //To mimick name resolution on proxy, if proxy gets unresolved address in - // the CONNECT request, dial to the backend address directly. - if net.ParseIP(req.URL.Host) == nil { - fmt.Printf("dialing to backend add: %v from proxy", backendAddr) - out, err = net.Dial("tcp", backendAddr) - } else { - //If proxy gets resolved address in CONNECT req,dial to the actual server - // with address sent in the connect request. - fmt.Printf("\ndialing with host in connect req\n") - out, err = net.Dial("tcp", req.URL.Host) - } - if err != nil { - errCh <- fmt.Errorf("failed to dial to server: %v", err) - // p.t.Errorf("failed to dial to server: %v", err) - return - } - out.SetDeadline(time.Now().Add(defaultTestTimeout)) - //response OK to client - resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} - var buf bytes.Buffer - resp.Write(&buf) - p.in.Write(buf.Bytes()) - p.out = out - // perform the proxy function, i.e pass the data from client to server and server to client. - go io.Copy(p.in, p.out) - go io.Copy(p.out, p.in) } func (p *ProxyServer) Stop() { @@ -100,12 +52,72 @@ func (p *ProxyServer) Stop() { } } -func NewProxyServer(lis net.Listener, requestCheck func(*http.Request) error, errCh chan error, backendAddr string) *ProxyServer { +func NewProxyServer(lis net.Listener, requestCheck func(*http.Request) error, errCh chan error, doneCh chan struct{}, backendAddr string, resolutionOnClient bool) *ProxyServer { fmt.Printf("starting proxy server") p := &ProxyServer{ lis: lis, requestCheck: requestCheck, } - go p.run(errCh, backendAddr) + go func() { + fmt.Printf("run proxy server") + in, err := p.lis.Accept() + if err != nil { + return + } + p.in = in + + req, err := http.ReadRequest(bufio.NewReader(in)) + if err != nil { + errCh <- fmt.Errorf("failed to read CONNECT req: %v", err) + // p.t.Errorf("failed to read CONNECT req: %v", err) + return + } + fmt.Printf("request sent : %+v\n", req) + if err := p.requestCheck(req); err != nil { + resp := http.Response{StatusCode: http.StatusMethodNotAllowed} + resp.Write(p.in) + p.in.Close() + errCh <- fmt.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) + // p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) + return + } + var out net.Conn + //To mimick name resolution on proxy, if proxy gets unresolved address in + // the CONNECT request, dial to the backend address directly. + // if net.ParseIP(req.URL.Host) == nil { + // fmt.Printf("dialing to backend add: %v from proxy", backendAddr) + // out, err = net.Dial("tcp", backendAddr) + // } else { + // //If proxy gets resolved address in CONNECT req,dial to the actual server + // // with address sent in the connect request. + // fmt.Printf("\ndialing with host in connect req\n") + // out, err = net.Dial("tcp", req.URL.Host) + // } + + if resolutionOnClient { + fmt.Printf("resolving host in connect req\n") + out, err = net.Dial("tcp", req.URL.Host) + } else { + fmt.Printf("not resolving host in connect req\n") + out, err = net.Dial("tcp",backendAddr) + } + + if err != nil { + errCh <- fmt.Errorf("failed to dial to server: %v", err) + // p.t.Errorf("failed to dial to server: %v", err) + return + } + out.SetDeadline(time.Now().Add(defaultTestTimeout)) + //response OK to client + resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} + var buf bytes.Buffer + resp.Write(&buf) + p.in.Write(buf.Bytes()) + p.out = out + // perform the proxy function, i.e pass the data from client to server and server to client. + go io.Copy(p.in, p.out) + go io.Copy(p.out, p.in) + close(doneCh) + }() return p } diff --git a/test/proxy_test.go b/test/proxy_test.go index d5d0961ca8fd..6c502a8bc9b6 100644 --- a/test/proxy_test.go +++ b/test/proxy_test.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/internal/resolver/delegatingresolver" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" + testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -54,11 +55,13 @@ func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() { } } -// TestGrpcDialWithProxy tests grpc.Dial with no resolver mentioned in targetURI -// (i.e it uses passthrough) with a proxy. -func (s) TestGrpcDialWithProxy(t *testing.T) { +func overwriteAndRestoreProxyEnv(proxyAddr string) func() { + origHTTPSProxy := envconfig.HTTPSProxy + envconfig.HTTPSProxy = proxyAddr + return func() { envconfig.HTTPSProxy = origHTTPSProxy } +} - //Create and start a backend server +func createAndStartBackendServer(t *testing.T) string { backend := &stubserver.StubServer{ EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, } @@ -67,28 +70,44 @@ func (s) TestGrpcDialWithProxy(t *testing.T) { } t.Logf("Started TestService backend at: %q", backend.Address) t.Cleanup(func() { backend.Stop() }) - backendAddr := backend.Address + return backend.Address +} +// TestGrpcDialWithProxy tests grpc.Dial with no resolver mentioned in targetURI +// (i.e it uses passthrough) with a proxy. +func (s) TestGrpcDialWithProxy(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resolutionCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution // Access using package name + delegatingresolver.OnClientResolution = func(int) { // Access using package name + resolutionCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) // Restore using package name + //Create and start a backend server + backendAddr := createAndStartBackendServer(t) proxyLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } errCh := make(chan error, 1) + doneCh := make(chan struct{}) defer func() { close(errCh) }() reqCheck := func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } - //emchandwani: TODO :Change to the correct string if req.URL.Host != unresProxyURI { return fmt.Errorf("unexpected URL.Host %q, want %q", req.URL.Host, unresProxyURI) } return nil } fmt.Printf("backendaddr sent: %v\n", backendAddr) - p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, backendAddr) + p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, doneCh, backendAddr, false) t.Cleanup(func() { p.Stop() }) + //set proxy env + defer overwriteAndRestoreProxyEnv("proxyexample.com")() // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { if req.URL.Host == unresProxyURI { @@ -100,10 +119,6 @@ func (s) TestGrpcDialWithProxy(t *testing.T) { return nil, nil } defer overwrite(hpfe)() - // set proxy env variable - origHTTPSProxy := envconfig.HTTPSProxy - envconfig.HTTPSProxy = "proxyexample.com" - defer func() { envconfig.HTTPSProxy = origHTTPSProxy }() // Dial to proxy server. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -122,10 +137,15 @@ func (s) TestGrpcDialWithProxy(t *testing.T) { select { case err := <-errCh: t.Fatalf("proxy server encountered an error: %v", err) - default: + case <-doneCh: t.Logf("proxy server succeeded") } - + select { + case <-resolutionCh: + // Success: OnClientResolution was called. + default: + t.Error("Client side resolution should be called but isn't") + } } // setupDNS unregisters the DNS resolver and registers a manual resolver for the @@ -145,21 +165,14 @@ func setupDNS(t *testing.T) *manual.Resolver { func (s) TestGrpcDialWithProxyandResolution(t *testing.T) { //Create and start a backend server - backend := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, - } - if err := backend.StartServer(); err != nil { - t.Fatalf("Failed to start backend: %v", err) - } - t.Logf("Started TestService backend at: %q", backend.Address) - t.Cleanup(func() { backend.Stop() }) - backendAddr := backend.Address + backendAddr := createAndStartBackendServer(t) proxyLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } errCh := make(chan error, 1) + doneCh := make(chan struct{}) defer func() { close(errCh) }() reqCheck := func(req *http.Request) error { if req.Method != http.MethodConnect { @@ -174,26 +187,25 @@ func (s) TestGrpcDialWithProxyandResolution(t *testing.T) { return nil } fmt.Printf("backendaddr sent: %v\n", backendAddr) - p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, backendAddr) + p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, doneCh, backendAddr, false) t.Cleanup(func() { p.Stop() }) + // set proxy env variable + defer overwriteAndRestoreProxyEnv("proxyexample.com")() + // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { if req.URL.Host == unresProxyURI { return &url.URL{ Scheme: "https", - Host: "proxyURL.com", //emchandwani: return unresolved string here, and add manual resolver to check the correct resolution + Host: "proxyURL.com", //returning unresolved string here, and add manual resolver to check the correct resolution }, nil } return nil, nil } defer overwrite(hpfe)() - // mrTarget := setupDNS(t) + mrProxy := setupDNS(t) - // set proxy env variable - origHTTPSProxy := envconfig.HTTPSProxy - envconfig.HTTPSProxy = "proxyexample.com" - defer func() { envconfig.HTTPSProxy = origHTTPSProxy }() // Dial to proxy server. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -215,7 +227,7 @@ func (s) TestGrpcDialWithProxyandResolution(t *testing.T) { // Send an RPC to the backend through the proxy. client := testpb.NewTestServiceClient(conn) if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("EmptyCall failed: %v", err) + t.Errorf("EmptyCall failed: %v", err) } // Check if the proxy encountered any errors. select { @@ -229,22 +241,23 @@ func (s) TestGrpcDialWithProxyandResolution(t *testing.T) { // TestGrpcNewClientWithProxy tests grpc.NewClient with no resolver mentioned in targetURI with a proxy func (s) TestGrpcNewClientWithProxy(t *testing.T) { - //Create and start a backend server - backend := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, - } - if err := backend.StartServer(); err != nil { - t.Fatalf("Failed to start backend: %v", err) - } - t.Logf("Started TestService backend at: %q", backend.Address) - t.Cleanup(func() { backend.Stop() }) - backendAddr := backend.Address + // Set up a channel to receive signals from OnClientResolution. + resolutionCh := make(chan bool, 1) + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution // Access using package name + delegatingresolver.OnClientResolution = func(int) { // Access using package name + resolutionCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) // Restore using package name + //Create and start a backend server + backendAddr := createAndStartBackendServer(t) proxyLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } errCh := make(chan error, 1) + doneCh := make(chan struct{}) defer func() { close(errCh) }() reqCheck := func(req *http.Request) error { if req.Method != http.MethodConnect { @@ -256,9 +269,12 @@ func (s) TestGrpcNewClientWithProxy(t *testing.T) { return nil } fmt.Printf("backendaddr sent: %v\n", backendAddr) - p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, backendAddr) + p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, doneCh, backendAddr, false) t.Cleanup(func() { p.Stop() }) + // set proxy env variable + defer overwriteAndRestoreProxyEnv("proxyexample.com")() + // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { if req.URL.Host == unresProxyURI { @@ -270,10 +286,7 @@ func (s) TestGrpcNewClientWithProxy(t *testing.T) { return nil, nil } defer overwrite(hpfe)() - // set proxy env variable - origHTTPSProxy := envconfig.HTTPSProxy - envconfig.HTTPSProxy = "proxyexample.com" - defer func() { envconfig.HTTPSProxy = origHTTPSProxy }() + // Dial to proxy server. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() @@ -286,7 +299,103 @@ func (s) TestGrpcNewClientWithProxy(t *testing.T) { // Send an RPC to the backend through the proxy. client := testpb.NewTestServiceClient(conn) if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("EmptyCall failed: %v", err) + t.Errorf("EmptyCall failed: %v", err) //t.Errorf + } + // Check if the proxy encountered any errors. + //emchandwani : put in defer block where channel is created. + select { + case err := <-errCh: + t.Fatalf("proxy server encountered an error: %v", err) + default: + t.Logf("proxy server succeeded") + } + select { + case <-resolutionCh: + // Success: OnClientResolution was called. + t.Error("Client side resolution called") + default: + } + +} + +// TestGrpcNewClientWithProxyAndCustomResolver tests grpc.NewClient with +// resolver other than "dns" mentioned in targetURI with a proxy. +func (s) TestGrpcNewClientWithProxyAndCustomResolver(t *testing.T) { + // Set up a channel to receive signals from OnClientResolution. + resolutionCh := make(chan bool, 1) + + // Overwrite OnClientResolution to send a signal to the channel. + origOnClientResolution := delegatingresolver.OnClientResolution // Access using package name + delegatingresolver.OnClientResolution = func(int) { // Access using package name + resolutionCh <- true + } + t.Cleanup(func() { delegatingresolver.OnClientResolution = origOnClientResolution }) // Restore using package name + //Create and start a backend server + r := manual.NewBuilderWithScheme("whatever") + resolver.Register(r) + backendAddr := createAndStartBackendServer(t) + proxyLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + defer func() { close(errCh) }() + reqCheck := func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + if req.URL.Host != backendAddr { + return fmt.Errorf("unexpected URL.Host %q, want %q", req.URL.Host, backendAddr) + } + return nil + } + fmt.Printf("backendaddr sent: %v\n", backendAddr) + p := testutils.NewProxyServer(proxyLis, reqCheck, errCh, doneCh, backendAddr, true) + t.Cleanup(func() { p.Stop() }) + + // set proxy env variable + defer overwriteAndRestoreProxyEnv("proxyexample.com")() + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + fmt.Printf("req.URL.Host: %v\n", req.URL.Host) + if req.URL.Host == unresProxyURI { + return &url.URL{ + Scheme: "https", + Host: proxyLis.Addr().String(), + }, nil + } + return nil, nil + } + defer overwrite(hpfe)() + r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + // mrTarget := manual.NewBuilderWithScheme("whatever") + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + // grpc.WithResolvers(r), + } + // Dial to proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // At this point, the resolver has not returned any addresses to the channel. + // This RPC must block until the context expires. + + fmt.Printf("env config: %v\n", os.Getenv("HTTPS_PROXY")) + conn, err := grpc.NewClient(r.Scheme()+":///"+unresProxyURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient() failed: %v", err) + } + + t.Cleanup(func() { conn.Close() }) + + client := testgrpc.NewTestServiceClient(conn) + // sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + // defer sCancel() + // r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: backendAddr}}}) + + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Errorf("EmptyCall() failed: %v", err) } // Check if the proxy encountered any errors. select { @@ -295,4 +404,11 @@ func (s) TestGrpcNewClientWithProxy(t *testing.T) { default: t.Logf("proxy server succeeded") } + // Check if client-side resolution signal was sent to the channel. + select { + case <-resolutionCh: + // Success: OnClientResolution was called. + default: + t.Error("Client side resolution should be called but isn't") + } }