diff --git a/cmd/dowload_test.go b/cmd/dowload_test.go index 0e8e4cd..6540ff1 100644 --- a/cmd/dowload_test.go +++ b/cmd/dowload_test.go @@ -4,10 +4,16 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "github.com/cisco-open/grabit/internal" + "net/http" + "net/http/httptest" + "os" + "path/filepath" "testing" "github.com/cisco-open/grabit/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func getSha256Integrity(content string) string { @@ -124,6 +130,56 @@ func TestRunDownloadFailsIntegrityTest(t *testing.T) { assert.Contains(t, err.Error(), "integrity mismatch") } +func TestOptimization(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("test content")) + if err != nil { + return + } + })) + defer ts.Close() + + t.Run("Valid_File_Not_Redownloaded", func(t *testing.T) { + tmpDir := test.TmpDir(t) + testUrl := ts.URL + "/valid_test.txt" + + lockPath := test.TmpFile(t, "") + lock, err := internal.NewLock(lockPath, true) + require.NoError(t, err) + + err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "valid_test.txt") + require.NoError(t, err) + + // Update the Download call to match the new signature + err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument + require.NoError(t, err) + }) + + t.Run("Invalid_File_Redownloaded", func(t *testing.T) { + tmpDir := test.TmpDir(t) + testUrl := ts.URL + "/invalid_test.txt" + + lockPath := test.TmpFile(t, "") + lock, err := internal.NewLock(lockPath, true) + require.NoError(t, err) + + err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "invalid_test.txt") + require.NoError(t, err) + + err = lock.Save() + require.NoError(t, err) + + invalidPath := filepath.Join(tmpDir, "invalid_test.txt") + err = os.WriteFile(invalidPath, []byte("corrupted"), 0644) + require.NoError(t, err) + + // Update the Download call to match the new signature + err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument + require.Error(t, err) + assert.Contains(t, err.Error(), "integrity mismatch") + }) +} + func TestRunDownloadTriesAllUrls(t *testing.T) { content := `abcdef` contentIntegrity := getSha256Integrity(content) diff --git a/internal/resource.go b/internal/resource.go index a377e87..b89b7bb 100644 --- a/internal/resource.go +++ b/internal/resource.go @@ -92,15 +92,7 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e for _, u := range l.Urls { // Download file in the target directory so that the call to // os.Rename is atomic. - lpath, err := GetUrlToDir(u, dir, ctx) - if err != nil { - downloadError = err - continue - } - err = checkIntegrityFromFile(lpath, algo, l.Integrity, u) - if err != nil { - return err - } + log.Debug().Str("URL", u).Msg("Downloading") localName := "" if l.Filename != "" { @@ -109,27 +101,61 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e localName = path.Base(u) } resPath := filepath.Join(dir, localName) - err = os.Rename(lpath, resPath) + + // Check existing file first + if _, err := os.Stat(resPath); err == nil { + // File exists, validate its integrity + if !ValidateLocalFile(resPath, l.Integrity) { + return fmt.Errorf("integrity mismatch for '%s'", resPath) + } + // Set file permissions if needed + if mode != NoFileMode { + if err := os.Chmod(resPath, mode.Perm()); err != nil { + return err + } + } + ok = true + continue + } else if !os.IsNotExist(err) { + // Handle other potential errors from os.Stat + return fmt.Errorf("failed to stat file '%s': %v", resPath, err) + } + + // Download new file + lpath, err := GetUrlToDir(u, dir, ctx) if err != nil { - return err + downloadError = fmt.Errorf("failed to download '%s': %v", u, err) + continue + } + + // Validate downloaded file + if err := checkIntegrityFromFile(lpath, algo, l.Integrity, u); err != nil { + os.Remove(lpath) + downloadError = err + continue } + + // Move to final location + if err := os.Rename(lpath, resPath); err != nil { + os.Remove(lpath) + downloadError = err + continue + } + if mode != NoFileMode { - err = os.Chmod(resPath, mode.Perm()) - if err != nil { + if err := os.Chmod(resPath, mode.Perm()); err != nil { return err } } ok = true + break } + if !ok { - if downloadError != nil { - return downloadError - } - return err + return downloadError } return nil } - func (l *Resource) Contains(url string) bool { for _, u := range l.Urls { if u == url { @@ -138,3 +164,21 @@ func (l *Resource) Contains(url string) bool { } return false } + +func ValidateLocalFile(filePath string, expectedIntegrity string) bool { + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return false + } + + algo, err := getAlgoFromIntegrity(expectedIntegrity) + if err != nil { + return false + } + + fileIntegrity, err := getIntegrityFromFile(filePath, algo) + if err != nil { + return false + } + + return fileIntegrity == expectedIntegrity +}