diff --git a/supernode/daemon/mgr/cdn/downloader.go b/supernode/daemon/mgr/cdn/downloader.go index cf8fde97b..076f06b37 100644 --- a/supernode/daemon/mgr/cdn/downloader.go +++ b/supernode/daemon/mgr/cdn/downloader.go @@ -26,6 +26,7 @@ import ( errorType "github.com/dragonflyoss/Dragonfly/pkg/errortypes" "github.com/dragonflyoss/Dragonfly/pkg/httputils" "github.com/dragonflyoss/Dragonfly/pkg/rangeutils" + "github.com/dragonflyoss/Dragonfly/supernode/httpclient" ) // download downloads the file from the original address and @@ -43,20 +44,29 @@ func (cm *Manager) download(ctx context.Context, taskID, url string, headers map return nil, errors.Wrapf(errorType.ErrInvalidValue, "failed to calculate the breakRange: %v", err) } - if headers == nil { - headers = make(map[string]string) - } // check if Range in header? if Range already in Header, use this range directly - if _, ok := headers["Range"]; !ok { - headers["Range"] = httputils.ConstructRangeStr(breakRange) + if !hasRange(headers) { + headers = httpclient.CopyHeader( + map[string]string{"Range": httputils.ConstructRangeStr(breakRange)}, + headers) + } checkCode = []int{http.StatusPartialContent} } - logrus.Infof("start to download for taskId(%s) with fileUrl: %s header: %v checkCode: %d", taskID, url, headers, checkCode) + logrus.Infof("start to download for taskId(%s) with fileUrl: %s"+ + " header: %v checkCode: %d", taskID, url, headers, checkCode) return cm.originClient.Download(url, headers, checkStatusCode(checkCode)) } +func hasRange(headers map[string]string) bool { + if headers == nil { + return false + } + _, ok := headers["Range"] + return ok +} + func checkStatusCode(statusCode []int) func(int) bool { return func(status int) bool { for _, s := range statusCode { diff --git a/supernode/daemon/mgr/cdn/downloader_test.go b/supernode/daemon/mgr/cdn/downloader_test.go index 36deb814d..f9101de87 100644 --- a/supernode/daemon/mgr/cdn/downloader_test.go +++ b/supernode/daemon/mgr/cdn/downloader_test.go @@ -105,7 +105,9 @@ func (s *CDNDownloadTestSuite) TestDownload(c *check.C) { } for _, v := range cases { + headers := cloneMap(v.headers) resp, err := cm.download(context.TODO(), "", ts.URL, v.headers, v.startPieceNum, v.httpFileLength, v.pieceContSize) + c.Check(headers, check.DeepEquals, v.headers) c.Check(v.errCheck(err), check.Equals, true) c.Check(resp.StatusCode, check.Equals, v.exceptedStatusCode) @@ -160,3 +162,17 @@ func Test_checkStatusCode(t *testing.T) { }) } } + +// ---------------------------------------------------------------------------- +// helper + +func cloneMap(src map[string]string) map[string]string { + if src == nil { + return nil + } + target := make(map[string]string) + for k, v := range src { + target[k] = v + } + return target +} diff --git a/supernode/httpclient/origin_http_client.go b/supernode/httpclient/origin_http_client.go index 2b3fc1600..5f6ee4a11 100644 --- a/supernode/httpclient/origin_http_client.go +++ b/supernode/httpclient/origin_http_client.go @@ -132,18 +132,16 @@ func (client *OriginClient) GetContentLength(url string, headers map[string]stri // IsSupportRange checks if the source url support partial requests. func (client *OriginClient) IsSupportRange(url string, headers map[string]string) (bool, error) { - // set headers - if headers == nil { - headers = make(map[string]string) - } - headers["Range"] = "bytes=0-0" + // set headers: headers is a reference to map, should not change it + copied := CopyHeader(nil, headers) + copied["Range"] = "bytes=0-0" // send request - resp, err := client.HTTPWithHeaders(http.MethodGet, url, headers, 4*time.Second) + resp, err := client.HTTPWithHeaders(http.MethodGet, url, copied, 4*time.Second) if err != nil { return false, err } - resp.Body.Close() + _ = resp.Body.Close() if resp.StatusCode == http.StatusPartialContent { return true, nil @@ -157,20 +155,18 @@ func (client *OriginClient) IsExpired(url string, headers map[string]string, las return true, nil } - // set headers - if headers == nil { - headers = make(map[string]string) - } + // set headers: headers is a reference to map, should not change it + copied := CopyHeader(nil, headers) if lastModified > 0 { lastModifiedStr, _ := netutils.ConvertTimeIntToString(lastModified) - headers["If-Modified-Since"] = lastModifiedStr + copied["If-Modified-Since"] = lastModifiedStr } if !stringutils.IsEmptyStr(eTag) { - headers["If-None-Match"] = eTag + copied["If-None-Match"] = eTag } // send request - resp, err := client.HTTPWithHeaders(http.MethodGet, url, headers, 4*time.Second) + resp, err := client.HTTPWithHeaders(http.MethodGet, url, copied, 4*time.Second) if err != nil { return false, err } @@ -222,3 +218,14 @@ func (client *OriginClient) HTTPWithHeaders(method, url string, headers map[stri } return httpClient.Do(req) } + +// CopyHeader copies the src to dst and return a non-nil dst map. +func CopyHeader(dst, src map[string]string) map[string]string { + if dst == nil { + dst = make(map[string]string) + } + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/supernode/httpclient/origin_http_client_test.go b/supernode/httpclient/origin_http_client_test.go index 971f04466..5b066a35f 100644 --- a/supernode/httpclient/origin_http_client_test.go +++ b/supernode/httpclient/origin_http_client_test.go @@ -98,3 +98,16 @@ func (s *OriginHTTPClientTestSuite) TestRegisterTLSConfig(c *check.C) { c.Assert(resp, check.NotNil) c.Assert(resp.ContentLength, check.Equals, int64(-1)) } + +func (s *OriginHTTPClientTestSuite) TestCopyHeader(c *check.C) { + dst := CopyHeader(nil, nil) + c.Check(dst, check.NotNil) + c.Check(len(dst), check.Equals, 0) + + src := map[string]string{"test": "1"} + dst = CopyHeader(nil, src) + c.Check(dst, check.DeepEquals, src) + + dst["test"] = "2" + c.Check(src["test"], check.Equals, "1") +}