diff --git a/Makefile b/Makefile index 851dd4c47..67a708a1c 100644 --- a/Makefile +++ b/Makefile @@ -135,7 +135,7 @@ tidy: gosec: gosec -exclude-dir updater ./... # exclude the testdata dir; it contains a go program for testing. - cd updater; gosec -exclude-dir internal/install/testdata ./... + cd updater; gosec -exclude-dir internal/service/testdata ./... # This target performs all checks that CI will do (excluding the build itself) .PHONY: ci-checks diff --git a/opamp/observiq/observiq_client.go b/opamp/observiq/observiq_client.go index 33e080a5e..efab8ecdc 100644 --- a/opamp/observiq/observiq_client.go +++ b/opamp/observiq/observiq_client.go @@ -81,7 +81,7 @@ func NewClient(args *NewClientArgs) (opamp.Client, error) { downloadableFileManager: newDownloadableFileManager(clientLogger, args.TmpPath), collector: args.Collector, currentConfig: args.Config, - packagesStateProvider: newPackagesStateProvider(clientLogger, "package_statuses.json"), + packagesStateProvider: newPackagesStateProvider(clientLogger, packagestate.DefaultFileName), } // Parse URL to determin scheme diff --git a/packagestate/packages_state_manager.go b/packagestate/packages_state_manager.go index 287b3f198..757d4d845 100644 --- a/packagestate/packages_state_manager.go +++ b/packagestate/packages_state_manager.go @@ -28,6 +28,9 @@ import ( // CollectorPackageName is the name for the top level packages for this collector const CollectorPackageName = "observiq-otel-collector" +// DefaultFileName is the default name of the file use to store state +const DefaultFileName = "package_statuses.json" + // StateManager tracks Package states type StateManager interface { // LoadStatuses retrieves the previously saved PackagesStatuses. diff --git a/updater/cmd/updater/main.go b/updater/cmd/updater/main.go index cdbcfb512..be21574d9 100644 --- a/updater/cmd/updater/main.go +++ b/updater/cmd/updater/main.go @@ -15,13 +15,21 @@ package main import ( + "context" + "errors" "fmt" "log" "os" + "time" + "github.com/observiq/observiq-otel-collector/packagestate" "github.com/observiq/observiq-otel-collector/updater/internal/install" + "github.com/observiq/observiq-otel-collector/updater/internal/rollback" + "github.com/observiq/observiq-otel-collector/updater/internal/state" "github.com/observiq/observiq-otel-collector/updater/internal/version" + "github.com/open-telemetry/opamp-go/protobufs" "github.com/spf13/pflag" + "go.uber.org/zap" ) // Unimplemented @@ -43,12 +51,58 @@ func main() { os.Exit(1) } + // Create a monitor and load the package status file + // TODO replace nop logger with real one + monitor, err := state.NewCollectorMonitor(zap.NewNop()) + if err != nil { + log.Fatalln("Failed to create monitor:", err) + } + installer, err := install.NewInstaller(*tmpDir) if err != nil { log.Fatalf("Failed to create installer: %s", err) } - if err := installer.Install(); err != nil { - log.Fatalf("Failed to install: %s", err) + rb, err := rollback.NewRollbacker(*tmpDir) + if err != nil { + log.Fatalf("Failed to create rollbacker: %s", err) + } + + if err := rb.Backup(); err != nil { + log.Fatalf("Failed to backup: %s", err) } + + if err := installer.Install(rb); err != nil { + log.Default().Printf("Failed to install: %s", err) + + // Set the state to failed before rollback so collector knows it failed + if setErr := monitor.SetState(packagestate.DefaultFileName, protobufs.PackageStatus_InstallFailed, err); setErr != nil { + log.Println("Failed to set state on install failure:", setErr) + } + rb.Rollback() + log.Default().Fatalf("Rollback complete") + } + + // Create a context with timeout to wait for a success or failed status + checkCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Monitor the install state + if err := monitor.MonitorForSuccess(checkCtx, packagestate.DefaultFileName); err != nil { + log.Println("Failed to install:", err) + + // If this is not an error due to the collector setting a failed status we need to set a failed status + if !errors.Is(err, state.ErrFailedStatus) { + // Set the state to failed before rollback so collector knows it failed + if setErr := monitor.SetState(packagestate.DefaultFileName, protobufs.PackageStatus_InstallFailed, err); setErr != nil { + log.Println("Failed to set state on install failure:", setErr) + } + } + + rb.Rollback() + log.Fatalln("Rollback complete") + } + + // Successful update + log.Println("Update Complete") } diff --git a/updater/go.mod b/updater/go.mod index d113e4346..32c25c7ab 100644 --- a/updater/go.mod +++ b/updater/go.mod @@ -4,14 +4,22 @@ go 1.17 require ( github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 + github.com/observiq/observiq-otel-collector/packagestate v0.0.0-00010101000000-000000000000 + github.com/open-telemetry/opamp-go v0.2.0 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.7.2 - golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211 + github.com/stretchr/testify v1.8.0 + go.uber.org/zap v1.21.0 + golang.org/x/sys v0.0.0-20210510120138-977fb7262007 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.1.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.8.0 // indirect + google.golang.org/protobuf v1.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/observiq/observiq-otel-collector/packagestate => ../packagestate diff --git a/updater/go.sum b/updater/go.sum index 788615ce8..fdb019d02 100644 --- a/updater/go.sum +++ b/updater/go.sum @@ -1,19 +1,81 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/open-telemetry/opamp-go v0.2.0 h1:dV7wTkG5XNiorU62N1CJPr3f5dM0PGEtUUBtvK+LEG0= +github.com/open-telemetry/opamp-go v0.2.0/go.mod h1:IMdeuHGVc5CjKSu5/oNV0o+UmiXuahoHvoZ4GOmAI9M= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= -github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= -golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211 h1:9UQO31fZ+0aKQOFldThf7BKPMJTiBfWycGh/u3UoO88= -golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= +go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/updater/internal/action/action.go b/updater/internal/action/action.go new file mode 100644 index 000000000..6ddbd7845 --- /dev/null +++ b/updater/internal/action/action.go @@ -0,0 +1,21 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +// RollbackableAction is an interface to represents an install action that may be rolled back. +//go:generate mockery --name RollbackableAction --filename rollbackable_action.go +type RollbackableAction interface { + Rollback() error +} diff --git a/updater/internal/action/file_action.go b/updater/internal/action/file_action.go new file mode 100644 index 000000000..d62f3ab81 --- /dev/null +++ b/updater/internal/action/file_action.go @@ -0,0 +1,82 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" +) + +// CopyFileAction is an action that records a file being copied from FromPath to ToPath +type CopyFileAction struct { + // FromPathRel is the path where the file originated, relative to the "latest" + // directory + FromPathRel string + // ToPath is the path where the file was written. + ToPath string + // FileCreated is a bool that records whether this action had to create a new file or not + FileCreated bool + backupDir string + latestDir string +} + +var _ RollbackableAction = (*CopyFileAction)(nil) + +// NewCopyFileAction creates a new CopyFileAction that indicates a file was copied from +// fromPathRel into toPath. tmpDir is specified for rollback purposes. +// NOTE: This action MUST be created BEFORE the action actually takes place; This allows +// for previous existence of the file to be recorded. +func NewCopyFileAction(fromPathRel, toPath, tmpDir string) (*CopyFileAction, error) { + fileExists := true + _, err := os.Stat(toPath) + switch { + case errors.Is(err, os.ErrNotExist): + fileExists = false + case err != nil: + return nil, fmt.Errorf("unexpected error stat-ing file: %w", err) + } + + return &CopyFileAction{ + FromPathRel: fromPathRel, + ToPath: toPath, + // The file will be created if it doesn't already exist + FileCreated: !fileExists, + backupDir: path.BackupDirFromTempDir(tmpDir), + latestDir: path.LatestDirFromTempDir(tmpDir), + }, nil +} + +// Rollback will undo the file copy, by either deleting the file if the file did not originally exist, +// or it will copy the old file in the rollback dir if it already exists. +func (c CopyFileAction) Rollback() error { + if c.FileCreated { + // File did not exist before this action. + // We just need to delete this file. + return os.RemoveAll(c.ToPath) + } + + // join the relative path to the backup directory to get the location of the backup path + backupFilePath := filepath.Join(c.backupDir, c.FromPathRel) + if err := file.CopyFile(backupFilePath, c.ToPath, true); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} diff --git a/updater/internal/action/file_action_test.go b/updater/internal/action/file_action_test.go new file mode 100644 index 000000000..c61a74238 --- /dev/null +++ b/updater/internal/action/file_action_test.go @@ -0,0 +1,157 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewCopyFileAction(t *testing.T) { + t.Run("out file does not exist", func(t *testing.T) { + scratchDir := t.TempDir() + testTempDir := filepath.Join("testdata", "copyfileaction") + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testTempDir, "latest", "test.txt") + + a, err := NewCopyFileAction(inFile, outFile, testTempDir) + require.NoError(t, err) + + require.Equal(t, &CopyFileAction{ + FromPathRel: inFile, + ToPath: outFile, + FileCreated: true, + backupDir: filepath.Join(testTempDir, "rollback"), + latestDir: filepath.Join(testTempDir, "latest"), + }, a) + }) + + t.Run("out file exists", func(t *testing.T) { + scratchDir := t.TempDir() + testTempDir := filepath.Join("testdata", "copyfileaction") + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testTempDir, "latest", "test.txt") + + f, err := os.Create(outFile) + require.NoError(t, err) + require.NoError(t, f.Close()) + + a, err := NewCopyFileAction(inFile, outFile, testTempDir) + require.NoError(t, err) + + require.Equal(t, &CopyFileAction{ + FromPathRel: inFile, + ToPath: outFile, + FileCreated: false, + backupDir: filepath.Join(testTempDir, "rollback"), + latestDir: filepath.Join(testTempDir, "latest"), + }, a) + }) +} + +func TestCopyFileActionRollback(t *testing.T) { + t.Run("deletes out file if it does not exist", func(t *testing.T) { + scratchDir := t.TempDir() + testTempDir := filepath.Join("testdata", "copyfileaction") + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testTempDir, "latest", "test.txt") + + a, err := NewCopyFileAction(inFile, outFile, testTempDir) + require.NoError(t, err) + + inBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, inBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.NoError(t, err) + + require.NoFileExists(t, outFile) + }) + + t.Run("Rolls back out file when it exists", func(t *testing.T) { + scratchDir := t.TempDir() + testTempDir := filepath.Join("testdata", "copyfileaction") + outFile := filepath.Join(scratchDir, "test.txt") + inFileRel := "test.txt" + inFile := filepath.Join(testTempDir, "latest", inFileRel) + originalFile := filepath.Join(testTempDir, "rollback", "test.txt") + + originalBytes, err := os.ReadFile(originalFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, originalBytes, 0600) + require.NoError(t, err) + + a, err := NewCopyFileAction(inFileRel, outFile, testTempDir) + require.NoError(t, err) + + // Overwrite original file with latest file + inBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, inBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.NoError(t, err) + + require.FileExists(t, outFile) + + rolledBackBytes, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, originalBytes, rolledBackBytes) + }) + + t.Run("Fails if backup file doesn't exist", func(t *testing.T) { + scratchDir := t.TempDir() + testTempDir := filepath.Join("testdata", "copyfileaction") + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testTempDir, "latest", "not_in_backup.txt") + originalFile := filepath.Join(testTempDir, "rollback", "test.txt") + + // The latest file exists in the directory already, but for some reason is not copied to backup + originalBytes, err := os.ReadFile(originalFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, originalBytes, 0600) + require.NoError(t, err) + + a, err := NewCopyFileAction(inFile, outFile, testTempDir) + require.NoError(t, err) + + // Overwrite original file with latest file + latestBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, latestBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.ErrorContains(t, err, "failed to copy file") + require.FileExists(t, outFile) + + finalBytes, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, latestBytes, finalBytes) + }) + +} diff --git a/updater/internal/action/mocks/rollbackable_action.go b/updater/internal/action/mocks/rollbackable_action.go new file mode 100644 index 000000000..5ccf66995 --- /dev/null +++ b/updater/internal/action/mocks/rollbackable_action.go @@ -0,0 +1,39 @@ +// Code generated by mockery v2.13.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// RollbackableAction is an autogenerated mock type for the RollbackableAction type +type RollbackableAction struct { + mock.Mock +} + +// Rollback provides a mock function with given fields: +func (_m *RollbackableAction) Rollback() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewRollbackableAction interface { + mock.TestingT + Cleanup(func()) +} + +// NewRollbackableAction creates a new instance of RollbackableAction. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewRollbackableAction(t mockConstructorTestingTNewRollbackableAction) *RollbackableAction { + mock := &RollbackableAction{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/action/service_action.go b/updater/internal/action/service_action.go new file mode 100644 index 000000000..000dab426 --- /dev/null +++ b/updater/internal/action/service_action.go @@ -0,0 +1,80 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/service" +) + +// ServiceStopAction is an action that records that a service was stopped. +type ServiceStopAction struct { + svc service.Service +} + +var _ RollbackableAction = (*ServiceStopAction)(nil) + +// NewServiceStopAction creates a new ServiceStopAction +func NewServiceStopAction(svc service.Service) *ServiceStopAction { + return &ServiceStopAction{ + svc: svc, + } +} + +// Rollback rolls back the stop action (starts the service) +func (s ServiceStopAction) Rollback() error { + return s.svc.Start() +} + +// ServiceStartAction is an action that records that a service was started. +type ServiceStartAction struct { + svc service.Service +} + +var _ RollbackableAction = (*ServiceStartAction)(nil) + +// NewServiceStartAction creates a new ServiceStartAction +func NewServiceStartAction(svc service.Service) *ServiceStartAction { + return &ServiceStartAction{ + svc: svc, + } +} + +// Rollback rolls back the start action (stops the service) +func (s ServiceStartAction) Rollback() error { + return s.svc.Stop() +} + +// ServiceUpdateAction is an action that records that a service was updated. +type ServiceUpdateAction struct { + backupSvc service.Service +} + +var _ RollbackableAction = (*ServiceUpdateAction)(nil) + +// NewServiceUpdateAction creates a new ServiceUpdateAction +func NewServiceUpdateAction(tmpDir string) *ServiceUpdateAction { + return &ServiceUpdateAction{ + backupSvc: service.NewService( + "", // latestDir doesn't matter here + service.WithServiceFile(path.BackupServiceFile(path.ServiceFileDir(path.BackupDirFromTempDir(tmpDir)))), + ), + } +} + +// Rollback is an action that rolls back the service configuration to the one saved in the backup directory. +func (s ServiceUpdateAction) Rollback() error { + return s.backupSvc.Update() +} diff --git a/updater/internal/action/service_action_test.go b/updater/internal/action/service_action_test.go new file mode 100644 index 000000000..42127b427 --- /dev/null +++ b/updater/internal/action/service_action_test.go @@ -0,0 +1,53 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/require" +) + +func TestServiceStartAction(t *testing.T) { + svc := mocks.NewService(t) + ssa := NewServiceStartAction(svc) + + svc.On("Stop").Once().Return(nil) + + err := ssa.Rollback() + require.NoError(t, err) +} + +func TestServiceStopAction(t *testing.T) { + svc := mocks.NewService(t) + ssa := NewServiceStopAction(svc) + + svc.On("Start").Once().Return(nil) + + err := ssa.Rollback() + require.NoError(t, err) +} + +func TestServiceUpdateAction(t *testing.T) { + svc := mocks.NewService(t) + sua := NewServiceUpdateAction("./testdata") + sua.backupSvc = svc + + svc.On("Update").Once().Return(nil) + + err := sua.Rollback() + require.NoError(t, err) +} diff --git a/updater/internal/action/testdata/copyfileaction/latest/not_in_backup.txt b/updater/internal/action/testdata/copyfileaction/latest/not_in_backup.txt new file mode 100644 index 000000000..20f76d643 --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/latest/not_in_backup.txt @@ -0,0 +1 @@ +This file doesn't exist in backup diff --git a/updater/internal/action/testdata/copyfileaction/latest/test.txt b/updater/internal/action/testdata/copyfileaction/latest/test.txt new file mode 100644 index 000000000..6dfa057f0 --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/latest/test.txt @@ -0,0 +1 @@ +This is a new file diff --git a/updater/internal/action/testdata/copyfileaction/rollback/test.txt b/updater/internal/action/testdata/copyfileaction/rollback/test.txt new file mode 100644 index 000000000..684d5588a --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/rollback/test.txt @@ -0,0 +1 @@ +This is the old file diff --git a/updater/internal/file/file.go b/updater/internal/file/file.go new file mode 100644 index 000000000..468be246f --- /dev/null +++ b/updater/internal/file/file.go @@ -0,0 +1,71 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "fmt" + "io" + "log" + "os" + "path/filepath" +) + +// CopyFile copies the file from pathIn to pathOut. +// If the file does not exist, it is created. If the file does exist, it is truncated before writing. +func CopyFile(pathIn, pathOut string, overwrite bool) error { + pathInClean := filepath.Clean(pathIn) + + // Open the input file for reading. + inFile, err := os.Open(pathInClean) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer func() { + err := inFile.Close() + if err != nil { + log.Default().Printf("CopyFile: Failed to close input file: %s", err) + } + }() + + flags := os.O_CREATE | os.O_WRONLY + if overwrite { + // If we are OK to overwrite, we will truncate the file on open + flags |= os.O_TRUNC + } else { + // This flag will make OpenFile error if the file already exists + flags |= os.O_EXCL + } + + pathOutClean := filepath.Clean(pathOut) + // Open the output file, creating it if it does not exist and truncating it. + //#nosec G304 -- out file is cleaned; this is a general purpose copy function + outFile, err := os.OpenFile(pathOutClean, flags, 0600) + if err != nil { + return fmt.Errorf("failed to open output file: %w", err) + } + defer func() { + err := outFile.Close() + if err != nil { + log.Default().Printf("CopyFile: Failed to close output file: %s", err) + } + }() + + // Copy the input file to the output file. + if _, err := io.Copy(outFile, inFile); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} diff --git a/updater/internal/file/file_test.go b/updater/internal/file/file_test.go new file mode 100644 index 000000000..5db7bea0f --- /dev/null +++ b/updater/internal/file/file_test.go @@ -0,0 +1,161 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "io/fs" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCopyFile(t *testing.T) { + t.Run("Copies file when output does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFile(inFile, outFile, true) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, contentsIn, contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0600), fi.Mode()) + } + }) + + t.Run("Copies file when output already exists", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, []byte("This is a file that already exists"), 0640) + require.NoError(t, err) + + err = CopyFile(inFile, outFile, true) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, contentsIn, contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0640), fi.Mode()) + } + }) + + t.Run("Fails when input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFile(inFile, outFile, true) + require.ErrorContains(t, err, "failed to open input file") + require.NoFileExists(t, outFile) + }) + + t.Run("Does not truncate if input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0600) + require.NoError(t, err) + + err = CopyFile(inFile, outFile, true) + require.ErrorContains(t, err, "failed to open input file") + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, []byte("This is a file that already exists"), contentsOut) + }) + + t.Run("Fails to overwrite the output file if 'overwrite' false", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0640) + require.NoError(t, err) + + err = CopyFile(inFile, outFile, false) + require.ErrorContains(t, err, "failed to open output file") + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, []byte("This is a file that already exists"), contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0640), fi.Mode()) + } + }) + + t.Run("Copies file when output does not exist when 'overwrite' is false", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFile(inFile, outFile, false) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, contentsIn, contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0600), fi.Mode()) + } + }) +} diff --git a/updater/internal/file/testdata/test.txt b/updater/internal/file/testdata/test.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/updater/internal/file/testdata/test.txt @@ -0,0 +1 @@ +This is a test file diff --git a/updater/internal/install/install.go b/updater/internal/install/install.go index 2f129e931..18de2900c 100644 --- a/updater/internal/install/install.go +++ b/updater/internal/install/install.go @@ -16,11 +16,15 @@ package install import ( "fmt" - "io" "io/fs" - "log" "os" "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/action" + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/rollback" + "github.com/observiq/observiq-otel-collector/updater/internal/service" ) // Installer allows you to install files from latestDir into installDir, @@ -28,60 +32,60 @@ import ( type Installer struct { latestDir string installDir string - svc Service + tmpDir string + svc service.Service } // NewInstaller returns a new instance of an Installer. -func NewInstaller(tempDir string) (*Installer, error) { - latestDir := filepath.Join(tempDir, "latest") - installDirPath, err := installDir() +func NewInstaller(tmpDir string) (*Installer, error) { + latestDir := path.LatestDirFromTempDir(tmpDir) + installDirPath, err := path.InstallDir() if err != nil { return nil, fmt.Errorf("failed to determine install dir: %w", err) } return &Installer{ latestDir: latestDir, - svc: newService(latestDir), + svc: service.NewService(latestDir), installDir: installDirPath, + tmpDir: tmpDir, }, nil } -// Install installs the unpacked artifacts in latestDirPath to installDirPath, -// as well as installing the new service file using the provided Service interface -func (i Installer) Install() error { +// Install installs the unpacked artifacts in latestDir to installDir, +// as well as installing the new service file using the installer's Service interface +func (i Installer) Install(rb rollback.ActionAppender) error { // Stop service if err := i.svc.Stop(); err != nil { return fmt.Errorf("failed to stop service: %w", err) } + rb.AppendAction(action.NewServiceStopAction(i.svc)) // install files that go to installDirPath to their correct location, // excluding any config files (logging.yaml, config.yaml, manager.yaml) - if err := moveFiles(i.latestDir, i.installDir); err != nil { + if err := copyFiles(i.latestDir, i.installDir, i.tmpDir, rb); err != nil { return fmt.Errorf("failed to install new files: %w", err) } - // Uninstall previous service - if err := i.svc.Uninstall(); err != nil { - return fmt.Errorf("failed to uninstall service: %w", err) - } - - // Install new service - if err := i.svc.Install(); err != nil { - return fmt.Errorf("failed to install service: %w", err) + // Update old service config to new service config + if err := i.svc.Update(); err != nil { + return fmt.Errorf("failed to update service: %w", err) } + rb.AppendAction(action.NewServiceUpdateAction(i.tmpDir)) // Start service if err := i.svc.Start(); err != nil { return fmt.Errorf("failed to start service: %w", err) } + rb.AppendAction(action.NewServiceStartAction(i.svc)) return nil } -// moveFiles moves the file tree rooted at latestDirPath to installDirPath, -// skipping configuration files -func moveFiles(latestDirPath, installDirPath string) error { - err := filepath.WalkDir(latestDirPath, func(path string, d fs.DirEntry, err error) error { +// copyFiles moves the file tree rooted at latestDirPath to installDirPath, +// skipping configuration files. Appends CopyFileAction-s to the Rollbacker as it copies file. +func copyFiles(inputPath, outputPath, tmpDir string, rb rollback.ActionAppender) error { + err := filepath.WalkDir(inputPath, func(inPath string, d fs.DirEntry, err error) error { switch { case err != nil: // if there was an error walking the directory, we want to bail out. @@ -89,55 +93,40 @@ func moveFiles(latestDirPath, installDirPath string) error { case d.IsDir(): // Skip directories, we'll create them when we get a file in the directory. return nil - case skipFile(path): + case skipConfigFiles(inPath): // Found a config file that we should skip copying. return nil } - cleanPath := filepath.Clean(path) - // We want the path relative to the directory we are walking in order to calculate where the file should be // mirrored in the destination directory. - relPath, err := filepath.Rel(latestDirPath, cleanPath) + relPath, err := filepath.Rel(inputPath, inPath) if err != nil { return err } // use the relative path to get the outPath (where we should write the file), and // to get the out directory (which we will create if it does not exist). - outPath := filepath.Clean(filepath.Join(installDirPath, relPath)) + outPath := filepath.Join(outputPath, relPath) outDir := filepath.Dir(outPath) if err := os.MkdirAll(outDir, 0750); err != nil { return fmt.Errorf("failed to create dir: %w", err) } - // Open the output file, creating it if it does not exist and truncating it. - outFile, err := os.OpenFile(outPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + // We create the action record here, because we want to record whether the file exists or not before + // we open the file (which will end up creating the file). + cfa, err := action.NewCopyFileAction(relPath, outPath, tmpDir) if err != nil { - return fmt.Errorf("failed to open output file: %w", err) + return fmt.Errorf("failed to create copy file action: %w", err) } - defer func() { - err := outFile.Close() - if err != nil { - log.Default().Printf("installFiles: Failed to close output file: %s", err) - } - }() - - // Open the input file for reading. - inFile, err := os.Open(cleanPath) - if err != nil { - return fmt.Errorf("failed to open input file: %w", err) - } - defer func() { - err := inFile.Close() - if err != nil { - log.Default().Printf("installFiles: Failed to close input file: %s", err) - } - }() - - // Copy the input file to the output file. - if _, err := io.Copy(outFile, inFile); err != nil { + + // Record that we are performing copying the file. + // We record before we actually do the action here because the file may be partially written, + // and we will want to roll that back if that is the case. + rb.AppendAction(cfa) + + if err := file.CopyFile(inPath, outPath, true); err != nil { return fmt.Errorf("failed to copy file: %w", err) } @@ -151,9 +140,9 @@ func moveFiles(latestDirPath, installDirPath string) error { return nil } -// skipFile returns true if the given path is a special config file. +// skipConfigFiles returns true if the given path is a special config file. // These files should not be overwritten. -func skipFile(path string) bool { +func skipConfigFiles(path string) bool { var configFiles = []string{ "config.yaml", "logging.yaml", diff --git a/updater/internal/install/install_test.go b/updater/internal/install/install_test.go index b13a43c57..bf9b0e844 100644 --- a/updater/internal/install/install_test.go +++ b/updater/internal/install/install_test.go @@ -21,7 +21,10 @@ import ( "path/filepath" "testing" - "github.com/observiq/observiq-otel-collector/updater/internal/install/mocks" + "github.com/observiq/observiq-otel-collector/updater/internal/action" + rb_mocks "github.com/observiq/observiq-otel-collector/updater/internal/rollback/mocks" + "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -29,6 +32,8 @@ func TestInstallArtifacts(t *testing.T) { t.Run("Installs artifacts correctly", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) + installer := &Installer{ latestDir: filepath.Join("testdata", "example-install"), installDir: outDir, @@ -47,11 +52,16 @@ func TestInstallArtifacts(t *testing.T) { require.NoError(t, err) svc.On("Stop").Once().Return(nil) - svc.On("Uninstall").Once().Return(nil) - svc.On("Install").Once().Return(nil) + svc.On("Update").Once().Return(nil) svc.On("Start").Once().Return(nil) - err = installer.Install() + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) require.NoError(t, err) contentsEqual(t, outDirConfig, "# The original config file") @@ -64,11 +74,36 @@ func TestInstallArtifacts(t *testing.T) { contentsEqual(t, filepath.Join(outDir, "test.txt"), "This is a test file\n") contentsEqual(t, filepath.Join(outDir, "test-folder", "another-test.txt"), "This is a nested text file\n") + + copyTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + require.Equal(t, []action.RollbackableAction{ + action.NewServiceStopAction(svc), + copyNestedTestTxtAction, + copyTestTxtAction, + action.NewServiceUpdateAction(installer.tmpDir), + action.NewServiceStartAction(svc), + }, actions) }) t.Run("Stop fails", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) installer := &Installer{ latestDir: filepath.Join("testdata", "example-install"), installDir: outDir, @@ -77,13 +112,14 @@ func TestInstallArtifacts(t *testing.T) { svc.On("Stop").Once().Return(errors.New("stop failed")) - err := installer.Install() + err := installer.Install(rb) require.ErrorContains(t, err, "failed to stop service") }) - t.Run("Uninstall fails", func(t *testing.T) { + t.Run("Update fails", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) installer := &Installer{ latestDir: filepath.Join("testdata", "example-install"), installDir: outDir, @@ -91,32 +127,43 @@ func TestInstallArtifacts(t *testing.T) { } svc.On("Stop").Once().Return(nil) - svc.On("Uninstall").Once().Return(errors.New("uninstall failed")) - - err := installer.Install() - require.ErrorContains(t, err, "failed to uninstall service") - }) - - t.Run("Install fails", func(t *testing.T) { - outDir := t.TempDir() - svc := mocks.NewService(t) - installer := &Installer{ - latestDir: filepath.Join("testdata", "example-install"), - installDir: outDir, - svc: svc, - } + svc.On("Update").Once().Return(errors.New("uninstall failed")) + + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err := installer.Install(rb) + require.ErrorContains(t, err, "failed to update service") + copyTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true - svc.On("Stop").Once().Return(nil) - svc.On("Uninstall").Once().Return(nil) - svc.On("Install").Once().Return(errors.New("install failed")) + copyNestedTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true - err := installer.Install() - require.ErrorContains(t, err, "failed to install service") + require.Equal(t, []action.RollbackableAction{ + action.NewServiceStopAction(svc), + copyNestedTestTxtAction, + copyTestTxtAction, + }, actions) }) t.Run("Start fails", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) installer := &Installer{ latestDir: filepath.Join("testdata", "example-install"), installDir: outDir, @@ -124,17 +171,46 @@ func TestInstallArtifacts(t *testing.T) { } svc.On("Stop").Once().Return(nil) - svc.On("Uninstall").Once().Return(nil) - svc.On("Install").Once().Return(nil) + svc.On("Update").Once().Return(nil) svc.On("Start").Once().Return(errors.New("start failed")) - err := installer.Install() + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err := installer.Install(rb) require.ErrorContains(t, err, "failed to start service") + + copyTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + require.Equal(t, []action.RollbackableAction{ + action.NewServiceStopAction(svc), + copyNestedTestTxtAction, + copyTestTxtAction, + action.NewServiceUpdateAction(installer.tmpDir), + }, actions) }) t.Run("Latest dir does not exist", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) installer := &Installer{ latestDir: filepath.Join("testdata", "non-existent-dir"), installDir: outDir, @@ -143,13 +219,24 @@ func TestInstallArtifacts(t *testing.T) { svc.On("Stop").Once().Return(nil) - err := installer.Install() + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err := installer.Install(rb) require.ErrorContains(t, err, "failed to install new files") + + require.Equal(t, []action.RollbackableAction{ + action.NewServiceStopAction(svc), + }, actions) }) t.Run("An artifact exists already as a folder", func(t *testing.T) { outDir := t.TempDir() svc := mocks.NewService(t) + rb := rb_mocks.NewActionAppender(t) installer := &Installer{ latestDir: filepath.Join("testdata", "example-install"), installDir: outDir, @@ -172,8 +259,39 @@ func TestInstallArtifacts(t *testing.T) { svc.On("Stop").Once().Return(nil) - err = installer.Install() + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) require.ErrorContains(t, err, "failed to install new files") + t.Logf("Error: %s", err) + + copyTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = false + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.tmpDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + require.Equal(t, []action.RollbackableAction{ + action.NewServiceStopAction(svc), + copyNestedTestTxtAction, + // copyTestTxtAction is appended even though it failed; This is because we don't know WHY it failed, so we should keep it and try a rollback anyways, + // in case it was actually a partial write. + copyTestTxtAction, + }, actions) }) } diff --git a/updater/internal/path/path.go b/updater/internal/path/path.go new file mode 100644 index 000000000..fdd8be108 --- /dev/null +++ b/updater/internal/path/path.go @@ -0,0 +1,46 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import "path/filepath" + +const ( + latestDirFragment = "latest" + rollbackDirFragment = "rollback" + serviceFileDirFragment = "install" + serviceFileBackupFilename = "backup.service" +) + +// LatestDirFromTempDir gets the path to the "latest" dir, where the new artifacts are, +// from the temporary directory +func LatestDirFromTempDir(tmpDir string) string { + return filepath.Join(tmpDir, latestDirFragment) +} + +// BackupDirFromTempDir gets the path to the "rollback" dir, where current artifacts are backed up, +// from the temporary directory +func BackupDirFromTempDir(tmpDir string) string { + return filepath.Join(tmpDir, rollbackDirFragment) +} + +// ServiceFileDir gets the directory of the service file definitions from the install dir +func ServiceFileDir(installBaseDir string) string { + return filepath.Join(installBaseDir, serviceFileDirFragment) +} + +// BackupServiceFile returns the full path to the backup service file from the service file directory path +func BackupServiceFile(serviceFileDir string) string { + return filepath.Join(serviceFileDir, serviceFileBackupFilename) +} diff --git a/updater/internal/path/path_darwin.go b/updater/internal/path/path_darwin.go new file mode 100644 index 000000000..566791e7b --- /dev/null +++ b/updater/internal/path/path_darwin.go @@ -0,0 +1,23 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +// DarwinInstallDir is the path to the install directory on Darwin. +const DarwinInstallDir = "/opt/observiq-otel-collector" + +// InstallDir returns the filepath to the install directory +func InstallDir() (string, error) { + return DarwinInstallDir, nil +} diff --git a/updater/internal/path/path_linux.go b/updater/internal/path/path_linux.go new file mode 100644 index 000000000..8f07d49df --- /dev/null +++ b/updater/internal/path/path_linux.go @@ -0,0 +1,23 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +// LinuxInstallDir is the install directory of the collector on linux. +const LinuxInstallDir = "/opt/observiq-otel-collector" + +// InstallDir returns the filepath to the install directory +func InstallDir() (string, error) { + return LinuxInstallDir, nil +} diff --git a/updater/internal/path/path_test.go b/updater/internal/path/path_test.go new file mode 100644 index 000000000..2a9b9032f --- /dev/null +++ b/updater/internal/path/path_test.go @@ -0,0 +1,40 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLatestDirFromTempDir(t *testing.T) { + require.Equal(t, filepath.Join("tmp", "latest"), LatestDirFromTempDir("tmp")) +} + +func TestBackupDirFromTempDir(t *testing.T) { + require.Equal(t, filepath.Join("tmp", "rollback"), BackupDirFromTempDir("tmp")) +} + +func TestServiceFileDir(t *testing.T) { + installDir := filepath.Join("tmp", "rollback") + require.Equal(t, filepath.Join(installDir, "install"), ServiceFileDir(installDir)) +} + +func TestBackupServiceFile(t *testing.T) { + serviceFileDir := filepath.Join("tmp", "rollback", "install") + require.Equal(t, filepath.Join(serviceFileDir, "backup.service"), BackupServiceFile(serviceFileDir)) +} diff --git a/updater/internal/path/path_windows.go b/updater/internal/path/path_windows.go new file mode 100644 index 000000000..44fca1e81 --- /dev/null +++ b/updater/internal/path/path_windows.go @@ -0,0 +1,53 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import ( + "fmt" + "log" + + "golang.org/x/sys/windows/registry" +) + +const defaultProductName = "observIQ Distro for OpenTelemetry Collector" + +// InstallDirFromRegistry gets the installation dir of the given product from the Windows Registry +func InstallDirFromRegistry(productName string) (string, error) { + // this key is created when installing using the MSI installer + keyPath := fmt.Sprintf(`Software\Microsoft\Windows\CurrentVersion\Uninstall\%s`, productName) + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ) + if err != nil { + return "", fmt.Errorf("failed to open registry key: %w", err) + } + defer func() { + err := key.Close() + if err != nil { + log.Default().Printf("InstallDirFromRegistry: failed to close registry key") + } + }() + + // This value ("InstallLocation") contains the path to the install folder. + val, _, err := key.GetStringValue("InstallLocation") + if err != nil { + return "", fmt.Errorf("failed to read install dir: %w", err) + } + + return val, nil +} + +// InstallDir returns the filepath to the install directory +func InstallDir() (string, error) { + return InstallDirFromRegistry(defaultProductName) +} diff --git a/updater/internal/rollback/mocks/action_appender.go b/updater/internal/rollback/mocks/action_appender.go new file mode 100644 index 000000000..bc6226cef --- /dev/null +++ b/updater/internal/rollback/mocks/action_appender.go @@ -0,0 +1,33 @@ +// Code generated by mockery v2.13.1. DO NOT EDIT. + +package mocks + +import ( + action "github.com/observiq/observiq-otel-collector/updater/internal/action" + mock "github.com/stretchr/testify/mock" +) + +// ActionAppender is an autogenerated mock type for the ActionAppender type +type ActionAppender struct { + mock.Mock +} + +// AppendAction provides a mock function with given fields: _a0 +func (_m *ActionAppender) AppendAction(_a0 action.RollbackableAction) { + _m.Called(_a0) +} + +type mockConstructorTestingTNewActionAppender interface { + mock.TestingT + Cleanup(func()) +} + +// NewActionAppender creates a new instance of ActionAppender. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewActionAppender(t mockConstructorTestingTNewActionAppender) *ActionAppender { + mock := &ActionAppender{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/rollback/rollback.go b/updater/internal/rollback/rollback.go new file mode 100644 index 000000000..e94220dd8 --- /dev/null +++ b/updater/internal/rollback/rollback.go @@ -0,0 +1,155 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rollback + +import ( + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + + "github.com/observiq/observiq-otel-collector/updater/internal/action" + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/service" +) + +// ActionAppender is an interface that allows actions to be appended to it. +//go:generate mockery --name ActionAppender --filename action_appender.go +type ActionAppender interface { + AppendAction(action action.RollbackableAction) +} + +// Rollbacker is a struct that records rollback information, +// and can use that information to perform a rollback. +type Rollbacker struct { + originalSvc service.Service + backupDir string + installDir string + tmpDir string + actions []action.RollbackableAction +} + +// NewRollbacker returns a new Rollbacker +func NewRollbacker(tmpDir string) (*Rollbacker, error) { + installDir, err := path.InstallDir() + if err != nil { + return nil, fmt.Errorf("failed to determine install dir: %w", err) + } + + return &Rollbacker{ + backupDir: path.BackupDirFromTempDir(tmpDir), + installDir: installDir, + tmpDir: tmpDir, + originalSvc: service.NewService(path.LatestDirFromTempDir(tmpDir)), + }, nil +} + +// AppendAction records the action that was performed, so that it may be undone later. +func (r *Rollbacker) AppendAction(action action.RollbackableAction) { + r.actions = append(r.actions, action) +} + +// Backup backs up the installDir to the rollbackDir +func (r Rollbacker) Backup() error { + // Remove any pre-existing backup + if err := os.RemoveAll(r.backupDir); err != nil { + return fmt.Errorf("failed to remove previous backup: %w", err) + } + + // Copy all the files in the install directory to the backup directory + if err := copyFiles(r.installDir, r.backupDir, r.tmpDir); err != nil { + return fmt.Errorf("failed to copy files to backup dir: %w", err) + } + + // Backup the service configuration so we can reload it in case of rollback + if err := r.originalSvc.Backup(path.ServiceFileDir(r.backupDir)); err != nil { + return fmt.Errorf("failed to backup service configuration: %w", err) + } + + return nil +} + +// Rollback performs a rollback by undoing all recorded actions. +func (r Rollbacker) Rollback() { + // We need to loop through the actions slice backwards, to roll back the actions in the correct order. + // e.g. if StartService was called last, we need to stop the service first, then rollback previous actions. + for i := len(r.actions) - 1; i >= 0; i-- { + action := r.actions[i] + if err := action.Rollback(); err != nil { + log.Default().Printf("Failed to run rollback option: %s", err) + } + } +} + +// copyFiles copies files from inputPath to output path, skipping tmpDir. +func copyFiles(inputPath, outputPath, tmpDir string) error { + absTmpDir, err := filepath.Abs(tmpDir) + if err != nil { + return fmt.Errorf("failed to get absolute path for temporary directory: %w", err) + } + + err = filepath.WalkDir(inputPath, func(inPath string, d fs.DirEntry, err error) error { + + fullPath, absErr := filepath.Abs(inPath) + if absErr != nil { + return fmt.Errorf("failed to determine absolute path of file: %w", err) + } + + switch { + case err != nil: + // if there was an error walking the directory, we want to bail out. + return err + case d.IsDir() && strings.HasPrefix(fullPath, absTmpDir): + // If this is the "tmp" directory, we want to skip copying this directory, + // since this folder is only for temporary files (and is where this binary is running right now) + return filepath.SkipDir + case d.IsDir(): + // Skip directories, we'll create them when we get a file in the directory. + return nil + } + + // We want the path relative to the directory we are walking in order to calculate where the file should be + // mirrored in the output directory. + relPath, err := filepath.Rel(inputPath, inPath) + if err != nil { + return err + } + + // use the relative path to get the outPath (where we should write the file), and + // to get the out directory (which we will create if it does not exist). + outPath := filepath.Join(outputPath, relPath) + outDir := filepath.Dir(outPath) + + if err := os.MkdirAll(outDir, 0750); err != nil { + return fmt.Errorf("failed to create dir: %w", err) + } + + // Fail if copying the input file to the output file would fail + if err := file.CopyFile(inPath, outPath, false); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil + }) + + if err != nil { + return fmt.Errorf("failed to walk latest dir: %w", err) + } + + return nil +} diff --git a/updater/internal/rollback/rollback_test.go b/updater/internal/rollback/rollback_test.go new file mode 100644 index 000000000..0179b673e --- /dev/null +++ b/updater/internal/rollback/rollback_test.go @@ -0,0 +1,153 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rollback + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "testing" + + action_mocks "github.com/observiq/observiq-otel-collector/updater/internal/action/mocks" + service_mocks "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRollbackerBackup(t *testing.T) { + t.Run("Successfully backs up everything", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + + svc := service_mocks.NewService(t) + svc.On("Backup", filepath.Join(outDir, "install")).Return(nil) + + rb := &Rollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + tmpDir: filepath.Join(installDir, "tmp-dir"), + } + + err := rb.Backup() + require.NoError(t, err) + + require.FileExists(t, filepath.Join(outDir, "some-file.txt")) + require.FileExists(t, filepath.Join(outDir, "plugins-dir", "plugin.txt")) + require.NoDirExists(t, filepath.Join(outDir, "tmp-dir")) + }) + + t.Run("Service backup fails", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + + svc := service_mocks.NewService(t) + svc.On("Backup", filepath.Join(outDir, "install")).Return(fmt.Errorf("invalid permissions")) + + rb := &Rollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + tmpDir: filepath.Join(installDir, "tmp-dir"), + } + + err := rb.Backup() + require.ErrorContains(t, err, "failed to backup service configuration") + }) + + t.Run("Removes pre-existing backup", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + leftoverFile := filepath.Join(outDir, "leftover-file.txt") + + svc := service_mocks.NewService(t) + svc.On("Backup", filepath.Join(outDir, "install")).Return(nil) + + err := os.WriteFile(leftoverFile, []byte("leftover file"), 0600) + require.NoError(t, err) + + rb := &Rollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + tmpDir: filepath.Join(installDir, "tmp-dir"), + } + + err = rb.Backup() + require.NoError(t, err) + + require.FileExists(t, filepath.Join(outDir, "some-file.txt")) + require.FileExists(t, filepath.Join(outDir, "plugins-dir", "plugin.txt")) + require.NoDirExists(t, filepath.Join(outDir, "tmp-dir")) + require.NoFileExists(t, leftoverFile) + }) +} + +func TestRollbackerRollback(t *testing.T) { + t.Run("Runs rollback actions in the correct order", func(t *testing.T) { + seq := 0 + + rb := &Rollbacker{} + + for i := 0; i < 10; i++ { + actionNum := i + action := action_mocks.NewRollbackableAction(t) + action.On("Rollback").Run(func(args mock.Arguments) { + // Rollback should be done in reverse order; So action 0 + // should be done last (10th action, seq == 9), while + // the last action (action 9) should be done first (seq == 0) + expectedSeq := 10 - actionNum - 1 + assert.Equal(t, expectedSeq, seq, "Expected action %d to occur at sequence %d", seq, expectedSeq) + seq++ + }).Return(nil) + + rb.AppendAction(action) + } + + rb.Rollback() + }) + + t.Run("Continues despite rollback errors", func(t *testing.T) { + seq := 0 + + rb := &Rollbacker{} + + for i := 0; i < 10; i++ { + actionNum := i + action := action_mocks.NewRollbackableAction(t) + + call := action.On("Rollback").Run(func(args mock.Arguments) { + // Rollback should be done in reverse order; So action 0 + // should be done last (10th action, seq == 9), while + // the last action (action 9) should be done first (seq == 0) + expectedSeq := 10 - actionNum - 1 + assert.Equal(t, expectedSeq, seq, "Expected action %d to occur at sequence %d", seq, expectedSeq) + seq++ + }) + + if actionNum == 5 { + call.Return(errors.New("failed to rollback")) + } else { + call.Return(nil) + } + + rb.AppendAction(action) + } + + rb.Rollback() + }) +} diff --git a/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt b/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt new file mode 100644 index 000000000..c47f0348d --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt @@ -0,0 +1 @@ +This is a test file for copying diff --git a/updater/internal/rollback/testdata/rollbacker/some-file.txt b/updater/internal/rollback/testdata/rollbacker/some-file.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/some-file.txt @@ -0,0 +1 @@ +This is a test file diff --git a/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt b/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt new file mode 100644 index 000000000..f594928cb --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt @@ -0,0 +1 @@ +This file should not be copied, because it is in the tmp-dir diff --git a/updater/internal/install/mocks/service.go b/updater/internal/service/mocks/service.go similarity index 81% rename from updater/internal/install/mocks/service.go rename to updater/internal/service/mocks/service.go index 001459bce..c00ed0921 100644 --- a/updater/internal/install/mocks/service.go +++ b/updater/internal/service/mocks/service.go @@ -9,13 +9,13 @@ type Service struct { mock.Mock } -// Install provides a mock function with given fields: -func (_m *Service) Install() error { - ret := _m.Called() +// Backup provides a mock function with given fields: outDir +func (_m *Service) Backup(outDir string) error { + ret := _m.Called(outDir) var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(outDir) } else { r0 = ret.Error(0) } @@ -51,8 +51,8 @@ func (_m *Service) Stop() error { return r0 } -// Uninstall provides a mock function with given fields: -func (_m *Service) Uninstall() error { +// Update provides a mock function with given fields: +func (_m *Service) Update() error { ret := _m.Called() var r0 error diff --git a/updater/internal/install/service.go b/updater/internal/service/service.go similarity index 82% rename from updater/internal/install/service.go rename to updater/internal/service/service.go index a9868fbfa..ae4a1e491 100644 --- a/updater/internal/install/service.go +++ b/updater/internal/service/service.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package install +package service import ( "bytes" @@ -20,6 +20,7 @@ import ( "path/filepath" ) +//go:generate mockery --name Service --filename service.go // Service represents a controllable service type Service interface { // Start the service @@ -28,11 +29,11 @@ type Service interface { // Stop the service Stop() error - // Installs the service - Install() error + // Updates the old service configuration to the new one + Update() error - // Uninstalls the service - Uninstall() error + // Backup backs the current service configuration to the given directory + Backup(outDir string) error } // replaceInstallDir replaces "[INSTALLDIR]" with the given installDir string. diff --git a/updater/internal/install/service_darwin.go b/updater/internal/service/service_darwin.go similarity index 67% rename from updater/internal/install/service_darwin.go rename to updater/internal/service/service_darwin.go index c8537ebf0..4883142fc 100644 --- a/updater/internal/install/service_darwin.go +++ b/updater/internal/service/service_darwin.go @@ -14,27 +14,45 @@ //go:build darwin -package install +package service import ( "fmt" "os" "os/exec" "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" ) const ( darwinServiceFilePath = "/Library/LaunchDaemons/com.observiq.collector.plist" - darwinInstallDir = "/opt/observiq-otel-collector" ) -// newService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. -func newService(latestPath string) Service { - return &darwinService{ - newServiceFilePath: filepath.Join(latestPath, "install", "com.observiq.collector.plist"), +// Option is an extra option for creating a Service +type Option func(darwinSvc *darwinService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(darwinSvc *darwinService) { + darwinSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(latestPath string, opts ...Option) Service { + darwinSvc := &darwinService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(latestPath), "com.observiq.collector.plist"), installedServiceFilePath: darwinServiceFilePath, - installDir: darwinInstallDir, + installDir: path.DarwinInstallDir, } + + for _, opt := range opts { + opt(darwinSvc) + } + + return darwinSvc } type darwinService struct { @@ -79,7 +97,7 @@ func (d darwinService) Stop() error { } // Installs the service -func (d darwinService) Install() error { +func (d darwinService) install() error { serviceFileBytes, err := os.ReadFile(d.newServiceFilePath) if err != nil { return fmt.Errorf("failed to open input file: %w", err) @@ -94,8 +112,7 @@ func (d darwinService) Install() error { } // Uninstalls the service -func (d darwinService) Uninstall() error { - //#nosec G204 -- installedServiceFilePath is not determined by user input +func (d darwinService) uninstall() error { if err := d.Stop(); err != nil { return err } @@ -107,7 +124,22 @@ func (d darwinService) Uninstall() error { return nil } -// InstallDir returns the filepath to the install directory -func installDir() (string, error) { - return darwinInstallDir, nil +func (d darwinService) Update() error { + if err := d.uninstall(); err != nil { + return fmt.Errorf("failed to uninstall old service: %w", err) + } + + if err := d.install(); err != nil { + return fmt.Errorf("failed to install new service: %w", err) + } + + return nil +} + +func (d darwinService) Backup(outDir string) error { + if err := file.CopyFile(d.installedServiceFilePath, path.BackupServiceFile(outDir), false); err != nil { + return fmt.Errorf("failed to copy service file: %w", err) + } + + return nil } diff --git a/updater/internal/install/service_darwin_test.go b/updater/internal/service/service_darwin_test.go similarity index 68% rename from updater/internal/install/service_darwin_test.go rename to updater/internal/service/service_darwin_test.go index 325e41881..10a38fb84 100644 --- a/updater/internal/install/service_darwin_test.go +++ b/updater/internal/service/service_darwin_test.go @@ -14,7 +14,7 @@ //go:build darwin && integration -package install +package service import ( "os" @@ -23,6 +23,7 @@ import ( "regexp" "testing" + "github.com/observiq/observiq-otel-collector/updater/internal/path" "github.com/stretchr/testify/require" ) @@ -37,14 +38,14 @@ func TestDarwinServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := d.Install() + err := d.install() require.NoError(t, err) require.FileExists(t, installedServicePath) // We want to check that the service was actually loaded requireServiceLoadedStatus(t, true) - err = d.Uninstall() + err = d.uninstall() require.NoError(t, err) require.NoFileExists(t, installedServicePath) @@ -63,7 +64,7 @@ func TestDarwinServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := d.Install() + err := d.install() require.NoError(t, err) require.FileExists(t, installedServicePath) @@ -80,7 +81,7 @@ func TestDarwinServiceInstall(t *testing.T) { requireServiceLoadedStatus(t, false) - err = d.Uninstall() + err = d.uninstall() require.NoError(t, err) require.NoFileExists(t, installedServicePath) @@ -98,7 +99,7 @@ func TestDarwinServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := d.Install() + err := d.install() require.ErrorContains(t, err, "failed to open input file") requireServiceLoadedStatus(t, false) }) @@ -113,7 +114,7 @@ func TestDarwinServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := d.Install() + err := d.install() require.ErrorContains(t, err, "failed to write service file") requireServiceLoadedStatus(t, false) }) @@ -128,7 +129,7 @@ func TestDarwinServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := d.Uninstall() + err := d.uninstall() require.ErrorContains(t, err, "failed to stat installed service file") requireServiceLoadedStatus(t, false) }) @@ -160,6 +161,87 @@ func TestDarwinServiceInstall(t *testing.T) { err := d.Stop() require.ErrorContains(t, err, "failed to stat installed service file") }) + + t.Run("Backup installed service succeeds", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + serviceFileContents, err := os.ReadFile(newServiceFile) + require.NoError(t, err) + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + } + + err = d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + require.NoError(t, d.Stop()) + + backupServiceDir := t.TempDir() + err = d.Backup(backupServiceDir) + require.NoError(t, err) + require.FileExists(t, path.BackupServiceFile(backupServiceDir)) + + backupServiceContents, err := os.ReadFile(path.BackupServiceFile(backupServiceDir)) + + require.Equal(t, serviceFileContents, backupServiceContents) + require.NoError(t, d.uninstall()) + }) + + t.Run("Backup installed service fails if not installed", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + } + + backupServiceDir := t.TempDir() + err := d.Backup(backupServiceDir) + require.ErrorContains(t, err, "failed to copy service file") + }) + + t.Run("Backup installed service fails if output file already exists", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + require.NoError(t, d.Stop()) + + backupServiceDir := t.TempDir() + // Write the backup file before creating it; Backup should + // not ever overwrite an existing file + os.WriteFile(path.BackupServiceFile(backupServiceDir), []byte("file exists"), 0600) + + err = d.Backup(backupServiceDir) + require.ErrorContains(t, err, "failed to copy service file") + }) } // uninstallService is a helper that uninstalls the service manually for test setup, in case it is somehow leftover. diff --git a/updater/internal/install/service_linux.go b/updater/internal/service/service_linux.go similarity index 70% rename from updater/internal/install/service_linux.go rename to updater/internal/service/service_linux.go index f96d0386f..5dd23f652 100644 --- a/updater/internal/install/service_linux.go +++ b/updater/internal/service/service_linux.go @@ -14,7 +14,7 @@ //go:build linux -package install +package service import ( "fmt" @@ -23,18 +23,37 @@ import ( "os" "os/exec" "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" ) const linuxServiceName = "observiq-otel-collector" const linuxServiceFilePath = "/usr/lib/systemd/system/observiq-otel-collector.service" -// newService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. -func newService(latestPath string) Service { - return linuxService{ - newServiceFilePath: filepath.Join(latestPath, "install", "observiq-otel-collector.service"), +// Option is an extra option for creating a Service +type Option func(linuxSvc *linuxService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(linuxSvc *linuxService) { + linuxSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(latestPath string, opts ...Option) Service { + linuxSvc := &linuxService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(latestPath), "observiq-otel-collector.service"), serviceName: linuxServiceName, installedServiceFilePath: linuxServiceFilePath, } + + for _, opt := range opts { + opt(linuxSvc) + } + + return linuxSvc } type linuxService struct { @@ -66,8 +85,8 @@ func (l linuxService) Stop() error { return nil } -// Installs the service -func (l linuxService) Install() error { +// installs the service +func (l linuxService) install() error { inFile, err := os.Open(l.newServiceFilePath) if err != nil { return fmt.Errorf("failed to open input file: %w", err) @@ -108,8 +127,8 @@ func (l linuxService) Install() error { return nil } -// Uninstalls the service -func (l linuxService) Uninstall() error { +// uninstalls the service +func (l linuxService) uninstall() error { //#nosec G204 -- serviceName is not determined by user input cmd := exec.Command("systemctl", "disable", l.serviceName) if err := cmd.Run(); err != nil { @@ -128,7 +147,22 @@ func (l linuxService) Uninstall() error { return nil } -// installDir returns the filepath to the install directory -func installDir() (string, error) { - return "/opt/observiq-otel-collector", nil +func (l linuxService) Update() error { + if err := l.uninstall(); err != nil { + return fmt.Errorf("failed to uninstall old service: %w", err) + } + + if err := l.install(); err != nil { + return fmt.Errorf("failed to install new service: %w", err) + } + + return nil +} + +func (l linuxService) Backup(outDir string) error { + if err := file.CopyFile(l.installedServiceFilePath, path.BackupServiceFile(outDir), false); err != nil { + return fmt.Errorf("failed to copy service file: %w", err) + } + + return nil } diff --git a/updater/internal/install/service_linux_test.go b/updater/internal/service/service_linux_test.go similarity index 70% rename from updater/internal/install/service_linux_test.go rename to updater/internal/service/service_linux_test.go index 99ef1f744..2ac5b69bb 100644 --- a/updater/internal/install/service_linux_test.go +++ b/updater/internal/service/service_linux_test.go @@ -15,7 +15,7 @@ // an elevated user is needed to run the service tests //go:build linux && integration -package install +package service import ( "os" @@ -23,6 +23,7 @@ import ( "path/filepath" "testing" + "github.com/observiq/observiq-otel-collector/updater/internal/path" "github.com/stretchr/testify/require" ) @@ -38,14 +39,14 @@ func TestLinuxServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := l.Install() + err := l.install() require.NoError(t, err) require.FileExists(t, installedServicePath) //We want to check that the service was actually loaded requireServiceLoadedStatus(t, true) - err = l.Uninstall() + err = l.uninstall() require.NoError(t, err) require.NoFileExists(t, installedServicePath) @@ -63,7 +64,7 @@ func TestLinuxServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := l.Install() + err := l.install() require.NoError(t, err) require.FileExists(t, installedServicePath) @@ -80,7 +81,7 @@ func TestLinuxServiceInstall(t *testing.T) { requireServiceRunningStatus(t, false) - err = l.Uninstall() + err = l.uninstall() require.NoError(t, err) require.NoFileExists(t, installedServicePath) @@ -98,7 +99,7 @@ func TestLinuxServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := l.Install() + err := l.install() require.ErrorContains(t, err, "failed to open input file") requireServiceLoadedStatus(t, false) }) @@ -113,7 +114,7 @@ func TestLinuxServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := l.Install() + err := l.install() require.ErrorContains(t, err, "failed to open output file") requireServiceLoadedStatus(t, false) }) @@ -128,7 +129,7 @@ func TestLinuxServiceInstall(t *testing.T) { installedServiceFilePath: installedServicePath, } - err := l.Uninstall() + err := l.uninstall() require.ErrorContains(t, err, "failed to disable unit") requireServiceLoadedStatus(t, false) }) @@ -160,6 +161,83 @@ func TestLinuxServiceInstall(t *testing.T) { err := l.Stop() require.ErrorContains(t, err, "running systemctl failed") }) + + t.Run("Backup installed service succeeds", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + serviceFileContents, err := os.ReadFile(newServiceFile) + require.NoError(t, err) + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + } + + err = d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + backupServiceDir := t.TempDir() + err = d.Backup(backupServiceDir) + require.NoError(t, err) + require.FileExists(t, path.BackupServiceFile(backupServiceDir)) + + backupServiceContents, err := os.ReadFile(path.BackupServiceFile(backupServiceDir)) + + require.Equal(t, serviceFileContents, backupServiceContents) + require.NoError(t, d.uninstall()) + }) + + t.Run("Backup installed service fails if not installed", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + } + + backupServiceDir := t.TempDir() + err := d.Backup(backupServiceDir) + require.ErrorContains(t, err, "failed to copy service file") + }) + + t.Run("Backup installed service fails if output file already exists", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + backupServiceDir := t.TempDir() + // Write the backup file before creating it; Backup should + // not ever overwrite an existing file + os.WriteFile(path.BackupServiceFile(backupServiceDir), []byte("file exists"), 0600) + + err = d.Backup(backupServiceDir) + require.ErrorContains(t, err, "failed to copy service file") + }) } // uninstallService is a helper that uninstalls the service manually for test setup, in case it is somehow leftover. diff --git a/updater/internal/install/service_test.go b/updater/internal/service/service_test.go similarity index 98% rename from updater/internal/install/service_test.go rename to updater/internal/service/service_test.go index 4cd8ce5a5..3d5be7a4e 100644 --- a/updater/internal/install/service_test.go +++ b/updater/internal/service/service_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package install +package service import ( "os" diff --git a/updater/internal/install/service_windows.go b/updater/internal/service/service_windows.go similarity index 58% rename from updater/internal/install/service_windows.go rename to updater/internal/service/service_windows.go index c65a80980..3defcad16 100644 --- a/updater/internal/install/service_windows.go +++ b/updater/internal/service/service_windows.go @@ -14,7 +14,7 @@ //go:build windows -package install +package service import ( "encoding/json" @@ -25,11 +25,11 @@ import ( "strings" "time" - "golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" "github.com/kballard/go-shellquote" + "github.com/observiq/observiq-otel-collector/updater/internal/path" ) const ( @@ -39,13 +39,29 @@ const ( serviceNotExistErrStr = "The specified service does not exist as an installed service." ) -// newService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. -func newService(latestPath string) Service { - return &windowsService{ - newServiceFilePath: filepath.Join(latestPath, "install", "windows_service.json"), +// Option is an extra option for creating a Service +type Option func(winSvc *windowsService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(winSvc *windowsService) { + winSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(latestPath string, opts ...Option) Service { + winSvc := &windowsService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(latestPath), "windows_service.json"), serviceName: defaultServiceName, productName: defaultProductName, } + + for _, opt := range opts { + opt(winSvc) + } + + return winSvc } type windowsService struct { @@ -100,7 +116,7 @@ func (w windowsService) Stop() error { } // Installs the service -func (w windowsService) Install() error { +func (w windowsService) install() error { // parse the service definition from disk wsc, err := readWindowsServiceConfig(w.newServiceFilePath) if err != nil { @@ -108,13 +124,13 @@ func (w windowsService) Install() error { } // fetch the install directory so that we can determine the binary path that we need to execute - iDir, err := installDirFromRegistry(w.productName) + iDir, err := path.InstallDirFromRegistry(w.productName) if err != nil { return fmt.Errorf("failed to get install dir: %w", err) } // expand the arguments to be properly formatted (expand [INSTALLDIR], clean '"' to be '"') - expandArguments(wsc, w.productName, iDir) + expandArguments(wsc, iDir) // Split the arguments; Arguments are "shell-like", in that they may contain spaces, and can be quoted to indicate that. splitArgs, err := shellquote.Split(wsc.Service.Arguments) @@ -123,7 +139,7 @@ func (w windowsService) Install() error { } // Get the start type - startType, delayed, err := startType(wsc.Service.Start) + startType, delayed, err := winapiStartType(wsc.Service.Start) if err != nil { return fmt.Errorf("failed to parse start type in service config: %w", err) } @@ -154,7 +170,7 @@ func (w windowsService) Install() error { } // Uninstalls the service -func (w windowsService) Uninstall() error { +func (w windowsService) uninstall() error { m, err := mgr.Connect() if err != nil { return fmt.Errorf("failed to connect to service manager: %w", err) @@ -202,6 +218,50 @@ func (w windowsService) Uninstall() error { return nil } +func (w windowsService) Update() error { + if err := w.uninstall(); err != nil { + return fmt.Errorf("failed to uninstall old service: %w", err) + } + + if err := w.install(); err != nil { + return fmt.Errorf("failed to install new service: %w", err) + } + + return nil +} + +func (w windowsService) Backup(outDir string) error { + + wsc, err := w.currentServiceConfig() + if err != nil { + return fmt.Errorf("failed to construct service config: %w", err) + } + + // Marshal config as json + wscBytes, err := json.Marshal(wsc) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + // Open with O_EXCL to fail if the file already exists + f, err := os.OpenFile(path.BackupServiceFile(outDir), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("failed to create backup service file: %w", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Default().Printf("windowsService.Backup: Failed to close backup service file: %s", err) + } + }() + + // finally, write the config out so we can rollback. + if _, err := f.Write(wscBytes); err != nil { + return fmt.Errorf("failed to write backup service config: %w", err) + } + + return nil +} + // windowsServiceConfig defines how the service should be configured, including the entrypoint for the service. type windowsServiceConfig struct { // Path is the file that will be executed for the service. It is relative to the install directory. @@ -226,7 +286,8 @@ type windowsServiceDefinitionConfig struct { // readWindowsServiceConfig reads the service config from the file at the given path func readWindowsServiceConfig(path string) (*windowsServiceConfig, error) { - b, err := os.ReadFile(path) + cleanPath := filepath.Clean(path) + b, err := os.ReadFile(cleanPath) if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } @@ -241,39 +302,94 @@ func readWindowsServiceConfig(path string) (*windowsServiceConfig, error) { } // expandArguments expands [INSTALLDIR] to the actual install directory and -// expands '"e;' to the literal '"' -func expandArguments(wsc *windowsServiceConfig, productName, installDir string) { +// expands '"' to the literal '"' +func expandArguments(wsc *windowsServiceConfig, installDir string) { wsc.Service.Arguments = string(replaceInstallDir([]byte(wsc.Service.Arguments), installDir)) wsc.Service.Arguments = strings.ReplaceAll(wsc.Service.Arguments, """, `"`) } -// installDirFromRegistry gets the installation dir of the given product from the Windows Registry -func installDirFromRegistry(productName string) (string, error) { - // this key is created when installing using the MSI installer - keyPath := fmt.Sprintf(`Software\Microsoft\Windows\CurrentVersion\Uninstall\%s`, productName) - key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ) +func (w windowsService) currentServiceConfig() (*windowsServiceConfig, error) { + m, err := mgr.Connect() if err != nil { - return "", fmt.Errorf("failed to open registry key: %w", err) + return nil, fmt.Errorf("failed to connect to service manager: %w", err) + } + defer m.Disconnect() + + s, err := m.OpenService(w.serviceName) + if err != nil { + return nil, fmt.Errorf("failed to open service: %w", err) + } + defer s.Close() + + // Get the current config of the service + conf, err := s.Config() + if err != nil { + return nil, fmt.Errorf("failed to get service config: %w", err) + } + + fullBinaryPath, argString, err := splitServiceBinaryName(conf.BinaryPathName) + if err != nil { + return nil, fmt.Errorf("failed to split service BinaryPathName: %w", err) + } + + iDir, err := path.InstallDirFromRegistry(w.productName) + if err != nil { + return nil, fmt.Errorf("failed to get install dir: %w", err) } - defer func() { - err := key.Close() - if err != nil { - log.Default().Printf("installDirFromRegistry: failed to close registry key") - } - }() - // This value ("InstallLocation") contains the path to the install folder. - val, _, err := key.GetStringValue("InstallLocation") + // In the original config, the Path is the main binary path, relative to the install directory. + binaryPath, err := filepath.Rel(iDir, fullBinaryPath) if err != nil { - return "", fmt.Errorf("failed to read install dir: %w", err) + return nil, fmt.Errorf("could not find service exe relative to install dir: %w", err) } - return val, nil + // Convert windows api start type to the config file service type + confStartType, err := configStartType(conf.StartType, conf.DelayedAutoStart) + if err != nil { + return nil, fmt.Errorf("failed to get start type: %w", err) + } + + // Construct the config + return &windowsServiceConfig{ + Path: binaryPath, + Service: windowsServiceDefinitionConfig{ + Start: confStartType, + DisplayName: conf.DisplayName, + Description: conf.Description, + Arguments: argString, + }, + }, nil } -// startType converts the start type from the windowsServiceConfig to a start type recognizable by the windows +func splitServiceBinaryName(binaryPathName string) (binaryPath, argString string, err error) { + // Split the service arguments into an array of arguments + args, err := shellquote.Split(binaryPathName) + if err != nil { + return "", "", fmt.Errorf("failed to split service config args: %w", err) + } + + // The first argument is always the binary name; If the length of the array is 0, we know this is an invalid argument list. + if len(args) < 1 { + return "", "", fmt.Errorf("no binary specified in service config") + } + + // The absolute path to the binary is the first argument + binaryPath = args[0] + + // Stored argument string doesn't include the binary path (first arg) + args = args[1:] + + // Args should end up being a string, where literal quotes are """ + argString = shellquote.Join(args...) + // shellquote uses ' to quote, so we convert those to """ + argString = strings.ReplaceAll(argString, "'", """) + + return binaryPath, argString, nil +} + +// winapiStartType converts the start type from the windowsServiceConfig to a start type recognizable by the windows // service API -func startType(cfgStartType string) (startType uint32, delayed bool, err error) { +func winapiStartType(cfgStartType string) (startType uint32, delayed bool, err error) { switch cfgStartType { case "auto": // Automatically starts on system bootup. @@ -294,7 +410,18 @@ func startType(cfgStartType string) (startType uint32, delayed bool, err error) return } -// installDir returns the filepath to the install directory -func installDir() (string, error) { - return installDirFromRegistry(defaultProductName) +func configStartType(winapiStartType uint32, delayed bool) (string, error) { + switch winapiStartType { + case mgr.StartAutomatic: + if delayed { + return "delayed", nil + } + return "auto", nil + case mgr.StartDisabled: + return "disabled", nil + case mgr.StartManual: + return "manual", nil + default: + return "", fmt.Errorf("invalid winapi start type: %d", winapiStartType) + } } diff --git a/updater/internal/install/service_windows_test.go b/updater/internal/service/service_windows_test.go similarity index 83% rename from updater/internal/install/service_windows_test.go rename to updater/internal/service/service_windows_test.go index d58c4a1f9..0ae0eadfa 100644 --- a/updater/internal/install/service_windows_test.go +++ b/updater/internal/service/service_windows_test.go @@ -15,7 +15,7 @@ // an elevated user is needed to run the service tests //go:build windows && integration -package install +package service import ( "fmt" @@ -28,6 +28,7 @@ import ( "golang.org/x/sys/windows/registry" + "github.com/observiq/observiq-otel-collector/updater/internal/path" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc" @@ -57,7 +58,7 @@ func TestWindowsServiceInstall(t *testing.T) { productName: testProductName, } - err = w.Install() + err = w.install() require.NoError(t, err) //We want to check that the service was actually loaded @@ -76,7 +77,7 @@ func TestWindowsServiceInstall(t *testing.T) { }, ) - err = w.Uninstall() + err = w.uninstall() require.NoError(t, err) //Make sure the service is no longer listed @@ -106,7 +107,7 @@ func TestWindowsServiceInstall(t *testing.T) { productName: testProductName, } - err = w.Install() + err = w.install() require.NoError(t, err) //We want to check that the service was actually loaded @@ -125,7 +126,7 @@ func TestWindowsServiceInstall(t *testing.T) { }, ) - err = w.Uninstall() + err = w.uninstall() require.NoError(t, err) //Make sure the service is no longer listed @@ -154,7 +155,7 @@ func TestWindowsServiceInstall(t *testing.T) { productName: testProductName, } - err = w.Install() + err = w.install() require.NoError(t, err) // We want to check that the service was actually loaded @@ -170,7 +171,7 @@ func TestWindowsServiceInstall(t *testing.T) { requireServiceRunningStatus(t, false) - err = w.Uninstall() + err = w.uninstall() require.NoError(t, err) // Make sure the service is no longer listed @@ -199,7 +200,7 @@ func TestWindowsServiceInstall(t *testing.T) { productName: testProductName, } - err = w.Install() + err = w.install() require.ErrorContains(t, err, "The system cannot find the file specified.") requireServiceLoadedStatus(t, false) }) @@ -216,7 +217,7 @@ func TestWindowsServiceInstall(t *testing.T) { productName: testProductName, } - err := w.Uninstall() + err := w.uninstall() require.ErrorContains(t, err, "failed to open service") requireServiceLoadedStatus(t, false) }) @@ -252,6 +253,77 @@ func TestWindowsServiceInstall(t *testing.T) { err := w.Stop() require.ErrorContains(t, err, "failed to open service") }) + + t.Run("Test backup works", func(t *testing.T) { + tempDir := t.TempDir() + installDir := filepath.Join(tempDir, "install directory") + backupDir := filepath.Join(tempDir, "backup") + + require.NoError(t, os.MkdirAll(installDir, 0775)) + require.NoError(t, os.MkdirAll(backupDir, 0775)) + + testProductName := "Test Product" + + serviceJSON := filepath.Join(installDir, "windows-service.json") + testServiceProgram := filepath.Join(installDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + defer uninstallService(t) + createInstallDirRegistryKey(t, testProductName, installDir) + defer deleteInstallDirRegistryKey(t, testProductName) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: "windows-service", + productName: testProductName, + } + + err = w.install() + require.NoError(t, err) + + //We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + requireServiceConfigMatches(t, + testServiceProgram, + "windows-service", + mgr.StartAutomatic, + "Test Windows Service", + "This is a windows service to test", + true, + []string{ + "--config", + filepath.Join(installDir, "test.yaml"), + }, + ) + + // Take a backup; Assert the backup makes sense. + // It will not be the same as the original service file due to expansion of INSTALLDIR + // which is OK and expected. + err = w.Backup(backupDir) + require.NoError(t, err) + + backupSvcFile := path.BackupServiceFile(backupDir) + + svcCfg, err := readWindowsServiceConfig(backupSvcFile) + require.NoError(t, err) + + assert.Equal(t, &windowsServiceConfig{ + Path: "windows-service.exe", + Service: windowsServiceDefinitionConfig{ + Start: "delayed", + DisplayName: "Test Windows Service", + Description: "This is a windows service to test", + Arguments: fmt.Sprintf("--config "%s"", filepath.Join(installDir, "test.yaml")), + }, + }, svcCfg) + + err = w.uninstall() + }) } func TestStartType(t *testing.T) { @@ -289,7 +361,7 @@ func TestStartType(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("cfgStartType: %s", tc.cfgStartType), func(t *testing.T) { - st, d, err := startType(tc.cfgStartType) + st, d, err := winapiStartType(tc.cfgStartType) if tc.expectedErr != "" { require.ErrorContains(t, err, tc.expectedErr) } else { diff --git a/updater/internal/install/testdata/darwin-service.plist b/updater/internal/service/testdata/darwin-service.plist similarity index 100% rename from updater/internal/install/testdata/darwin-service.plist rename to updater/internal/service/testdata/darwin-service.plist diff --git a/updater/internal/install/testdata/linux-service.service b/updater/internal/service/testdata/linux-service.service similarity index 100% rename from updater/internal/install/testdata/linux-service.service rename to updater/internal/service/testdata/linux-service.service diff --git a/updater/internal/install/testdata/test-windows-service.go b/updater/internal/service/testdata/test-windows-service.go similarity index 100% rename from updater/internal/install/testdata/test-windows-service.go rename to updater/internal/service/testdata/test-windows-service.go diff --git a/updater/internal/install/testdata/windows-service.json b/updater/internal/service/testdata/windows-service.json similarity index 100% rename from updater/internal/install/testdata/windows-service.json rename to updater/internal/service/testdata/windows-service.json diff --git a/updater/internal/state/mocks/mock_monitor.go b/updater/internal/state/mocks/mock_monitor.go new file mode 100644 index 000000000..5d772548e --- /dev/null +++ b/updater/internal/state/mocks/mock_monitor.go @@ -0,0 +1,55 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + protobufs "github.com/open-telemetry/opamp-go/protobufs" + mock "github.com/stretchr/testify/mock" + + testing "testing" +) + +// MockMonitor is an autogenerated mock type for the Monitor type +type MockMonitor struct { + mock.Mock +} + +// MonitorForSuccess provides a mock function with given fields: ctx, packageName +func (_m *MockMonitor) MonitorForSuccess(ctx context.Context, packageName string) error { + ret := _m.Called(ctx, packageName) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, packageName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetState provides a mock function with given fields: packageName, status, statusErr +func (_m *MockMonitor) SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error { + ret := _m.Called(packageName, status, statusErr) + + var r0 error + if rf, ok := ret.Get(0).(func(string, protobufs.PackageStatus_Status, error) error); ok { + r0 = rf(packageName, status, statusErr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockMonitor creates a new instance of MockMonitor. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockMonitor(t testing.TB) *MockMonitor { + mock := &MockMonitor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/state/monitor.go b/updater/internal/state/monitor.go new file mode 100644 index 000000000..11ef1fd98 --- /dev/null +++ b/updater/internal/state/monitor.go @@ -0,0 +1,138 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package state contains structures to monitor and update the state of the collector in the package status +package state + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "time" + + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +var ( + // ErrFailedStatus is the error when the Package status indicates a failure + ErrFailedStatus = errors.New("package status indicates failure") +) + +// Monitor allows checking and setting state of active install +type Monitor interface { + // SetState sets the state for the package. + // If passed in statusErr is not nil it will record the error as the message + SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error + + // MonitorForSuccess will periodically check the state of the package. It will keep checking until the context is canceled or a failed/success state is detected. + // It will return an error if status is Failed or if the context times out. + MonitorForSuccess(ctx context.Context, packageName string) error +} + +// CollectorMonitor implements Monitor interface for monitoring the Collector Package Status file +type CollectorMonitor struct { + stateManager packagestate.StateManager + currentStatus *protobufs.PackageStatuses +} + +// NewCollectorMonitor create a new Monitor specifically for the collector +func NewCollectorMonitor(logger *zap.Logger) (Monitor, error) { + // Get install directory + installDir, err := path.InstallDir() + if err != nil { + return nil, fmt.Errorf("failed to determine install directory: %w", err) + } + + // Create a collector monitor + packageStatusPath := filepath.Join(installDir, packagestate.DefaultFileName) + collectorMonitor := &CollectorMonitor{ + stateManager: packagestate.NewFileStateManager(logger, packageStatusPath), + } + + // Load the current status to ensure the package status file exists + collectorMonitor.currentStatus, err = collectorMonitor.stateManager.LoadStatuses() + if err != nil { + return nil, fmt.Errorf("failed to load package statues: %w", err) + } + + return collectorMonitor, nil + +} + +// SetState sets the status on the specified package and saves it to the package status file +func (c *CollectorMonitor) SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error { + // Verify we have package by that name + targetPackage, ok := c.currentStatus.GetPackages()[packageName] + if !ok { + return fmt.Errorf("no package for name %s", packageName) + } + + // Update the status + targetPackage.Status = status + + // If that passed in error is not nil set it as the error message + if statusErr != nil { + targetPackage.ErrorMessage = statusErr.Error() + } + + c.currentStatus.GetPackages()[packageName] = targetPackage + + // Save to updated status to disk + return c.stateManager.SaveStatuses(c.currentStatus) +} + +// MonitorForSuccess intermittently checks the package status file for either an install failed or success status. +// If an InstallFailed status is read this returns ErrFailedStatus error. +// If the context is canceled the context error will be returned. +func (c *CollectorMonitor) MonitorForSuccess(ctx context.Context, packageName string) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + packageStatus, err := c.stateManager.LoadStatuses() + switch { + // If there is any error we'll just continue. Some valid reasons we could error and should retry: + // - File was deleted by new collector before it's rewritten + // - File is being written to while we're reading it so we'll get invalid JSON + case err != nil: + continue + default: + targetPackage, ok := packageStatus.GetPackages()[packageName] + // Target package might not exist yet so continue + if !ok { + continue + } + + switch targetPackage.GetStatus() { + case protobufs.PackageStatus_InstallFailed: + return ErrFailedStatus + case protobufs.PackageStatus_Installed: + // Install successful + return nil + default: + // Collector may still be starting up or we may have read the file while it's being written + continue + } + } + } + } +} diff --git a/updater/internal/state/monitor_test.go b/updater/internal/state/monitor_test.go new file mode 100644 index 000000000..82d5c0dbf --- /dev/null +++ b/updater/internal/state/monitor_test.go @@ -0,0 +1,367 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/observiq/observiq-otel-collector/packagestate/mocks" + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" +) + +func TestCollectorMonitorSetState(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Package not in current status", + testFunc: func(*testing.T) { + mockStateManger := mocks.NewMockStateManager(t) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: make(map[string]*protobufs.PackageStatus), + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.Error(t, err) + }, + }, + { + desc: "Sets Status no error", + testFunc: func(*testing.T) { + pgkName := "my_package" + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + { + desc: "Sets Status w/error", + testFunc: func(*testing.T) { + pgkName := "my_package" + statusErr := errors.New("some error") + + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallFailed, + ErrorMessage: statusErr.Error(), + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_InstallFailed, statusErr) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + { + desc: "StateManager fails to save", + testFunc: func(*testing.T) { + pgkName := "my_package" + expectedErr := errors.New("bad") + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(expectedErr) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestCollectorMonitorMonitorForSuccess(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Context is canceled", + testFunc: func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + mockStateManger := mocks.NewMockStateManager(t) + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(ctx, "my_package") + assert.ErrorIs(t, err, context.Canceled) + }, + }, + { + desc: "Package Status Indicates Failed Install", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_InstallFailed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.ErrorIs(t, err, ErrFailedStatus) + }, + }, + { + desc: "Package Status Indicates Successful install", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "File does not exist at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(nil, os.ErrNotExist) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Error reading file at first first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(nil, errors.New("bad")) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Package is not present at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + firstStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{}, + } + secondStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(firstStatus, nil) + mockStateManger.On("LoadStatuses").Return(secondStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Package is still marked as Installing at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + firstStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_InstallPending, + }, + }, + } + secondStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(firstStatus, nil) + mockStateManger.On("LoadStatuses").Return(secondStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +}