diff --git a/supernode/daemon/mgr/cdn/downloader.go b/supernode/daemon/mgr/cdn/downloader.go index 7da96b55f..d2fa15b58 100644 --- a/supernode/daemon/mgr/cdn/downloader.go +++ b/supernode/daemon/mgr/cdn/downloader.go @@ -35,7 +35,7 @@ import ( // Body which the caller is expected to close. func (cm *Manager) download(ctx context.Context, taskID, url string, headers map[string]string, startPieceNum int, httpFileLength int64, pieceContSize int32) (*http.Response, error) { - var checkCode = http.StatusOK | http.StatusPartialContent + checkCode := []int{http.StatusOK, http.StatusPartialContent} if startPieceNum > 0 { breakRange, err := util.CalculateBreakRange(startPieceNum, int(pieceContSize), httpFileLength) @@ -50,9 +50,20 @@ func (cm *Manager) download(ctx context.Context, taskID, url string, headers map if _, ok := headers["Range"]; !ok { headers["Range"] = httputils.ConstructRangeStr(breakRange) } - checkCode = http.StatusPartialContent + checkCode = []int{http.StatusPartialContent} } logrus.Infof("start to download for taskId(%s) with fileUrl: %s header: %v checkCode: %d", taskID, url, headers, checkCode) - return cm.originClient.Download(url, headers, checkCode) + return cm.originClient.Download(url, headers, checkStatusCode(checkCode)) +} + +func checkStatusCode(statusCode []int) func(int) bool { + return func(status int) bool { + for _, s := range statusCode { + if status == s { + return true + } + } + return false + } } diff --git a/supernode/daemon/mgr/cdn/downloader_test.go b/supernode/daemon/mgr/cdn/downloader_test.go index 7ba4c0e74..61ad3308a 100644 --- a/supernode/daemon/mgr/cdn/downloader_test.go +++ b/supernode/daemon/mgr/cdn/downloader_test.go @@ -22,6 +22,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "reflect" "testing" "github.com/dragonflyoss/Dragonfly/pkg/errortypes" @@ -114,3 +115,48 @@ func (s *CDNDownloadTestSuite) TestDownload(c *check.C) { c.Check(string(result), check.Equals, string(v.exceptedBody)) } } + +func Test_checkStatusCode(t *testing.T) { + type args struct { + statusCode []int + targetStatusCode int + } + tests := []struct { + name string + args args + statusCode int + want bool + }{ + { + name: "200", + args: args{ + statusCode: []int{http.StatusOK}, + targetStatusCode: 200, + }, + want: true, + }, + { + name: "200|206", + args: args{ + statusCode: []int{http.StatusOK, http.StatusPartialContent}, + targetStatusCode: 206, + }, + want: true, + }, + { + name: "204", + args: args{ + statusCode: []int{http.StatusOK, http.StatusPartialContent}, + targetStatusCode: 204, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := checkStatusCode(tt.args.statusCode)(tt.args.targetStatusCode); !reflect.DeepEqual(got, tt.want) { + t.Errorf("checkStatusCode() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/supernode/httpclient/mock/mock_origin_http_client.go b/supernode/httpclient/mock/mock_origin_http_client.go index 8000928b1..87168be56 100644 --- a/supernode/httpclient/mock/mock_origin_http_client.go +++ b/supernode/httpclient/mock/mock_origin_http_client.go @@ -10,6 +10,8 @@ import ( strfmt "github.com/go-openapi/strfmt" gomock "github.com/golang/mock/gomock" + + "github.com/dragonflyoss/Dragonfly/supernode/httpclient" ) // MockOriginHTTPClient is a mock of OriginHTTPClient interface @@ -94,7 +96,7 @@ func (mr *MockOriginHTTPClientMockRecorder) IsExpired(url, headers, lastModified } // Download mocks base method -func (m *MockOriginHTTPClient) Download(url string, headers map[string]string, checkCode int) (*http.Response, error) { +func (m *MockOriginHTTPClient) Download(url string, headers map[string]string, checkCode httpclient.StatusCodeChecker) (*http.Response, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Download", url, headers, checkCode) ret0, _ := ret[0].(*http.Response) diff --git a/supernode/httpclient/origin_http_client.go b/supernode/httpclient/origin_http_client.go index 9a957bc0a..8956873b3 100644 --- a/supernode/httpclient/origin_http_client.go +++ b/supernode/httpclient/origin_http_client.go @@ -35,13 +35,15 @@ import ( "github.com/pkg/errors" ) +type StatusCodeChecker func(int) bool + // OriginHTTPClient supply apis that interact with the source. type OriginHTTPClient interface { RegisterTLSConfig(rawURL string, insecure bool, caBlock []strfmt.Base64) GetContentLength(url string, headers map[string]string) (int64, int, error) IsSupportRange(url string, headers map[string]string) (bool, error) IsExpired(url string, headers map[string]string, lastModified int64, eTag string) (bool, error) - Download(url string, headers map[string]string, checkCode int) (*http.Response, error) + Download(url string, headers map[string]string, checkCode StatusCodeChecker) (*http.Response, error) } // OriginClient is an implementation of the interface of OriginHTTPClient. @@ -156,14 +158,14 @@ func (client *OriginClient) IsExpired(url string, headers map[string]string, las } // Download downloads the file from the original address -func (client *OriginClient) Download(url string, headers map[string]string, checkCode int) (*http.Response, error) { +func (client *OriginClient) Download(url string, headers map[string]string, checkCode StatusCodeChecker) (*http.Response, error) { // TODO: add timeout resp, err := client.HTTPWithHeaders("GET", url, headers, 0) if err != nil { return nil, err } - if (resp.StatusCode & checkCode) == resp.StatusCode { + if checkCode(resp.StatusCode) { return resp, nil } return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)