diff --git a/cmd/software-update/main.go b/cmd/software-update/main.go index 4eb1aa7..ba3bd14 100644 --- a/cmd/software-update/main.go +++ b/cmd/software-update/main.go @@ -34,6 +34,11 @@ func main() { loggerOut := logger.SetupLogger(logConfig) defer loggerOut.Close() + if err := suConfig.Validate(); err != nil { + logger.Errorf("failed to validate script-based software updatable configuration: %v\n", err) + os.Exit(1) + } + // Create new Script-Based software updatable edgeCtr, err := feature.InitScriptBasedSU(suConfig) if err != nil { diff --git a/internal/duration.go b/internal/duration.go new file mode 100644 index 0000000..8950d7a --- /dev/null +++ b/internal/duration.go @@ -0,0 +1,55 @@ +// Copyright (c) 2022 Contributors to the Eclipse Foundation +// +// See the NOTICE file(s) distributed with this work for additional +// information regarding copyright ownership. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0 +// +// SPDX-License-Identifier: EPL-2.0 + +package feature + +import ( + "encoding/json" + "errors" + "time" +) + +//durationTime is custom type of type time.Duration in order to add json unmarshal support +type durationTime time.Duration + +//UnmarshalJSON unmarshal durationTime type +func (d *durationTime) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + + case string: + duration, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = durationTime(duration) + default: + return errors.New("invalid duration") + } + return nil +} + +//Set durationTime from string, used for flag set +func (d *durationTime) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + err = errors.New("parse error") + } + *d = durationTime(v) + return err +} + +func (d durationTime) String() string { + return time.Duration(d).String() +} diff --git a/internal/feature.go b/internal/feature.go index 3cf6a77..7a4cb0e 100644 --- a/internal/feature.go +++ b/internal/feature.go @@ -12,6 +12,7 @@ package feature import ( + "fmt" "sync" "time" @@ -39,28 +40,32 @@ type operationFunc func() bool // ScriptBasedSoftwareUpdatableConfig provides the Script-Based SoftwareUpdatable configuration. type ScriptBasedSoftwareUpdatableConfig struct { - Broker string - Username string - Password string - StorageLocation string - FeatureID string - ModuleType string - ArtifactType string - ServerCert string - InstallCommand command + Broker string + Username string + Password string + StorageLocation string + FeatureID string + ModuleType string + ArtifactType string + ServerCert string + DownloadRetryCount int + DownloadRetryInterval durationTime + InstallCommand command } // ScriptBasedSoftwareUpdatable is the Script-Based SoftwareUpdatable actual implementation. type ScriptBasedSoftwareUpdatable struct { - lock sync.Mutex - queue chan operationFunc - store *storage.Storage - su *hawkbit.SoftwareUpdatable - dittoClient *ditto.Client - mqttClient MQTT.Client - artifactType string - serverCert string - installCommand *command + lock sync.Mutex + queue chan operationFunc + store *storage.Storage + su *hawkbit.SoftwareUpdatable + dittoClient *ditto.Client + mqttClient MQTT.Client + artifactType string + serverCert string + downloadRetryCount int + downloadRetryInterval time.Duration + installCommand *command } // InitScriptBasedSU creates a new Script-Based SoftwareUpdatable instance, listening for edge configuration. @@ -78,7 +83,12 @@ func InitScriptBasedSU(scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig) (*Ed store: localStorage, // Build install script command installCommand: &scriptSUPConfig.InstallCommand, - serverCert: scriptSUPConfig.ServerCert, + // Server download certificate + serverCert: scriptSUPConfig.ServerCert, + // Number of download reattempts + downloadRetryCount: scriptSUPConfig.DownloadRetryCount, + // Interval between download reattempts + downloadRetryInterval: time.Duration(scriptSUPConfig.DownloadRetryInterval), // Define the module artifact(s) type: archive or plane artifactType: scriptSUPConfig.ArtifactType, // Create queue with size 10 @@ -133,3 +143,11 @@ func (f *ScriptBasedSoftwareUpdatable) Disconnect(closeStorage bool) { logger.Info("ditto client disconnected") f.dittoClient.Disconnect() } + +// Validate the software updatable configuration +func (scriptSUPConfig *ScriptBasedSoftwareUpdatableConfig) Validate() error { + if scriptSUPConfig.DownloadRetryCount < 0 { + return fmt.Errorf("negative download retry count value - %d", scriptSUPConfig.DownloadRetryCount) + } + return nil +} diff --git a/internal/feature_download.go b/internal/feature_download.go index 03cc80a..c1958a7 100644 --- a/internal/feature_download.go +++ b/internal/feature_download.go @@ -123,7 +123,7 @@ Started: Downloading: if opError = f.store.DownloadModule(toDir, module, func(percent int) { setLastOS(su, newOS(cid, module, hawkbit.StatusDownloading).WithProgress(percent)) - }, f.serverCert); opError != nil { + }, f.serverCert, f.downloadRetryCount, f.downloadRetryInterval); opError != nil { opErrorMsg = errDownload return opError == storage.ErrCancel } diff --git a/internal/feature_install.go b/internal/feature_install.go index 3e3f14a..8b12d44 100644 --- a/internal/feature_install.go +++ b/internal/feature_install.go @@ -130,7 +130,7 @@ Started: Downloading: if opError = f.store.DownloadModule(dir, module, func(progress int) { setLastOS(su, newOS(cid, module, hawkbit.StatusDownloading).WithProgress(progress)) - }, f.serverCert); opError != nil { + }, f.serverCert, f.downloadRetryCount, f.downloadRetryInterval); opError != nil { opErrorMsg = errDownload return opError == storage.ErrCancel } diff --git a/internal/flags.go b/internal/flags.go index fbeae10..b29f173 100644 --- a/internal/flags.go +++ b/internal/flags.go @@ -43,6 +43,8 @@ const ( flagLogFileSize = "logFileSize" flagLogFileCount = "logFileCount" flagLogFileMaxAge = "logFileMaxAge" + flagRetryCount = "downloadRetryCount" + flagRetryInterval = "downloadRetryInterval" ) var ( @@ -51,20 +53,22 @@ var ( ) type cfg struct { - Broker string `json:"broker" def:"tcp://localhost:1883" descr:"Local MQTT broker address"` - Username string `json:"username" descr:"Username for authorized local client"` - Password string `json:"password" descr:"Password for authorized local client"` - StorageLocation string `json:"storageLocation" def:"." descr:"Location of the storage"` - FeatureID string `json:"featureId" def:"SoftwareUpdatable" descr:"Feature identifier of SoftwareUpdatable"` - ModuleType string `json:"moduleType" def:"software" descr:"Module type of SoftwareUpdatable"` - ArtifactType string `json:"artifactType" def:"archive" descr:"Defines the module artifact type: archive or plane"` - Install []string `json:"install" descr:"Defines the absolute path to install script"` - ServerCert string `json:"serverCert" descr:"A PEM encoded certificate \"file\" for secure artifact download"` - LogFile string `json:"logFile" def:"log/software-update.log" descr:"Log file location in storage directory"` - LogLevel string `json:"logLevel" def:"INFO" descr:"Log levels are ERROR, WARN, INFO, DEBUG, TRACE"` - LogFileSize int `json:"logFileSize" def:"2" descr:"Log file size in MB before it gets rotated"` - LogFileCount int `json:"logFileCount" def:"5" descr:"Log file max rotations count"` - LogFileMaxAge int `json:"logFileMaxAge" def:"28" descr:"Log file rotations max age in days"` + Broker string `json:"broker" def:"tcp://localhost:1883" descr:"Local MQTT broker address"` + Username string `json:"username" descr:"Username for authorized local client"` + Password string `json:"password" descr:"Password for authorized local client"` + StorageLocation string `json:"storageLocation" def:"." descr:"Location of the storage"` + FeatureID string `json:"featureId" def:"SoftwareUpdatable" descr:"Feature identifier of SoftwareUpdatable"` + ModuleType string `json:"moduleType" def:"software" descr:"Module type of SoftwareUpdatable"` + ArtifactType string `json:"artifactType" def:"archive" descr:"Defines the module artifact type: archive or plane"` + Install []string `json:"install" descr:"Defines the absolute path to install script"` + ServerCert string `json:"serverCert" descr:"A PEM encoded certificate \"file\" for secure artifact download"` + DownloadRetryCount int `json:"downloadRetryCount" def:"0" descr:"Number of retries, in case of a failed download.\n By default no retries are supported."` + DownloadRetryInterval durationTime `json:"downloadRetryInterval" def:"5s" descr:"Interval between retries, in case of a failed download.\n Should be a sequence of decimal numbers, each with optional fraction and a unit suffix, such as '300ms', '1.5h', '10m30s', etc. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'."` + LogFile string `json:"logFile" def:"log/software-update.log" descr:"Log file location in storage directory"` + LogLevel string `json:"logLevel" def:"INFO" descr:"Log levels are ERROR, WARN, INFO, DEBUG, TRACE"` + LogFileSize int `json:"logFileSize" def:"2" descr:"Log file size in MB before it gets rotated"` + LogFileCount int `json:"logFileCount" def:"5" descr:"Log file max rotations count"` + LogFileMaxAge int `json:"logFileMaxAge" def:"28" descr:"Log file rotations max age in days"` } // InitFlags tries to initialize Script-Based SoftwareUpdatable and Log configurations. @@ -80,7 +84,6 @@ func InitFlags(version string) (*ScriptBasedSoftwareUpdatableConfig, *logger.Log initFlagsWithDefaultValues(flgConfig) flag.Parse() - if *printVersion { fmt.Println(version) os.Exit(0) @@ -112,6 +115,13 @@ func initFlagsWithDefaultValues(config interface{}) { log.Printf("error parsing integer argument %v with value %v", fieldType.Name, defaultValue) } flag.IntVar(pointer.(*int), flagName, value, description) + case durationTime: + v, ok := pointer.(flag.Value) + if ok { + flag.Var(v, flagName, description) + } else { + log.Println("custom type Duration must implement reflect.Value interface") + } } } } @@ -123,6 +133,8 @@ func loadDefaultValues() *cfg { for i := 0; i < typeOf.NumField(); i++ { fieldType := typeOf.Field(i) defaultValue := fieldType.Tag.Get("def") + fieldValue := valueOf.FieldByName(fieldType.Name) + pointer := fieldValue.Addr().Interface() if len(defaultValue) > 0 { fieldValue := valueOf.FieldByName(fieldType.Name) switch fieldValue.Interface().(type) { @@ -134,6 +146,17 @@ func loadDefaultValues() *cfg { log.Printf("error parsing integer argument %v with value %v", fieldType.Name, defaultValue) } fieldValue.Set(reflect.ValueOf(value)) + case durationTime: + v, ok := pointer.(flag.Value) + if ok { + if err := v.Set(defaultValue); err == nil { + fieldValue.Set(reflect.ValueOf(v).Elem()) + } else { + log.Printf("error parsing argument %v with value %v - %v", fieldType.Name, defaultValue, err) + } + } else { + log.Println("custom type Duration must implement reflect.Value interface") + } } } @@ -163,7 +186,7 @@ func applyFlags(flagsConfig interface{}) { func applyConfigurationFile(configFile string) error { def := loadDefaultValues() - // Load configuration file (if posible) + // Load configuration file (if possible) if len(configFile) > 0 { if jf, err := os.Open(configFile); err == nil { defer jf.Close() diff --git a/internal/flags_test.go b/internal/flags_test.go index 2e32fcc..91a3d6d 100644 --- a/internal/flags_test.go +++ b/internal/flags_test.go @@ -112,6 +112,8 @@ func TestFlagsHasHigherPriority(t *testing.T) { expectedFeatureID := "TestFeature" expectedInstall := "TestInstall" expectedServerCert := "TestCert" + expectedDownloadRetryCount := 3 + expectedDownloadRetryInterval := "5s" expectedLogFile := "" expectedLogFileCount := 4 expectedLogFileMaxAge := 13 @@ -129,6 +131,8 @@ func TestFlagsHasHigherPriority(t *testing.T) { c(flagFeatureID, expectedFeatureID), c(flagInstall, expectedInstall), c(flagCert, expectedServerCert), + c(flagRetryCount, strconv.Itoa(expectedDownloadRetryCount)), + c(flagRetryInterval, expectedDownloadRetryInterval), c(flagLogFile, expectedLogFile), c(flagLogFileCount, strconv.Itoa(expectedLogFileCount)), c(flagLogFileMaxAge, strconv.Itoa(expectedLogFileMaxAge)), diff --git a/internal/storage/download.go b/internal/storage/download.go index 97f33ba..9096173 100644 --- a/internal/storage/download.go +++ b/internal/storage/download.go @@ -28,6 +28,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/eclipse-kanto/software-update/internal/logger" ) @@ -37,20 +38,21 @@ const prefix = "_temporary-" var secureCiphers = supportedCipherSuites() // downloadArtifact tries to resume previous download operation or perform a new download. -func downloadArtifact(to string, artifact *Artifact, progress progressBytes, serverCert string, done chan struct{}) error { - logger.Infof("Download [%s] to file [%s]", artifact.Link, to) +func downloadArtifact(to string, artifact *Artifact, progress progressBytes, serverCert string, retryCount int, retryInterval time.Duration, + done chan struct{}) error { + logger.Infof("download [%s] to file [%s]", artifact.Link, to) // Check for available file. if _, err := os.Stat(to); !os.IsNotExist(err) { - logger.Debugf("File exists, check its checksum: %s", to) + logger.Debugf("file exists, check its checksum: %s", to) if err = validate(to, artifact.HashType, artifact.HashValue); err == nil { - logger.Debugf("File already available: %s", to) + logger.Debugf("file already available: %s", to) if progress != nil { progress(int64(artifact.Size)) } return nil } - logger.Debugf("Available file with wrong checksum, remove it: %s", to) + logger.Debugf("available file with wrong checksum, remove it: %s", to) if err := os.Remove(to); err != nil { return err } @@ -69,30 +71,25 @@ func downloadArtifact(to string, artifact *Artifact, progress progressBytes, ser // Try to remove failed download file. if _, err := os.Stat(tmp); !os.IsNotExist(err) { if err = os.Remove(tmp); err != nil { - logger.Debugf("Failed to remove failed download file: %v", err) + logger.Debugf("failed to remove failed download file: %v", err) } } }() if stat, err := os.Stat(tmp); !os.IsNotExist(err) { // Try to resume previous download. - if dError = resume(tmp, stat.Size(), artifact, progress, serverCert, done); dError != nil { + if _, dError = resume(tmp, stat.Size(), artifact, progress, serverCert, retryCount, retryInterval, done); dError != nil { return dError } } else { // No available previous download, perform a full download. - response, err := requestDownload(artifact.Link, 0, serverCert) + response, remainingRetries, err := robustDownload(artifact.Link, 0, serverCert, retryCount, retryInterval) if err != nil { return err } defer response.Body.Close() - // HTTP Status code is NOT in the 2xx range - if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices { - return fmt.Errorf("http status code is not in the 2xx range: %v", response.StatusCode) - } - - if dError = download(tmp, response.Body, artifact, progress, done); dError != nil { + if _, dError = download(tmp, response.Body, artifact, progress, serverCert, remainingRetries, retryInterval, done); dError != nil { return dError } } @@ -101,51 +98,110 @@ func downloadArtifact(to string, artifact *Artifact, progress progressBytes, ser return os.Rename(tmp, to) } -func resume(to string, offset int64, artifact *Artifact, progress progressBytes, serverCert string, done chan struct{}) error { +func resume(to string, offset int64, artifact *Artifact, progress progressBytes, serverCert string, retryCount int, + retryInterval time.Duration, done chan struct{}) (int64, error) { // Send the HTTP request and get its response. - response, err := requestDownload(artifact.Link, offset, serverCert) + response, remainingRetries, err := robustDownload(artifact.Link, offset, serverCert, retryCount, retryInterval) if err != nil { - return err + return 0, err } defer response.Body.Close() - // HTTP Status code is NOT in the 2xx range - if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices { - return fmt.Errorf("http status code is not in the 2xx range: %v", response.StatusCode) - } - // Check if HTTP server support Range header. If not, delete existing file and perform regular download if response.Header.Get("Accept-Ranges") != "bytes" || response.Header.Get("Content-Range") == "" { - logger.Infof("Resume is not supported, remove previous file: %s", to) + logger.Infof("resume is not supported, remove previous file: %s", to) if err := os.Remove(to); err != nil { - return err + logger.Errorf("error removing partially downloaded file %s", to) + return 0, err } - return download(to, response.Body, artifact, progress, done) + return download(to, response.Body, artifact, progress, serverCert, remainingRetries, retryInterval, done) } // Download the rest of the file. - logger.Debugf("Resume previous download of file: %s", to) + logger.Debugf("resume previous download of file: %s", to) file, err := os.OpenFile(to, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0755) if err != nil { - return err + return 0, err } defer file.Close() if progress != nil { progress(offset) } - if w, err := copy(file, response.Body, int64(artifact.Size)-offset, progress, done); err != nil { - logger.Debugf("Written bytes: %v", w) - return err + return downloadFile(file, response.Body, to, offset, artifact, progress, serverCert, remainingRetries, retryInterval, done) +} + +func downloadFile(file *os.File, input io.ReadCloser, to string, offset int64, artifact *Artifact, + progress progressBytes, serverCert string, retryCount int, retryInterval time.Duration, done chan struct{}) (int64, error) { + w, err := copy(file, input, int64(artifact.Size)-offset, progress, done) + if err == nil { + err = validate(to, artifact.HashType, artifact.HashValue) + offset = 0 // in case of error, re-download the file + w = 0 + } else { + logger.Debugf("written bytes: %v", w) + offset += w + } + if err == nil { + return w, nil } - return validate(to, artifact.HashType, artifact.HashValue) + retryCount-- + for retryCount >= 0 { + var deltaBytes int64 + logger.Errorf("error copying artifact %s, remaining attempts - %d, cause: %v", file.Name(), retryCount, err) + logger.Infof("%v timeout until next attempt", retryInterval) + file.Close() + time.Sleep(time.Duration(retryInterval)) + logger.Infof("retrying to download artifact %s, current bytes written - %d", file.Name(), offset) + deltaBytes, err = resume(to, offset, artifact, progress, serverCert, 0, 0, done) + if err == nil { + break + } + offset += deltaBytes + retryCount-- + } + return w, err +} + +func robustDownload(link string, offset int64, serverCert string, retryCount int, retryInterval time.Duration) (*http.Response, int, error) { + var err error + var resp *http.Response + for retryCount >= 0 { + resp, err = attemptDownload(link, offset, serverCert) + if err == nil { + logger.Debugf("download response for artifact %s - %v", link, resp) + return resp, retryCount, nil + } + retryCount-- + if retryCount > 0 { + logger.Errorf("error downloading artifact %s, remaining attempts - %d, cause: %v", link, retryCount, err) + logger.Infof("%v timeout until next attempt", retryInterval) + if retryInterval > 0 { + time.Sleep(retryInterval) + } + } + } + return nil, 0, err +} + +func attemptDownload(link string, offset int64, serverCert string) (*http.Response, error) { + response, err := requestDownload(link, offset, serverCert) + if err != nil { + return nil, err + } + + // HTTP Status code is NOT in the 2xx range + if response.StatusCode < http.StatusOK || response.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("http status code is not in the 2xx range: %v", response.StatusCode) + } + return response, nil } func requestDownload(link string, offset int64, serverCert string) (*http.Response, error) { // Create new HTTP request with Range header. request, err := http.NewRequest(http.MethodGet, link, nil) if err != nil { - logger.Errorf("Error doing http(s) request to %s", link) + logger.Errorf("error doing http(s) request to %s", link) return nil, err } if offset > 0 { @@ -157,7 +213,7 @@ func requestDownload(link string, offset int64, serverCert string) (*http.Respon if len(serverCert) > 0 { caCert, err := ioutil.ReadFile(serverCert) if err != nil { - logger.Errorf("Error reading CA certificate file - \"%s\"", serverCert) + logger.Errorf("error reading CA certificate file - \"%s\"", serverCert) return nil, err } caCertPool = x509.NewCertPool() @@ -183,18 +239,15 @@ func requestDownload(link string, offset int64, serverCert string) (*http.Respon return client.Do(request) } -func download(to string, in io.ReadCloser, artifact *Artifact, progress progressBytes, done chan struct{}) error { +func download(to string, in io.ReadCloser, artifact *Artifact, progress progressBytes, + serverCert string, retryCount int, retryInterval time.Duration, done chan struct{}) (int64, error) { file, err := os.OpenFile(to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) if err != nil { - return err + return 0, err } defer file.Close() - if _, err = copy(file, in, int64(artifact.Size), progress, done); err != nil { - return err - } - - return validate(to, artifact.HashType, artifact.HashValue) + return downloadFile(file, in, to, 0, artifact, progress, serverCert, retryCount, retryInterval, done) } func copy(dst io.Writer, src io.Reader, size int64, progress progressBytes, done chan struct{}) (w int64, err error) { diff --git a/internal/storage/download_test.go b/internal/storage/download_test.go index 32f3d93..9ba4fad 100644 --- a/internal/storage/download_test.go +++ b/internal/storage/download_test.go @@ -166,7 +166,7 @@ func testDownloadToFile(arts []*Artifact, certFile, certKey string, t *testing.T // 1. Resume downlaod of corrupted temporary file. WriteLn(filepath.Join(dir, prefix+art.FileName), "wrong start") - if err := downloadArtifact(name, art, nil, certFile, make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, certFile, 0, 0, make(chan struct{})); err == nil { t.Fatal("downlaod of corrupted temporary file must fail") } @@ -175,7 +175,7 @@ func testDownloadToFile(arts []*Artifact, certFile, certKey string, t *testing.T callback := func(bytes int64) { close(done) } - if err := downloadArtifact(name, art, callback, certFile, done); err != ErrCancel { + if err := downloadArtifact(name, art, callback, certFile, 0, 0, done); err != ErrCancel { t.Fatalf("failed to cancel download operation: %v", err) } if _, err := os.Stat(filepath.Join(dir, prefix+art.FileName)); os.IsNotExist(err) { @@ -184,13 +184,13 @@ func testDownloadToFile(arts []*Artifact, certFile, certKey string, t *testing.T // 3. Resume previous download operation. callback = func(bytes int64) { /* Do nothing. */ } - if err := downloadArtifact(name, art, callback, certFile, make(chan struct{})); err != nil { + if err := downloadArtifact(name, art, callback, certFile, 0, 0, make(chan struct{})); err != nil { t.Fatalf("failed to download artifact: %v", err) } check(name, art.Size, t) // 4. Download available file. - if err := downloadArtifact(name, art, callback, certFile, make(chan struct{})); err != nil { + if err := downloadArtifact(name, art, callback, certFile, 0, 0, make(chan struct{})); err != nil { t.Fatalf("failed to download artifact: %v", err) } check(name, art.Size, t) @@ -203,14 +203,14 @@ func testDownloadToFile(arts []*Artifact, certFile, certKey string, t *testing.T // 5. Try to resume with file bigger than expected. WriteLn(filepath.Join(dir, prefix+art.FileName), "1111111111111") art.Size -= 10 - if err := downloadArtifact(name, art, nil, certFile, make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, certFile, 0, 0, make(chan struct{})); err == nil { t.Fatal("validate resume with file bigger than expected") } // 6. Try to resume from missing link. WriteLn(filepath.Join(dir, prefix+art.FileName), "1111111111111") art.Link = "http://localhost:43234/test-missing.txt" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("failed to validate with missing link") } @@ -242,33 +242,33 @@ func TestDownloadToFileError(t *testing.T) { // 1. Resume is not supported. WriteLn(filepath.Join(dir, prefix+art.FileName), "1111") - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err != nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err != nil { t.Fatalf("failed to download file artifact: %v", err) } check(name, art.Size, t) // 2. Try with missing checksum. art.HashValue = "" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("validated with missing checksum") } // 3. Try with missing link. art.Link = "http://localhost:43234/test-missing.txt" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("failed to validate with missing link") } // 4. Try with wrong checksum type. art.Link = "http://localhost:43234/test-simple.txt" art.HashType = "" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("validate with wrong checksum type") } // 5. Try with wrong checksum format. art.HashValue = ";;" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("validate with wrong checksum format") } @@ -276,7 +276,7 @@ func TestDownloadToFileError(t *testing.T) { art.HashType = "MD5" art.HashValue = "ab2ce340d36bbaafe17965a3a2c6ed5b" art.Size -= 10 - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatal("validate with file bigger than expected") } @@ -310,22 +310,22 @@ func TestDownloadToFileSecureError(t *testing.T) { // 1. Server uses expired certificate art.Link = "https://localhost:43234/test.txt" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatalf("download must fail(client uses no certificate, server uses expired): %v", err) } - if err := downloadArtifact(name, art, nil, expiredCert, make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, expiredCert, 0, 0, make(chan struct{})); err == nil { t.Fatalf("download must fail(client and server use expired certificate): %v", err) } // 2. Server uses untrusted certificate art.Link = "https://localhost:43235/test.txt" - if err := downloadArtifact(name, art, nil, "", make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, "", 0, 0, make(chan struct{})); err == nil { t.Fatalf("download must fail(client uses no certificate, server uses untrusted): %v", err) } // 3. Server uses valid certificate art.Link = "https://localhost:43236/test.txt" - if err := downloadArtifact(name, art, nil, untrustedCert, make(chan struct{})); err == nil { + if err := downloadArtifact(name, art, nil, untrustedCert, 0, 0, make(chan struct{})); err == nil { t.Fatalf("download must fail(client uses untrusted certificate, server uses valid): %v", err) } } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 3ac40ac..fa1bcf2 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -20,6 +20,7 @@ import ( "path/filepath" "sort" "strconv" + "time" "github.com/eclipse-kanto/software-update/hawkbit" "github.com/eclipse-kanto/software-update/internal/logger" @@ -211,7 +212,8 @@ func (st *Storage) ArchiveModule(dir string) error { } // DownloadModule artifacts to local storage. -func (st *Storage) DownloadModule(toDir string, module *Module, progress Progress, serverCert string) (err error) { +func (st *Storage) DownloadModule(toDir string, module *Module, progress Progress, serverCert string, + retryCount int, retryInterval time.Duration) (err error) { logger.Debugf("Download module to directory: [%s]", toDir) logger.Tracef("Module: %v", module) if err = os.MkdirAll(toDir, 0755); err != nil { @@ -240,7 +242,8 @@ func (st *Storage) DownloadModule(toDir string, module *Module, progress Progres } for _, sa := range module.Artifacts { - if err = downloadArtifact(filepath.Join(toDir, sa.FileName), sa, callback, serverCert, st.done); err != nil { + if err = downloadArtifact(filepath.Join(toDir, sa.FileName), sa, callback, serverCert, retryCount, retryInterval, + st.done); err != nil { return err } } diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go index 298021a..15e22c0 100644 --- a/internal/storage/storage_test.go +++ b/internal/storage/storage_test.go @@ -234,7 +234,7 @@ func TestDownloadArchiveModule(t *testing.T) { // 1. Download module without progress. path := filepath.Join(store.DownloadPath, "0", "0") m := &Module{Name: "name1", Version: "1", Artifacts: []*Artifact{art}} - if err := store.DownloadModule(path, m, nil, ""); err != nil { + if err := store.DownloadModule(path, m, nil, "", 0, 0); err != nil { t.Fatalf("fail to download module [Hash: %s, File: %s]: %v", art.HashValue, hex.EncodeToString(srv.data), err) } existence(filepath.Join(path, art.FileName), true, "[initial download]", t) @@ -251,7 +251,7 @@ func TestDownloadArchiveModule(t *testing.T) { // 3. Download previous module with progress. path = filepath.Join(store.DownloadPath, "0", "1") progress := func(percent int) { /* Do nothing. */ } - if err := store.DownloadModule(path, m, progress, ""); err != nil { + if err := store.DownloadModule(path, m, progress, "", 0, 0); err != nil { t.Errorf("fail to download module: %v", err) } existence(filepath.Join(store.ModulesPath, "0", art.FileName), false, "[archive]", t)