From 3f21120878cbf744fc55de774d026a7ef299f328 Mon Sep 17 00:00:00 2001 From: rabadin Date: Sat, 7 Dec 2024 10:58:52 +0100 Subject: [PATCH] Adjust code and test - Refactor resource.go code to limit code duplication - Refactor test code to use the same format as the existing tests - Move test code to resource_test.go (since it's testing resource.go) - Improve test coverage Drive-by: fix typo in test file name. --- cmd/{dowload_test.go => download_test.go} | 72 ++------------------ internal/resource.go | 82 +++++++++-------------- internal/resource_test.go | 34 ++++++++++ test/utils.go | 9 +++ 4 files changed, 77 insertions(+), 120 deletions(-) rename cmd/{dowload_test.go => download_test.go} (64%) diff --git a/cmd/dowload_test.go b/cmd/download_test.go similarity index 64% rename from cmd/dowload_test.go rename to cmd/download_test.go index 6540ff1..66d6094 100644 --- a/cmd/dowload_test.go +++ b/cmd/download_test.go @@ -1,30 +1,16 @@ package cmd import ( - "crypto/sha256" - "encoding/base64" "fmt" - "github.com/cisco-open/grabit/internal" - "net/http" - "net/http/httptest" - "os" - "path/filepath" "testing" "github.com/cisco-open/grabit/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func getSha256Integrity(content string) string { - hasher := sha256.New() - hasher.Write([]byte(content)) - return fmt.Sprintf("sha256-%s", base64.StdEncoding.EncodeToString(hasher.Sum(nil))) -} - func TestRunDownload(t *testing.T) { content := `abcdef` - contentIntegrity := getSha256Integrity(content) + contentIntegrity := test.GetSha256Integrity(content) port := test.TestHttpHandler(content, t) testfilepath := test.TmpFile(t, fmt.Sprintf(` [[Resource]] @@ -47,7 +33,7 @@ func TestRunDownload(t *testing.T) { func TestRunDownloadWithTags(t *testing.T) { content := `abcdef` - contentIntegrity := getSha256Integrity(content) + contentIntegrity := test.GetSha256Integrity(content) port := test.TestHttpHandler(content, t) testfilepath := test.TmpFile(t, fmt.Sprintf(` [[Resource]] @@ -72,7 +58,7 @@ func TestRunDownloadWithTags(t *testing.T) { func TestRunDownloadWithoutTags(t *testing.T) { content := `abcdef` - contentIntegrity := getSha256Integrity(content) + contentIntegrity := test.GetSha256Integrity(content) port := test.TestHttpHandler(content, t) testfilepath := test.TmpFile(t, fmt.Sprintf(` [[Resource]] @@ -130,59 +116,9 @@ func TestRunDownloadFailsIntegrityTest(t *testing.T) { assert.Contains(t, err.Error(), "integrity mismatch") } -func TestOptimization(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("test content")) - if err != nil { - return - } - })) - defer ts.Close() - - t.Run("Valid_File_Not_Redownloaded", func(t *testing.T) { - tmpDir := test.TmpDir(t) - testUrl := ts.URL + "/valid_test.txt" - - lockPath := test.TmpFile(t, "") - lock, err := internal.NewLock(lockPath, true) - require.NoError(t, err) - - err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "valid_test.txt") - require.NoError(t, err) - - // Update the Download call to match the new signature - err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument - require.NoError(t, err) - }) - - t.Run("Invalid_File_Redownloaded", func(t *testing.T) { - tmpDir := test.TmpDir(t) - testUrl := ts.URL + "/invalid_test.txt" - - lockPath := test.TmpFile(t, "") - lock, err := internal.NewLock(lockPath, true) - require.NoError(t, err) - - err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "invalid_test.txt") - require.NoError(t, err) - - err = lock.Save() - require.NoError(t, err) - - invalidPath := filepath.Join(tmpDir, "invalid_test.txt") - err = os.WriteFile(invalidPath, []byte("corrupted"), 0644) - require.NoError(t, err) - - // Update the Download call to match the new signature - err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument - require.Error(t, err) - assert.Contains(t, err.Error(), "integrity mismatch") - }) -} - func TestRunDownloadTriesAllUrls(t *testing.T) { content := `abcdef` - contentIntegrity := getSha256Integrity(content) + contentIntegrity := test.GetSha256Integrity(content) port := test.TestHttpHandler(content, t) testfilepath := test.TmpFile(t, fmt.Sprintf(` [[Resource]] diff --git a/internal/resource.go b/internal/resource.go index b89b7bb..d40993e 100644 --- a/internal/resource.go +++ b/internal/resource.go @@ -88,12 +88,9 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e if err != nil { return err } + var downloadError error = nil for _, u := range l.Urls { - // Download file in the target directory so that the call to - // os.Rename is atomic. - log.Debug().Str("URL", u).Msg("Downloading") - localName := "" if l.Filename != "" { localName = l.Filename @@ -102,60 +99,59 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e } resPath := filepath.Join(dir, localName) - // Check existing file first - if _, err := os.Stat(resPath); err == nil { - // File exists, validate its integrity - if !ValidateLocalFile(resPath, l.Integrity) { - return fmt.Errorf("integrity mismatch for '%s'", resPath) + // Check if the destination file already exists and has the correct integrity. + _, err := os.Stat(resPath) + if err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("error checking destination file presence '%s': '%v'", resPath, err) + } + } else { + err = checkIntegrityFromFile(resPath, algo, l.Integrity, u) + if err != nil { + return fmt.Errorf("existing file at '%s' with incorrect integrity: '%v'", resPath, err) } - // Set file permissions if needed if mode != NoFileMode { - if err := os.Chmod(resPath, mode.Perm()); err != nil { + err = os.Chmod(resPath, mode.Perm()) + if err != nil { return err } } - ok = true - continue - } else if !os.IsNotExist(err) { - // Handle other potential errors from os.Stat - return fmt.Errorf("failed to stat file '%s': %v", resPath, err) + return nil } - // Download new file + // Download file in the target directory so that the call to + // os.Rename is atomic. lpath, err := GetUrlToDir(u, dir, ctx) if err != nil { - downloadError = fmt.Errorf("failed to download '%s': %v", u, err) - continue - } - - // Validate downloaded file - if err := checkIntegrityFromFile(lpath, algo, l.Integrity, u); err != nil { - os.Remove(lpath) downloadError = err continue } - - // Move to final location - if err := os.Rename(lpath, resPath); err != nil { - os.Remove(lpath) - downloadError = err - continue + err = checkIntegrityFromFile(lpath, algo, l.Integrity, u) + if err != nil { + return err + } + err = os.Rename(lpath, resPath) + if err != nil { + return err } - if mode != NoFileMode { - if err := os.Chmod(resPath, mode.Perm()); err != nil { + err = os.Chmod(resPath, mode.Perm()) + if err != nil { return err } } ok = true break } - if !ok { - return downloadError + if downloadError != nil { + return downloadError + } + return err } return nil } + func (l *Resource) Contains(url string) bool { for _, u := range l.Urls { if u == url { @@ -164,21 +160,3 @@ func (l *Resource) Contains(url string) bool { } return false } - -func ValidateLocalFile(filePath string, expectedIntegrity string) bool { - if _, err := os.Stat(filePath); os.IsNotExist(err) { - return false - } - - algo, err := getAlgoFromIntegrity(expectedIntegrity) - if err != nil { - return false - } - - fileIntegrity, err := getIntegrityFromFile(filePath, algo) - if err != nil { - return false - } - - return fileIntegrity == expectedIntegrity -} diff --git a/internal/resource_test.go b/internal/resource_test.go index 80a11e9..5e9b69e 100644 --- a/internal/resource_test.go +++ b/internal/resource_test.go @@ -4,8 +4,11 @@ package internal import ( + "context" "fmt" "net/http" + "os" + "path/filepath" "testing" "github.com/cisco-open/grabit/test" @@ -50,3 +53,34 @@ func TestNewResourceFromUrl(t *testing.T) { } } } + +func TestResourceDownloadWithValidFileAlreadyPresent(t *testing.T) { + content := `abcdef` + contentIntegrity := test.GetSha256Integrity(content) + port := 33 // unused because the file is already present. + testFileName := "test.html" + resource := Resource{Urls: []string{fmt.Sprintf("http://localhost:%d/%s", port, testFileName)}, Integrity: contentIntegrity, Tags: []string{}, Filename: ""} + outputDir := test.TmpDir(t) + err := os.WriteFile(filepath.Join(outputDir, testFileName), []byte(content), 0644) + assert.Nil(t, err) + err = resource.Download(outputDir, 0644, context.Background()) + assert.Nil(t, err) + for _, file := range []string{testFileName} { + test.AssertFileContains(t, fmt.Sprintf("%s/%s", outputDir, file), content) + } +} + +func TestResourceDownloadWithInValidFileAlreadyPresent(t *testing.T) { + content := `abcdef` + contentIntegrity := test.GetSha256Integrity(content) + port := 33 // unused because the file, although invalid, is already present. + testFileName := "test.html" + resource := Resource{Urls: []string{fmt.Sprintf("http://localhost:%d/%s", port, testFileName)}, Integrity: contentIntegrity, Tags: []string{}, Filename: ""} + outputDir := test.TmpDir(t) + err := os.WriteFile(filepath.Join(outputDir, testFileName), []byte("invalid"), 0644) + assert.Nil(t, err) + err = resource.Download(outputDir, 0644, context.Background()) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "integrity mismatch") + assert.Contains(t, err.Error(), "existing file") +} diff --git a/test/utils.go b/test/utils.go index 71f5b64..3653641 100644 --- a/test/utils.go +++ b/test/utils.go @@ -4,6 +4,9 @@ package test import ( + "crypto/sha256" + "encoding/base64" + "fmt" "log" "net" "net/http" @@ -14,6 +17,12 @@ import ( "github.com/stretchr/testify/assert" ) +func GetSha256Integrity(content string) string { + hasher := sha256.New() + hasher.Write([]byte(content)) + return fmt.Sprintf("sha256-%s", base64.StdEncoding.EncodeToString(hasher.Sum(nil))) +} + func TmpFile(t *testing.T, content string) string { f, err := os.CreateTemp(t.TempDir(), "test") if err != nil {