Skip to content

Commit

Permalink
Adjust code and test
Browse files Browse the repository at this point in the history
- Refactor resource.go code to limit code duplication
- Refactor test code to use the same format as the existing tests
- Move test code to resource_test.go (since it's testing resource.go)
- Improve test coverage

Drive-by: fix typo in test file name.
  • Loading branch information
rabadin committed Dec 11, 2024
1 parent 97f38ea commit 3f21120
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 120 deletions.
72 changes: 4 additions & 68 deletions cmd/dowload_test.go → cmd/download_test.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,16 @@
package cmd

import (
"crypto/sha256"
"encoding/base64"
"fmt"
"github.com/cisco-open/grabit/internal"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

"github.com/cisco-open/grabit/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func getSha256Integrity(content string) string {
hasher := sha256.New()
hasher.Write([]byte(content))
return fmt.Sprintf("sha256-%s", base64.StdEncoding.EncodeToString(hasher.Sum(nil)))
}

func TestRunDownload(t *testing.T) {
content := `abcdef`
contentIntegrity := getSha256Integrity(content)
contentIntegrity := test.GetSha256Integrity(content)
port := test.TestHttpHandler(content, t)
testfilepath := test.TmpFile(t, fmt.Sprintf(`
[[Resource]]
Expand All @@ -47,7 +33,7 @@ func TestRunDownload(t *testing.T) {

func TestRunDownloadWithTags(t *testing.T) {
content := `abcdef`
contentIntegrity := getSha256Integrity(content)
contentIntegrity := test.GetSha256Integrity(content)
port := test.TestHttpHandler(content, t)
testfilepath := test.TmpFile(t, fmt.Sprintf(`
[[Resource]]
Expand All @@ -72,7 +58,7 @@ func TestRunDownloadWithTags(t *testing.T) {

func TestRunDownloadWithoutTags(t *testing.T) {
content := `abcdef`
contentIntegrity := getSha256Integrity(content)
contentIntegrity := test.GetSha256Integrity(content)
port := test.TestHttpHandler(content, t)
testfilepath := test.TmpFile(t, fmt.Sprintf(`
[[Resource]]
Expand Down Expand Up @@ -130,59 +116,9 @@ func TestRunDownloadFailsIntegrityTest(t *testing.T) {
assert.Contains(t, err.Error(), "integrity mismatch")
}

func TestOptimization(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("test content"))
if err != nil {
return
}
}))
defer ts.Close()

t.Run("Valid_File_Not_Redownloaded", func(t *testing.T) {
tmpDir := test.TmpDir(t)
testUrl := ts.URL + "/valid_test.txt"

lockPath := test.TmpFile(t, "")
lock, err := internal.NewLock(lockPath, true)
require.NoError(t, err)

err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "valid_test.txt")
require.NoError(t, err)

// Update the Download call to match the new signature
err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument
require.NoError(t, err)
})

t.Run("Invalid_File_Redownloaded", func(t *testing.T) {
tmpDir := test.TmpDir(t)
testUrl := ts.URL + "/invalid_test.txt"

lockPath := test.TmpFile(t, "")
lock, err := internal.NewLock(lockPath, true)
require.NoError(t, err)

err = lock.AddResource([]string{testUrl}, internal.RecommendedAlgo, nil, "invalid_test.txt")
require.NoError(t, err)

err = lock.Save()
require.NoError(t, err)

invalidPath := filepath.Join(tmpDir, "invalid_test.txt")
err = os.WriteFile(invalidPath, []byte("corrupted"), 0644)
require.NoError(t, err)

// Update the Download call to match the new signature
err = lock.Download(tmpDir, nil, nil, "", false) // Added 'false' for the new boolean argument
require.Error(t, err)
assert.Contains(t, err.Error(), "integrity mismatch")
})
}

func TestRunDownloadTriesAllUrls(t *testing.T) {
content := `abcdef`
contentIntegrity := getSha256Integrity(content)
contentIntegrity := test.GetSha256Integrity(content)
port := test.TestHttpHandler(content, t)
testfilepath := test.TmpFile(t, fmt.Sprintf(`
[[Resource]]
Expand Down
82 changes: 30 additions & 52 deletions internal/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,9 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e
if err != nil {
return err
}

var downloadError error = nil
for _, u := range l.Urls {
// Download file in the target directory so that the call to
// os.Rename is atomic.
log.Debug().Str("URL", u).Msg("Downloading")

localName := ""
if l.Filename != "" {
localName = l.Filename
Expand All @@ -102,60 +99,59 @@ func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) e
}
resPath := filepath.Join(dir, localName)

// Check existing file first
if _, err := os.Stat(resPath); err == nil {
// File exists, validate its integrity
if !ValidateLocalFile(resPath, l.Integrity) {
return fmt.Errorf("integrity mismatch for '%s'", resPath)
// Check if the destination file already exists and has the correct integrity.
_, err := os.Stat(resPath)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("error checking destination file presence '%s': '%v'", resPath, err)
}
} else {
err = checkIntegrityFromFile(resPath, algo, l.Integrity, u)
if err != nil {
return fmt.Errorf("existing file at '%s' with incorrect integrity: '%v'", resPath, err)
}
// Set file permissions if needed
if mode != NoFileMode {
if err := os.Chmod(resPath, mode.Perm()); err != nil {
err = os.Chmod(resPath, mode.Perm())
if err != nil {
return err
}
}
ok = true
continue
} else if !os.IsNotExist(err) {
// Handle other potential errors from os.Stat
return fmt.Errorf("failed to stat file '%s': %v", resPath, err)
return nil
}

// Download new file
// Download file in the target directory so that the call to
// os.Rename is atomic.
lpath, err := GetUrlToDir(u, dir, ctx)
if err != nil {
downloadError = fmt.Errorf("failed to download '%s': %v", u, err)
continue
}

// Validate downloaded file
if err := checkIntegrityFromFile(lpath, algo, l.Integrity, u); err != nil {
os.Remove(lpath)
downloadError = err
continue
}

// Move to final location
if err := os.Rename(lpath, resPath); err != nil {
os.Remove(lpath)
downloadError = err
continue
err = checkIntegrityFromFile(lpath, algo, l.Integrity, u)
if err != nil {
return err
}
err = os.Rename(lpath, resPath)
if err != nil {
return err
}

if mode != NoFileMode {
if err := os.Chmod(resPath, mode.Perm()); err != nil {
err = os.Chmod(resPath, mode.Perm())
if err != nil {
return err
}
}
ok = true
break
}

if !ok {
return downloadError
if downloadError != nil {
return downloadError
}
return err
}
return nil
}

func (l *Resource) Contains(url string) bool {
for _, u := range l.Urls {
if u == url {
Expand All @@ -164,21 +160,3 @@ func (l *Resource) Contains(url string) bool {
}
return false
}

func ValidateLocalFile(filePath string, expectedIntegrity string) bool {
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false
}

algo, err := getAlgoFromIntegrity(expectedIntegrity)
if err != nil {
return false
}

fileIntegrity, err := getIntegrityFromFile(filePath, algo)
if err != nil {
return false
}

return fileIntegrity == expectedIntegrity
}
34 changes: 34 additions & 0 deletions internal/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
package internal

import (
"context"
"fmt"
"net/http"
"os"
"path/filepath"
"testing"

"github.com/cisco-open/grabit/test"
Expand Down Expand Up @@ -50,3 +53,34 @@ func TestNewResourceFromUrl(t *testing.T) {
}
}
}

func TestResourceDownloadWithValidFileAlreadyPresent(t *testing.T) {
content := `abcdef`
contentIntegrity := test.GetSha256Integrity(content)
port := 33 // unused because the file is already present.
testFileName := "test.html"
resource := Resource{Urls: []string{fmt.Sprintf("http://localhost:%d/%s", port, testFileName)}, Integrity: contentIntegrity, Tags: []string{}, Filename: ""}
outputDir := test.TmpDir(t)
err := os.WriteFile(filepath.Join(outputDir, testFileName), []byte(content), 0644)
assert.Nil(t, err)
err = resource.Download(outputDir, 0644, context.Background())
assert.Nil(t, err)
for _, file := range []string{testFileName} {
test.AssertFileContains(t, fmt.Sprintf("%s/%s", outputDir, file), content)
}
}

func TestResourceDownloadWithInValidFileAlreadyPresent(t *testing.T) {
content := `abcdef`
contentIntegrity := test.GetSha256Integrity(content)
port := 33 // unused because the file, although invalid, is already present.
testFileName := "test.html"
resource := Resource{Urls: []string{fmt.Sprintf("http://localhost:%d/%s", port, testFileName)}, Integrity: contentIntegrity, Tags: []string{}, Filename: ""}
outputDir := test.TmpDir(t)
err := os.WriteFile(filepath.Join(outputDir, testFileName), []byte("invalid"), 0644)
assert.Nil(t, err)
err = resource.Download(outputDir, 0644, context.Background())
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "integrity mismatch")
assert.Contains(t, err.Error(), "existing file")
}
9 changes: 9 additions & 0 deletions test/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package test

import (
"crypto/sha256"
"encoding/base64"
"fmt"
"log"
"net"
"net/http"
Expand All @@ -14,6 +17,12 @@ import (
"github.com/stretchr/testify/assert"
)

func GetSha256Integrity(content string) string {
hasher := sha256.New()
hasher.Write([]byte(content))
return fmt.Sprintf("sha256-%s", base64.StdEncoding.EncodeToString(hasher.Sum(nil)))
}

func TmpFile(t *testing.T, content string) string {
f, err := os.CreateTemp(t.TempDir(), "test")
if err != nil {
Expand Down

0 comments on commit 3f21120

Please sign in to comment.