Skip to content

Commit

Permalink
Merge pull request #160 from mackerelio/ctx-1
Browse files Browse the repository at this point in the history
add context variants of requestInternal and related functions
  • Loading branch information
lufia authored Jul 26, 2024
2 parents 563c3b6 + eb55e14 commit e95beaf
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 41 deletions.
50 changes: 9 additions & 41 deletions mackerel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mackerel

import (
"bytes"
"context"
"encoding/json"
"io"
"log"
Expand Down Expand Up @@ -131,68 +132,35 @@ func (c *Client) Request(req *http.Request) (resp *http.Response, err error) {
}

func requestGet[T any](client *Client, path string) (*T, error) {
return requestNoBody[T](client, http.MethodGet, path, nil)
return requestGetContext[T](context.Background(), client, path)
}

func requestGetWithParams[T any](client *Client, path string, params url.Values) (*T, error) {
return requestNoBody[T](client, http.MethodGet, path, params)
return requestGetWithParamsContext[T](context.Background(), client, path, params)
}

func requestGetAndReturnHeader[T any](client *Client, path string) (*T, http.Header, error) {
return requestInternal[T](client, http.MethodGet, path, nil, nil)
return requestGetAndReturnHeaderContext[T](context.Background(), client, path)
}

func requestPost[T any](client *Client, path string, payload any) (*T, error) {
return requestJSON[T](client, http.MethodPost, path, payload)
return requestPostContext[T](context.Background(), client, path, payload)
}

func requestPut[T any](client *Client, path string, payload any) (*T, error) {
return requestJSON[T](client, http.MethodPut, path, payload)
return requestPutContext[T](context.Background(), client, path, payload)
}

func requestDelete[T any](client *Client, path string) (*T, error) {
return requestNoBody[T](client, http.MethodDelete, path, nil)
return requestDeleteContext[T](context.Background(), client, path)
}

func requestJSON[T any](client *Client, method, path string, payload any) (*T, error) {
var body bytes.Buffer
err := json.NewEncoder(&body).Encode(payload)
if err != nil {
return nil, err
}
data, _, err := requestInternal[T](client, method, path, nil, &body)
return data, err
}

func requestNoBody[T any](client *Client, method, path string, params url.Values) (*T, error) {
data, _, err := requestInternal[T](client, method, path, params, nil)
return data, err
return requestJSONContext[T](context.Background(), client, method, path, payload)
}

func requestInternal[T any](client *Client, method, path string, params url.Values, body io.Reader) (*T, http.Header, error) {
req, err := http.NewRequest(method, client.urlFor(path, params).String(), body)
if err != nil {
return nil, nil, err
}
if body != nil || method != http.MethodGet {
req.Header.Add("Content-Type", "application/json")
}

resp, err := client.Request(req)
if err != nil {
return nil, nil, err
}
defer func() {
io.Copy(io.Discard, resp.Body) // nolint
resp.Body.Close()
}()

var data T
err = json.NewDecoder(resp.Body).Decode(&data)
if err != nil {
return nil, nil, err
}
return &data, resp.Header, nil
return requestInternalContext[T](context.Background(), client, method, path, params, body)
}

func (c *Client) compatRequestJSON(method string, path string, payload interface{}) (*http.Response, error) {
Expand Down
75 changes: 75 additions & 0 deletions mackerel_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package mackerel

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
)

func requestGetContext[T any](ctx context.Context, client *Client, path string) (*T, error) {
return requestNoBodyContext[T](ctx, client, http.MethodGet, path, nil)
}

func requestGetWithParamsContext[T any](ctx context.Context, client *Client, path string, params url.Values) (*T, error) {
return requestNoBodyContext[T](ctx, client, http.MethodGet, path, params)
}

func requestGetAndReturnHeaderContext[T any](ctx context.Context, client *Client, path string) (*T, http.Header, error) {
return requestInternalContext[T](ctx, client, http.MethodGet, path, nil, nil)
}

func requestPostContext[T any](ctx context.Context, client *Client, path string, payload any) (*T, error) {
return requestJSONContext[T](ctx, client, http.MethodPost, path, payload)
}

func requestPutContext[T any](ctx context.Context, client *Client, path string, payload any) (*T, error) {
return requestJSONContext[T](ctx, client, http.MethodPut, path, payload)
}

func requestDeleteContext[T any](ctx context.Context, client *Client, path string) (*T, error) {
return requestNoBodyContext[T](ctx, client, http.MethodDelete, path, nil)
}

func requestJSONContext[T any](ctx context.Context, client *Client, method, path string, payload any) (*T, error) {
var body bytes.Buffer
err := json.NewEncoder(&body).Encode(payload)
if err != nil {
return nil, err
}
data, _, err := requestInternalContext[T](ctx, client, method, path, nil, &body)
return data, err
}

func requestNoBodyContext[T any](ctx context.Context, client *Client, method, path string, params url.Values) (*T, error) {
data, _, err := requestInternalContext[T](context.Background(), client, method, path, params, nil)
return data, err
}

func requestInternalContext[T any](ctx context.Context, client *Client, method, path string, params url.Values, body io.Reader) (*T, http.Header, error) {
req, err := http.NewRequestWithContext(ctx, method, client.urlFor(path, params).String(), body)
if err != nil {
return nil, nil, err
}
if body != nil || method != http.MethodGet {
req.Header.Add("Content-Type", "application/json")
}

resp, err := client.Request(req)
if err != nil {
return nil, nil, err
}
defer func() {
io.Copy(io.Discard, resp.Body) // nolint
resp.Body.Close()
}()

var data T
err = json.NewDecoder(resp.Body).Decode(&data)
if err != nil {
return nil, nil, err
}
return &data, resp.Header, nil
}
21 changes: 21 additions & 0 deletions mackerel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mackerel

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
Expand All @@ -11,6 +13,7 @@ import (
"os"
"strings"
"testing"
"time"
)

func TestRequest(t *testing.T) {
Expand Down Expand Up @@ -77,6 +80,24 @@ func Test_requestInternal(t *testing.T) {
}
}

func Test_requestInternalContext_cancel(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
time.Sleep(1 * time.Second)
fmt.Fprint(res, "ok")
}))
t.Cleanup(ts.Close)

client, _ := NewClientWithOptions("dummy-key", ts.URL, false)
ctx, cancel := context.WithCancelCause(context.Background())
expectedErr := errors.New("expected error")
cancel(expectedErr)

_, _, err := requestInternalContext[struct{}](ctx, client, "GET", "/", nil, nil)
if cause := context.Cause(ctx); err == nil || !errors.Is(cause, expectedErr) {
t.Errorf("got %v; want %v", cause, expectedErr)
}
}

func TestUrlFor(t *testing.T) {
client, _ := NewClientWithOptions("dummy-key", "https://example.com/with/ignored/path", false)
expected := "https://example.com/some/super/endpoint"
Expand Down

0 comments on commit e95beaf

Please sign in to comment.