Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eshitachandwani committed Nov 26, 2024
1 parent 445664c commit 2e13c40
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 97 deletions.
11 changes: 11 additions & 0 deletions internal/resolver/delegatingresolver/delegatingresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@ func mapAddress(address string) (*url.URL, error) {
return url, nil
}

// OnClientResolution is a no-op function in non-test code. In tests, it can
// be overwritten to send a signal to a channel, indicating that client-side
// name resolution was triggered. This enables tests to verify that resolution
// is bypassed when a proxy is in use.
var OnClientResolution = func(int) { /* no-op */ }

var OnDelegatingResolverCalled = func(int) { /* no-op */ }

func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions, targetResolverBuilder resolver.Builder, targetResolutionEnabled bool) (resolver.Resolver, error) {
OnDelegatingResolverCalled(1)
r := &delegatingResolver{
target: target,
cc: cc,
Expand All @@ -84,6 +93,7 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
if r.proxyURL == nil {
// Add an info log here.
logger.Info("No proxy URL detected")
OnClientResolution(1)
return targetResolverBuilder.Build(target, cc, opts)
}

Expand All @@ -95,6 +105,7 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti
} else {
// targetResolverBuilder must not be nil. If it is nil, the channel should
// error out and not call this function.
OnClientResolution(1)
r.targetResolver, err = targetResolverBuilder.Build(target, &wrappingClientConn{r, "target"}, opts)
if err != nil {
return nil, fmt.Errorf("delegating_resolver: unable to build the target resolver : %v", err)
Expand Down
114 changes: 63 additions & 51 deletions internal/testutils/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,56 +38,8 @@ type ProxyServer struct {
requestCheck func(*http.Request) error
}

func (p *ProxyServer) run(errCh chan error, backendAddr string) {
fmt.Printf("run proxy server")
in, err := p.lis.Accept()
if err != nil {
return
}
p.in = in
func (p *ProxyServer) run(errCh chan error, doneCh chan struct{}, backendAddr string) {

req, err := http.ReadRequest(bufio.NewReader(in))
if err != nil {
errCh <- fmt.Errorf("failed to read CONNECT req: %v", err)
// p.t.Errorf("failed to read CONNECT req: %v", err)
return
}
fmt.Printf("request sent : %+v\n", req)
if err := p.requestCheck(req); err != nil {
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
resp.Write(p.in)
p.in.Close()
errCh <- fmt.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
// p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
return
}
var out net.Conn
//To mimick name resolution on proxy, if proxy gets unresolved address in
// the CONNECT request, dial to the backend address directly.
if net.ParseIP(req.URL.Host) == nil {
fmt.Printf("dialing to backend add: %v from proxy", backendAddr)
out, err = net.Dial("tcp", backendAddr)
} else {
//If proxy gets resolved address in CONNECT req,dial to the actual server
// with address sent in the connect request.
fmt.Printf("\ndialing with host in connect req\n")
out, err = net.Dial("tcp", req.URL.Host)
}
if err != nil {
errCh <- fmt.Errorf("failed to dial to server: %v", err)
// p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
//response OK to client
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
var buf bytes.Buffer
resp.Write(&buf)
p.in.Write(buf.Bytes())
p.out = out
// perform the proxy function, i.e pass the data from client to server and server to client.
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
}

func (p *ProxyServer) Stop() {
Expand All @@ -100,12 +52,72 @@ func (p *ProxyServer) Stop() {
}
}

func NewProxyServer(lis net.Listener, requestCheck func(*http.Request) error, errCh chan error, backendAddr string) *ProxyServer {
func NewProxyServer(lis net.Listener, requestCheck func(*http.Request) error, errCh chan error, doneCh chan struct{}, backendAddr string, resolutionOnClient bool) *ProxyServer {
fmt.Printf("starting proxy server")
p := &ProxyServer{
lis: lis,
requestCheck: requestCheck,
}
go p.run(errCh, backendAddr)
go func() {
fmt.Printf("run proxy server")
in, err := p.lis.Accept()
if err != nil {
return
}
p.in = in

req, err := http.ReadRequest(bufio.NewReader(in))
if err != nil {
errCh <- fmt.Errorf("failed to read CONNECT req: %v", err)
// p.t.Errorf("failed to read CONNECT req: %v", err)
return
}
fmt.Printf("request sent : %+v\n", req)
if err := p.requestCheck(req); err != nil {
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
resp.Write(p.in)
p.in.Close()
errCh <- fmt.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
// p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
return
}
var out net.Conn
//To mimick name resolution on proxy, if proxy gets unresolved address in
// the CONNECT request, dial to the backend address directly.
// if net.ParseIP(req.URL.Host) == nil {
// fmt.Printf("dialing to backend add: %v from proxy", backendAddr)
// out, err = net.Dial("tcp", backendAddr)
// } else {
// //If proxy gets resolved address in CONNECT req,dial to the actual server
// // with address sent in the connect request.
// fmt.Printf("\ndialing with host in connect req\n")
// out, err = net.Dial("tcp", req.URL.Host)
// }

if resolutionOnClient {
fmt.Printf("resolving host in connect req\n")
out, err = net.Dial("tcp", req.URL.Host)
} else {
fmt.Printf("not resolving host in connect req\n")
out, err = net.Dial("tcp",backendAddr)
}

if err != nil {
errCh <- fmt.Errorf("failed to dial to server: %v", err)
// p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
//response OK to client
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
var buf bytes.Buffer
resp.Write(&buf)
p.in.Write(buf.Bytes())
p.out = out
// perform the proxy function, i.e pass the data from client to server and server to client.
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
close(doneCh)
}()
return p
}
Loading

0 comments on commit 2e13c40

Please sign in to comment.