From b341b96fb646e3439a582bd84d666f2727ad74c8 Mon Sep 17 00:00:00 2001 From: Kevin McDermott Date: Thu, 28 Mar 2024 10:05:59 +0000 Subject: [PATCH] Handle transport errors in retryablehttp. Currently, if there's an error trying to connect to the upstream API, the actual error message is dropped, and it falls back to the underlying Go client which warns that the Transport returned no error and no response. Co-authored-by: Tom Bamford --- msgraph/client.go | 4 +++ msgraph/client_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 msgraph/client_test.go diff --git a/msgraph/client.go b/msgraph/client.go index 984ab655..b9119c21 100644 --- a/msgraph/client.go +++ b/msgraph/client.go @@ -58,6 +58,10 @@ type Uri struct { // RetryableErrorHandler ensures that the response is returned after exhausting retries for a request // We can't return an error here, or net/http will not return the response func RetryableErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) { + if resp == nil { + return nil, err + } + return resp, nil } diff --git a/msgraph/client_test.go b/msgraph/client_test.go new file mode 100644 index 00000000..e0fed930 --- /dev/null +++ b/msgraph/client_test.go @@ -0,0 +1,72 @@ +package msgraph + +import ( + "context" + "fmt" + "log" + "net" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" +) + +func TestClient_GetWithError(t *testing.T) { + // This creates a listener on a random available port. + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + port := l.Addr().(*net.TCPAddr).Port + l.Close() + + hc := NewClient(VersionBeta) + hc.Endpoint = fmt.Sprintf("https://localhost:%d/", port) + hc.RetryableClient.RetryMax = 2 + + _, _, _, err = hc.Get(context.Background(), GetHttpRequestInput{ + ConsistencyFailureFunc: RetryOn404ConsistencyFailureFunc, + ValidStatusCodes: []int{http.StatusOK}, + Uri: Uri{ + Entity: "/users/test", + }, + }) + if err == nil { + t.Error("expected to get an error, got nil") + } + if msg := err.Error(); !strings.Contains(msg, "connect: connection refused") { + log.Fatalf("got %s, want message with 'connection refused'", msg) + } +} + +func TestClient_GetWithResponseAndError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/beta/users/test" { + n, _ := strconv.Atoi(r.FormValue("n")) + if n < 15 { + http.Redirect(w, r, fmt.Sprintf("%s?n=%d", r.URL.Path, 1), http.StatusTemporaryRedirect) + return + } + } + })) + defer ts.Close() + + hc := NewClient(VersionBeta) + hc.Endpoint = ts.URL + hc.RetryableClient.RetryMax = 2 + + _, _, _, err := hc.Get(context.Background(), GetHttpRequestInput{ + ConsistencyFailureFunc: RetryOn404ConsistencyFailureFunc, + ValidStatusCodes: []int{http.StatusOK}, + Uri: Uri{ + Entity: "/users/test", + }, + }) + if err == nil { + t.Error("expected to get an error, got nil") + } + if msg := err.Error(); !strings.Contains(msg, "stopped after 10 redirects") { + log.Fatalf("got %s, want message with 'stopped after 10 redirects'", msg) + } +}