Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checked HTTP client #369

Merged
merged 14 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions checked_http_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package main

import (
"fmt"
"net/http"
"net/url"
)

type checkedHttpClient struct {
client httpClient
acceptedStatusCodes statusCodeSet
}

func newCheckedHttpClient(c httpClient, acceptedStatusCodes statusCodeSet) httpClient {
return &checkedHttpClient{c, acceptedStatusCodes}
}

func (c *checkedHttpClient) Get(u *url.URL, header http.Header) (httpResponse, error) {
r, err := c.client.Get(u, header)
if err != nil {
return nil, err
} else if code := r.StatusCode(); !c.acceptedStatusCodes.Contains(code) {
return nil, fmt.Errorf("%v", code)
}

return r, nil
}
44 changes: 44 additions & 0 deletions checked_http_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func TestCheckedHttpClientFailWithValidStatusCode(t *testing.T) {
u, err := url.Parse(testUrl)

assert.Nil(t, err)

r, err := newCheckedHttpClient(
newFakeHttpClient(
func(u *url.URL) (*fakeHttpResponse, error) {
return newFakeHttpResponse(200, testUrl, nil, nil), nil
},
),
statusCodeSet{{200, 201}: {}},
).Get(u, nil)

assert.Nil(t, err)
assert.NotNil(t, r)
}

func TestCheckedHttpClientFailWithInvalidStatusCode(t *testing.T) {
u, err := url.Parse(testUrl)

assert.Nil(t, err)

r, err := newCheckedHttpClient(
newFakeHttpClient(
func(u *url.URL) (*fakeHttpResponse, error) {
return newFakeHttpResponse(404, testUrl, nil, nil), nil
},
),
statusCodeSet{{200, 201}: {}},
).Get(u, nil)

assert.Nil(t, r)
assert.Equal(t, err.Error(), "404")
}
34 changes: 18 additions & 16 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,26 @@ func (c *command) runWithError(ss []string) (bool, error) {
return true, nil
}

client := newRedirectHttpClient(
newThrottledHttpClient(
c.httpClientFactory.Create(
httpClientOptions{
MaxConnectionsPerHost: args.MaxConnectionsPerHost,
MaxResponseBodySize: args.MaxResponseBodySize,
BufferSize: args.BufferSize,
Proxy: args.Proxy,
SkipTLSVerification: args.SkipTLSVerification,
Timeout: time.Duration(args.Timeout) * time.Second,
Header: args.Header,
},
client := newCheckedHttpClient(
newRedirectHttpClient(
newThrottledHttpClient(
c.httpClientFactory.Create(
httpClientOptions{
MaxConnectionsPerHost: args.MaxConnectionsPerHost,
MaxResponseBodySize: args.MaxResponseBodySize,
BufferSize: args.BufferSize,
Proxy: args.Proxy,
SkipTLSVerification: args.SkipTLSVerification,
Timeout: time.Duration(args.Timeout) * time.Second,
Header: args.Header,
},
),
args.RateLimit,
args.MaxConnections,
args.MaxConnectionsPerHost,
),
args.RateLimit,
args.MaxConnections,
args.MaxConnectionsPerHost,
args.MaxRedirections,
),
args.MaxRedirections,
args.AcceptedStatusCodes,
)

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/raviqqe/muffet/v2

go 1.19
go 1.22.0

require (
github.com/andybalholm/brotli v1.1.0
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHY
github.com/yhat/scrape v0.0.0-20161128144610-24b7890b0945 h1:6Ju8pZBYFTN9FaV/JvNBiIHcsgEmP4z4laciqjfjY8E=
github.com/yhat/scrape v0.0.0-20161128144610-24b7890b0945/go.mod h1:4vRFPPNYllgCacoj+0FoKOjTW68rUhEfqPLiEJaK2w8=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/ratelimit v0.3.0 h1:IdZd9wqvFXnvLvSEBo0KPcGfkoBGNkpTHlrE3Rcjkjw=
go.uber.org/ratelimit v0.3.0/go.mod h1:So5LG7CV1zWpY1sHe+DXTJqQvOx+FFPFaAs2SnoyBaI=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
Expand Down
61 changes: 20 additions & 41 deletions redirect_http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import (
)

type redirectHttpClient struct {
client httpClient
maxRedirections int
acceptedStatusCodes statusCodeSet
client httpClient
maxRedirections int
}

func newRedirectHttpClient(c httpClient, maxRedirections int, acceptedStatusCodes statusCodeSet) httpClient {
return &redirectHttpClient{c, maxRedirections, acceptedStatusCodes}
func newRedirectHttpClient(c httpClient, maxRedirections int) httpClient {
return &redirectHttpClient{c, maxRedirections}
}

func (c *redirectHttpClient) Get(u *url.URL, header http.Header) (httpResponse, error) {
Expand All @@ -24,59 +23,39 @@ func (c *redirectHttpClient) Get(u *url.URL, header http.Header) (httpResponse,
}

cj, err := cookiejar.New(nil)

if err != nil {
return nil, err
}

i := 0

for {
for i := range c.maxRedirections + 1 {
for _, c := range cj.Cookies(u) {
header.Add("cookie", c.String())
}

r, err := c.client.Get(u, header)
if err != nil {
return nil, c.formatError(err, i, u)
}

code := r.StatusCode()

if c.acceptedStatusCodes.Contains(code) {
if err != nil && i == 0 {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("%w (following redirect %v)", err, u.String())
} else if c := r.StatusCode(); c < 300 || c >= 400 {
return r, nil
} else if code >= 300 && code <= 399 {
i++

if i > c.maxRedirections {
return nil, errors.New("too many redirections")
}

s := r.Header("Location")

if len(s) == 0 {
return nil, errors.New("location header not set")
}
}

u, err = u.Parse(s)
s := r.Header("Location")

if err != nil {
return nil, err
}
if len(s) == 0 {
return nil, errors.New("location header not set")
}

cj.SetCookies(u, parseCookies(r.Header("set-cookie")))
} else {
return nil, c.formatError(fmt.Errorf("%v", code), i, u)
u, err = u.Parse(s)
if err != nil {
return nil, err
}
}
}

func (*redirectHttpClient) formatError(err error, redirections int, u *url.URL) error {
if redirections == 0 {
return err
cj.SetCookies(u, parseCookies(r.Header("set-cookie")))
}

return fmt.Errorf("%w (following redirect %v)", err, u.String())
return nil, errors.New("too many redirections")
}

func parseCookies(s string) []*http.Cookie {
Expand Down
Loading
Loading