From 2632a298e69c22dbfaf2cc4caf67c3dcb4b5845b Mon Sep 17 00:00:00 2001 From: Dmitry Lyfar Date: Sun, 23 Jul 2023 23:03:11 +1200 Subject: [PATCH] Merge snapcore/snapd/state changes as of ffa162ef6b State: - Add WaitStatus support that allows a task to wait until further action to continue its execution. The WaitStatus is treated mostly as DoneStatus, except it is not a ready status. - Add CopyState function. - Add Change.AbortUnreadyLanes. - Add Change.CheckTaskDependencies to check if tasks have circular dependencies. - Add task and change callbacks invoked on a status change. - Implement timings.GetSaver interface in state.State. Note, however, the snapcore/snapd/timings package that provides the interface is not included in this PR. pebble has a similar "timing" package. It might be reasonable to introduce "timings" separately to replace the "timing" package to keep concerns (state and timings) separated. State engine: - Add StateStarterUp interface to perform possible expensive initialisation in a separate StartUp method. Overlord: - Implement the StateStarterUp interface Daemon: - Call Overlord StartUp during the start process to complete its initialisation. Note, that there is no obvious requirement to introduce this separation for pebble. It can still initialise Overlord in New(). The separation was mimicked to maintain better compatibility with possible future mergers of the snapcore/snapd/overlord/* changes. Misc: - Add jsonutil package required by the new state's CopyState function. - Rename osutil.CanStat to osutil.FileExist. - Rename task.FakeTime to task.MockTime. - Add testutil.ErrorIs and testutil.DeepUnsortedMatches checkers. --- internals/cli/cmd_run.go | 5 +- internals/daemon/api_changes_test.go | 16 +- internals/daemon/api_files_test.go | 12 +- internals/daemon/api_test.go | 1 + internals/daemon/daemon.go | 17 +- internals/daemon/daemon_test.go | 52 +- internals/jsonutil/json.go | 66 ++ internals/jsonutil/json_test.go | 90 +++ internals/jsonutil/safejson/safejson.go | 202 +++++ internals/jsonutil/safejson/safejson_test.go | 148 ++++ internals/osutil/io_test.go | 4 +- internals/osutil/squashfs/fstype.go | 2 +- internals/osutil/stat.go | 4 +- internals/osutil/stat_test.go | 6 +- internals/overlord/export_test.go | 8 + internals/overlord/overlord.go | 63 +- internals/overlord/overlord_test.go | 200 ++++- internals/overlord/patch/patch.go | 9 +- internals/overlord/restart/restart.go | 4 +- internals/overlord/restart/restart_test.go | 3 +- internals/overlord/state/change.go | 341 ++++++++- internals/overlord/state/change_test.go | 761 ++++++++++++++++++- internals/overlord/state/copy.go | 141 ++++ internals/overlord/state/copy_test.go | 146 ++++ internals/overlord/state/export_test.go | 10 +- internals/overlord/state/state.go | 155 +++- internals/overlord/state/state_test.go | 360 ++++++++- internals/overlord/state/task.go | 202 +++-- internals/overlord/state/task_test.go | 74 +- internals/overlord/state/taskrunner.go | 51 +- internals/overlord/state/taskrunner_test.go | 438 ++++++++++- internals/overlord/state/warning.go | 2 +- internals/overlord/state/warning_test.go | 4 +- internals/overlord/stateengine.go | 80 +- internals/overlord/stateengine_test.go | 95 ++- internals/systemd/sdnotify.go | 2 +- internals/systemd/systemd.go | 2 +- internals/systemd/systemd_test.go | 4 +- internals/testutil/containschecker.go | 160 +++- internals/testutil/containschecker_test.go | 180 ++++- internals/testutil/errorischecker.go | 53 ++ internals/testutil/errorischecker_test.go | 72 ++ 42 files changed, 3947 insertions(+), 298 deletions(-) create mode 100644 internals/jsonutil/json.go create mode 100644 internals/jsonutil/json_test.go create mode 100644 internals/jsonutil/safejson/safejson.go create mode 100644 internals/jsonutil/safejson/safejson_test.go create mode 100644 internals/overlord/state/copy.go create mode 100644 internals/overlord/state/copy_test.go create mode 100644 internals/testutil/errorischecker.go create mode 100644 internals/testutil/errorischecker_test.go diff --git a/internals/cli/cmd_run.go b/internals/cli/cmd_run.go index 2c1f2083..62d513a5 100644 --- a/internals/cli/cmd_run.go +++ b/internals/cli/cmd_run.go @@ -182,7 +182,10 @@ func runDaemon(rcmd *cmdRun, ch chan os.Signal, ready chan<- func()) error { } d.Version = cmd.Version - d.Start() + err = d.Start() + if err != nil { + return err + } watchdog, err := runWatchdog(d) if err != nil { diff --git a/internals/daemon/api_changes_test.go b/internals/daemon/api_changes_test.go index 0f5f6228..f52788ff 100644 --- a/internals/daemon/api_changes_test.go +++ b/internals/daemon/api_changes_test.go @@ -46,7 +46,7 @@ func setupChanges(st *state.State) []string { } func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -75,7 +75,7 @@ func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) { } func (s *apiSuite) TestStateChangesInProgress(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -104,7 +104,7 @@ func (s *apiSuite) TestStateChangesInProgress(c *check.C) { } func (s *apiSuite) TestStateChangesAll(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -133,7 +133,7 @@ func (s *apiSuite) TestStateChangesAll(c *check.C) { } func (s *apiSuite) TestStateChangesReady(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -161,7 +161,7 @@ func (s *apiSuite) TestStateChangesReady(c *check.C) { } func (s *apiSuite) TestStateChangesForServiceName(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -192,7 +192,7 @@ func (s *apiSuite) TestStateChangesForServiceName(c *check.C) { } func (s *apiSuite) TestStateChange(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -261,7 +261,7 @@ func (s *apiSuite) TestStateChange(c *check.C) { } func (s *apiSuite) TestStateChangeAbort(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() soon := 0 @@ -334,7 +334,7 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { } func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup diff --git a/internals/daemon/api_files_test.go b/internals/daemon/api_files_test.go index ab41554d..26157f2b 100644 --- a/internals/daemon/api_files_test.go +++ b/internals/daemon/api_files_test.go @@ -625,7 +625,7 @@ func (s *filesSuite) TestRemoveSingle(c *C) { c.Check(r.Result, HasLen, 1) checkFileResult(c, r.Result[0], tmpDir+"/file", "", "") - c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false) + c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false) } func (s *filesSuite) TestRemoveMultiple(c *C) { @@ -667,7 +667,7 @@ func (s *filesSuite) TestRemoveMultiple(c *C) { checkFileResult(c, r.Result[2], tmpDir+"/non-empty", "generic-file-error", ".*directory not empty") checkFileResult(c, r.Result[3], tmpDir+"/recursive", "", "") - c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false) + c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false) c.Check(osutil.IsDir(tmpDir+"/empty"), Equals, false) c.Check(osutil.IsDir(tmpDir+"/non-empty"), Equals, true) c.Check(osutil.IsDir(tmpDir+"/recursive"), Equals, false) @@ -1186,10 +1186,10 @@ group not found checkFileResult(c, r.Result[4], pathUserNotFound, "generic-file-error", ".*unknown user.*") checkFileResult(c, r.Result[5], pathGroupNotFound, "generic-file-error", ".*unknown group.*") - c.Check(osutil.CanStat(pathNoContent), Equals, false) - c.Check(osutil.CanStat(pathNotAbsolute), Equals, false) - c.Check(osutil.CanStat(pathNotFound), Equals, false) - c.Check(osutil.CanStat(pathPermissionDenied), Equals, false) + c.Check(osutil.FileExists(pathNoContent), Equals, false) + c.Check(osutil.FileExists(pathNotAbsolute), Equals, false) + c.Check(osutil.FileExists(pathNotFound), Equals, false) + c.Check(osutil.FileExists(pathPermissionDenied), Equals, false) } func assertFile(c *C, path string, perm os.FileMode, content string) { diff --git a/internals/daemon/api_test.go b/internals/daemon/api_test.go index ed7ebd7f..38735743 100644 --- a/internals/daemon/api_test.go +++ b/internals/daemon/api_test.go @@ -58,6 +58,7 @@ func (s *apiSuite) daemon(c *check.C) *Daemon { } d, err := New(&Options{Dir: s.pebbleDir}) c.Assert(err, check.IsNil) + c.Assert(d.Overlord().StartUp(), check.IsNil) d.addRoutes() s.d = d return d diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index b7405ee0..dc46030f 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -385,6 +385,10 @@ func (d *Daemon) Init() error { return nil } +func (d *Daemon) Overlord() *overlord.Overlord { + return d.overlord +} + // SetDegradedMode puts the daemon into a degraded mode which will the // error given in the "err" argument for commands that are not marked // as readonlyOK. @@ -452,13 +456,13 @@ func (d *Daemon) initStandbyHandling() { d.standbyOpinions.Start() } -func (d *Daemon) Start() { +func (d *Daemon) Start() error { if d.rebootIsMissing { // we need to schedule and wait for a system restart d.tomb.Kill(nil) // avoid systemd killing us again while we wait systemdSdNotify("READY=1") - return + return nil } if d.overlord == nil { panic("internal error: no Overlord") @@ -466,6 +470,10 @@ func (d *Daemon) Start() { d.StartTime = time.Now() + // now perform expensive overlord/manages initialization + if err := d.overlord.StartUp(); err != nil { + return err + } d.connTracker = &connTracker{conns: make(map[net.Conn]struct{})} d.serve = &http.Server{ Handler: logit(d.router), @@ -505,6 +513,7 @@ func (d *Daemon) Start() { // notify systemd that we are ready systemdSdNotify("READY=1") + return nil } // HandleRestart implements overlord.RestartBehavior. @@ -673,7 +682,7 @@ func (d *Daemon) rebootDelay() (time.Duration, error) { // see whether a reboot had already been scheduled var rebootAt time.Time err := d.state.Get("daemon-system-restart-at", &rebootAt) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return 0, err } rebootDelay := 1 * time.Minute @@ -758,7 +767,7 @@ var errExpectedReboot = errors.New("expected reboot did not happen") func (d *Daemon) RebootIsMissing(st *state.State) error { var nTentative int err := st.Get("daemon-system-restart-tentative", &nTentative) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } nTentative++ diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index c10d04d9..5e6c8f1e 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -103,7 +103,8 @@ func (s *daemonSuite) TestExplicitPaths(c *C) { d := s.newDaemon(c) d.Init() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) info, err := os.Stat(s.socketPath) @@ -459,7 +460,8 @@ func (s *daemonSuite) TestStartStop(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l2, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) generalDone := make(chan struct{}) go func() { @@ -500,7 +502,8 @@ func (s *daemonSuite) TestRestartWiring(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) generalDone := make(chan struct{}) @@ -567,7 +570,8 @@ func (s *daemonSuite) TestGracefulStop(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: untrustedL, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) generalAccepting := make(chan struct{}) go func() { @@ -634,7 +638,8 @@ func (s *daemonSuite) TestRestartSystemWiring(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) st := d.overlord.State() @@ -781,7 +786,8 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) st := d.overlord.State() st.Lock() @@ -791,7 +797,7 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) { ch := make(chan os.Signal, 2) ch <- syscall.SIGTERM // stop will check if we got a sigterm in between (which we did) - err := d.Stop(ch) + err = d.Stop(ch) c.Assert(err, check.IsNil) } @@ -813,7 +819,8 @@ func (s *daemonSuite) TestRestartShutdown(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) st := d.overlord.State() st.Lock() @@ -860,7 +867,8 @@ func (s *daemonSuite) TestRestartExpectedRebootIsMissing(c *check.C) { c.Check(err, check.IsNil) c.Check(n, check.Equals, 1) - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) c.Check(s.notified, check.DeepEquals, []string{"READY=1"}) @@ -896,8 +904,8 @@ func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) { @@ -920,9 +928,9 @@ func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-tentative", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { @@ -936,7 +944,8 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) // pretend some ensure happened for i := 0; i < 5; i++ { @@ -955,7 +964,7 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { case <-time.After(15 * time.Second): c.Errorf("daemon did not stop after 15s") } - err := d.Stop(nil) + err = d.Stop(nil) c.Check(err, check.Equals, ErrRestartSocket) c.Check(d.restartSocket, check.Equals, true) } @@ -972,7 +981,8 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { st := d.overlord.State() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) // pretend some ensure happened for i := 0; i < 5; i++ { d.overlord.StateEngine().Ensure() @@ -998,7 +1008,7 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { c.Errorf("daemon did not stop after 5s") } // when the daemon got a pending change it just restarts - err := d.Stop(nil) + err = d.Stop(nil) c.Check(err, check.IsNil) c.Check(d.restartSocket, check.Equals, false) } @@ -1058,7 +1068,8 @@ func (s *daemonSuite) TestHTTPAPI(c *check.C) { s.httpAddress = ":0" // Go will choose port (use listener.Addr() to find it) d := s.newDaemon(c) d.Init() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) port := d.httpListener.Addr().(*net.TCPAddr).Port request, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d/v1/health", port), nil) @@ -1101,7 +1112,8 @@ services: d := s.newDaemon(c) err := d.Init() c.Assert(err, IsNil) - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) // Start the test service. payload := bytes.NewBufferString(`{"action": "start", "services": ["test1"]}`) diff --git a/internals/jsonutil/json.go b/internals/jsonutil/json.go new file mode 100644 index 00000000..056ab0a4 --- /dev/null +++ b/internals/jsonutil/json.go @@ -0,0 +1,66 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2017 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package jsonutil + +import ( + "encoding/json" + "fmt" + "io" + "reflect" + "strings" + + "github.com/canonical/x-go/strutil" +) + +// DecodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved +// via json.Number data type. It errors out on invalid json or any excess input. +func DecodeWithNumber(r io.Reader, value interface{}) error { + dec := json.NewDecoder(r) + dec.UseNumber() + if err := dec.Decode(&value); err != nil { + return err + } + if dec.More() { + return fmt.Errorf("cannot parse json value") + } + return nil +} + +// StructFields takes a pointer to a struct and a list of exceptions, +// and returns a list of the fields in the struct that are JSON-tagged +// and whose tag is not in the list of exceptions. +// The struct can be nil. +func StructFields(s interface{}, exceptions ...string) []string { + st := reflect.TypeOf(s).Elem() + num := st.NumField() + fields := make([]string, 0, num) + for i := 0; i < num; i++ { + tag := st.Field(i).Tag.Get("json") + idx := strings.IndexByte(tag, ',') + if idx > -1 { + tag = tag[:idx] + } + if tag != "" && !strutil.ListContains(exceptions, tag) { + fields = append(fields, tag) + } + } + + return fields +} diff --git a/internals/jsonutil/json_test.go b/internals/jsonutil/json_test.go new file mode 100644 index 00000000..2fb35c0b --- /dev/null +++ b/internals/jsonutil/json_test.go @@ -0,0 +1,90 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2017 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package jsonutil_test + +import ( + "encoding/json" + "strings" + "testing" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/jsonutil" +) + +func Test(t *testing.T) { TestingT(t) } + +type utilSuite struct{} + +var _ = Suite(&utilSuite{}) + +func (s *utilSuite) TestDecodeError(c *C) { + input := "{]" + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, `invalid character ']' looking for beginning of object key string`) +} + +func (s *utilSuite) TestDecodeErrorOnExcessData(c *C) { + input := "1000000000[1,2]" + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, `cannot parse json value`) +} + +func (s *utilSuite) TestDecodeSuccess(c *C) { + input := `{"a":1000000000, "b": 1.2, "c": "foo", "d":null}` + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, IsNil) + c.Assert(output, DeepEquals, map[string]interface{}{ + "a": json.Number("1000000000"), + "b": json.Number("1.2"), + "c": "foo", + "d": nil, + }) +} + +func (utilSuite) TestStructFields(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int `json:"potato,stuff"` + } + c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello", "potato"}) +} + +func (utilSuite) TestStructFieldsExcept(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int `json:"potato,stuff"` + } + c.Assert(jsonutil.StructFields((*aStruct)(nil), "potato"), DeepEquals, []string{"hello"}) + c.Assert(jsonutil.StructFields((*aStruct)(nil), "hello"), DeepEquals, []string{"potato"}) +} + +func (utilSuite) TestStructFieldsSurvivesNoTag(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int + } + c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello"}) +} diff --git a/internals/jsonutil/safejson/safejson.go b/internals/jsonutil/safejson/safejson.go new file mode 100644 index 00000000..298e6945 --- /dev/null +++ b/internals/jsonutil/safejson/safejson.go @@ -0,0 +1,202 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package safejson + +import ( + "fmt" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" + + "github.com/canonical/x-go/strutil" +) + +// String accepts any valid JSON string. Its Clean method will remove +// characters that aren't expected in a short descriptive text. +// I.e.: Cc, Co, Cf, Cs, noncharacters, and � (U+FFFD, the replacement +// character) are removed. +type String struct { + s string +} + +func (str *String) UnmarshalJSON(in []byte) (err error) { + str.s, err = unmarshal(in, uOpt{}) + return +} + +// Clean returns the string, with Cc, Co, Cf, Cs, noncharacters, +// and � (U+FFFD) removed. +func (str String) Clean() string { + return str.s +} + +// Paragraph accepts any valid JSON string. Its Clean method will remove +// characters that aren't expected in a long descriptive text. +// I.e.: Cc (except for \n), Co, Cf, Cs, noncharacters, and � (U+FFFD, +// the replacement character) are removed. +type Paragraph struct { + s string +} + +func (par *Paragraph) UnmarshalJSON(in []byte) (err error) { + par.s, err = unmarshal(in, uOpt{nlOK: true}) + return +} + +// Clean returns the string, with Cc minus \n, Co, Cf, Cs, noncharacters, +// and � (U+FFFD) removed. +func (par Paragraph) Clean() string { + return par.s +} + +func unescapeUCS2(in []byte) (rune, bool) { + if len(in) < 6 || in[0] != '\\' || in[1] != 'u' { + return -1, false + } + u, err := strconv.ParseUint(string(in[2:6]), 16, 32) + if err != nil { + return -1, false + } + return rune(u), true +} + +type uOpt struct { + nlOK bool + simple bool +} + +func unmarshal(in []byte, o uOpt) (string, error) { + // heavily based on (inspired by?) unquoteBytes from encoding/json + + if len(in) < 2 || in[0] != '"' || in[len(in)-1] != '"' { + // maybe it's a null and that's alright + if len(in) == 4 && in[0] == 'n' && in[1] == 'u' && in[2] == 'l' && in[3] == 'l' { + return "", nil + } + return "", fmt.Errorf("missing string delimiters: %q", in) + } + + // prune the quotes + in = in[1 : len(in)-1] + i := 0 + // try the fast track + for i < len(in) { + // 0x00..0x19 is the first of Cc + // 0x20..0x7e is all of printable ASCII (minus control chars) + if in[i] < 0x20 || in[i] > 0x7e || in[i] == '\\' || in[i] == '"' { + break + } + i++ + } + if i == len(in) { + // wee + return string(in), nil + } + if o.simple { + return "", fmt.Errorf("character %q in string %q unsupported for this value", in[i], in) + } + // in[i] is the first problematic one + out := make([]byte, i, len(in)+2*utf8.UTFMax) + copy(out, in) + var r, r2 rune + var n int + var c byte + var ubuf [utf8.UTFMax]byte + var ok bool + for i < len(in) { + c = in[i] + switch { + case c == '"': + return "", fmt.Errorf("unexpected unescaped quote at %d in \"%s\"", i, in) + case c < 0x20: + return "", fmt.Errorf("unexpected control character at %d in %q", i, in) + case c == '\\': + // handle escapes + i++ + if i == len(in) { + return "", fmt.Errorf("unexpected end of string (trailing backslash) in \"%s\"", in) + } + switch in[i] { + case 'u': + // oh dear, a unicode wotsit + r, ok = unescapeUCS2(in[i-1:]) + if !ok { + x := in[i-1:] + if len(x) > 6 { + x = x[:6] + } + return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in) + } + i += 5 + if utf16.IsSurrogate(r) { + // sigh + r2, ok = unescapeUCS2(in[i:]) + if !ok { + x := in[i:] + if len(x) > 6 { + x = x[:6] + } + return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in) + } + i += 6 + r = utf16.DecodeRune(r, r2) + } + if r <= 0x9f { + // otherwise, it's Cc (both halves, as we're looking at runes) + if (o.nlOK && r == '\n') || (r >= 0x20 && r <= 0x7e) { + out = append(out, byte(r)) + } + } else if r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) { + n = utf8.EncodeRune(ubuf[:], r) + out = append(out, ubuf[:n]...) + } + case 'b', 'f', 'r', 't': + // do nothing + i++ + case 'n': + if o.nlOK { + out = append(out, '\n') + } + i++ + case '"', '/', '\\': + // the spec says just ", / and \ can be backslash-escaped + // but go adds ' to the list (in unquoteBytes) + out = append(out, in[i]) + i++ + default: + return "", fmt.Errorf(`unknown escape '%c' at %d of "%s"`, in[i], i, in) + } + case c <= 0x7e: + // printable ASCII, except " or \ + out = append(out, c) + i++ + default: + r, n = utf8.DecodeRune(in[i:]) + j := i + n + if r > 0x9f && r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) { + out = append(out, in[i:j]...) + } + i = j + } + } + + return string(out), nil +} diff --git a/internals/jsonutil/safejson/safejson_test.go b/internals/jsonutil/safejson/safejson_test.go new file mode 100644 index 00000000..aae12aea --- /dev/null +++ b/internals/jsonutil/safejson/safejson_test.go @@ -0,0 +1,148 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package safejson_test + +import ( + "encoding/json" + "testing" + + "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/jsonutil/safejson" +) + +func Test(t *testing.T) { check.TestingT(t) } + +type escapeSuite struct{} + +var _ = check.Suite(escapeSuite{}) + +var table = map[string]string{ + "null": "", + `"hello"`: "hello", + `"árbol"`: "árbol", + `"\u0020"`: " ", + `"\uD83D\uDE00"`: "😀", + `"a\b\r\tb"`: "ab", + `"\\\""`: `\"`, + // escape sequences (NOTE just the control char is stripped) + `"\u001b[3mhello\u001b[m"`: "[3mhello[m", + `"a\u0080z"`: "az", + "\"a\u0080z\"": "az", + "\"a\u007fz\"": "az", + "\"a\u009fz\"": "az", + // replacement char + `"a\uFFFDb"`: "ab", + // private unicode chars + `"a\uE000b"`: "ab", + `"a\uDB80\uDC00b"`: "ab", +} + +func (escapeSuite) TestStrings(c *check.C) { + var u safejson.String + for j, s := range table { + comm := check.Commentf(j) + c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + + c.Assert(u.UnmarshalJSON([]byte(j)), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + } +} + +func (escapeSuite) TestBadStrings(c *check.C) { + var u1 safejson.String + + cc0 := make([][]byte, 0x20) + for i := range cc0 { + cc0[i] = []byte{'"', byte(i), '"'} + } + badesc := make([][]byte, 0, 0x7f-0x21-9) + for c := byte('!'); c <= '~'; c++ { + switch c { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u': + continue + default: + badesc = append(badesc, []byte{'"', '\\', c, '"'}) + } + } + + table := map[string][][]byte{ + // these are from json itself (so we're not checking them): + "invalid character '.+' in string literal": cc0, + "invalid character '.+' in string escape code": badesc, + `invalid character '.+' in \\u .*`: {[]byte(`"\u02"`), []byte(`"\u02zz"`)}, + "invalid character '\"' after top-level value": {[]byte(`"""`)}, + "unexpected end of JSON input": {[]byte(`"\"`)}, + } + + for e, js := range table { + for _, j := range js { + comm := check.Commentf("%q", j) + c.Check(json.Unmarshal(j, &u1), check.ErrorMatches, e, comm) + } + } + + table = map[string][][]byte{ + // these are from our lib + `missing string delimiters.*`: {{}, {'"'}}, + `unexpected control character at 0 in "\\.+"`: cc0, + `unknown escape '.' at 1 of "\\."`: badesc, + `badly formed \\u escape.*`: { + []byte(`"\u02"`), []byte(`"\u02zz"`), []byte(`"a\u02xxz"`), + []byte(`"\uD83Da"`), []byte(`"\uD83Da\u20"`), []byte(`"\uD83Da\u20zzz"`), + }, + `unexpected unescaped quote at 0 in """`: {[]byte(`"""`)}, + `unexpected end of string \(trailing backslash\).*`: {[]byte(`"\"`)}, + } + + for e, js := range table { + for _, j := range js { + comm := check.Commentf("%q", j) + c.Check(u1.UnmarshalJSON(j), check.ErrorMatches, e, comm) + } + } +} + +func (escapeSuite) TestParagraph(c *check.C) { + var u safejson.Paragraph + for j1, v1 := range table { + for j2, v2 := range table { + if j2 == "null" && j1 != "null" { + continue + } + + var j, s string + + if j1 == "null" { + j = j2 + s = v2 + } else { + j = j1[:len(j1)-1] + "\\n" + j2[1:] + s = v1 + "\n" + v2 + } + + comm := check.Commentf(j) + c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + } + } + +} diff --git a/internals/osutil/io_test.go b/internals/osutil/io_test.go index a57d88a1..61c6e69c 100644 --- a/internals/osutil/io_test.go +++ b/internals/osutil/io_test.go @@ -245,9 +245,9 @@ func (ts *AtomicWriteTestSuite) TestAtomicFileCancel(c *C) { aw, err := osutil.NewAtomicFile(p, 0644, 0, osutil.NoChown, osutil.NoChown) c.Assert(err, IsNil) fn := aw.File.Name() - c.Check(osutil.CanStat(fn), Equals, true) + c.Check(osutil.FileExists(fn), Equals, true) c.Check(aw.Cancel(), IsNil) - c.Check(osutil.CanStat(fn), Equals, false) + c.Check(osutil.FileExists(fn), Equals, false) } // SafeIoAtomicWriteTestSuite runs all AtomicWrite with safe diff --git a/internals/osutil/squashfs/fstype.go b/internals/osutil/squashfs/fstype.go index 07a8b8d5..f03fd213 100644 --- a/internals/osutil/squashfs/fstype.go +++ b/internals/osutil/squashfs/fstype.go @@ -25,7 +25,7 @@ import ( var useFuse = useFuseImpl func useFuseImpl() bool { - if !osutil.CanStat("/dev/fuse") { + if !osutil.FileExists("/dev/fuse") { return false } diff --git a/internals/osutil/stat.go b/internals/osutil/stat.go index c806dcad..8387e4c9 100644 --- a/internals/osutil/stat.go +++ b/internals/osutil/stat.go @@ -25,9 +25,9 @@ import ( "syscall" ) -// CanStat returns true if stat succeeds on the given path. +// FileExists returns true if stat succeeds on the given path. // It may return false on permission issues. -func CanStat(path string) bool { +func FileExists(path string) bool { _, err := os.Stat(path) return err == nil } diff --git a/internals/osutil/stat_test.go b/internals/osutil/stat_test.go index a1235ced..cdf11672 100644 --- a/internals/osutil/stat_test.go +++ b/internals/osutil/stat_test.go @@ -34,8 +34,8 @@ func (ts *StatTestSuite) TestCanStat(c *C) { err := ioutil.WriteFile(fname, []byte(fname), 0644) c.Assert(err, IsNil) - c.Assert(CanStat(fname), Equals, true) - c.Assert(CanStat("/i-do-not-exist"), Equals, false) + c.Assert(FileExists(fname), Equals, true) + c.Assert(FileExists("/i-do-not-exist"), Equals, false) } func (ts *StatTestSuite) TestCanStatOddPerms(c *C) { @@ -43,7 +43,7 @@ func (ts *StatTestSuite) TestCanStatOddPerms(c *C) { err := ioutil.WriteFile(fname, []byte(fname), 0100) c.Assert(err, IsNil) - c.Assert(CanStat(fname), Equals, true) + c.Assert(FileExists(fname), Equals, true) } func (ts *StatTestSuite) TestIsDir(c *C) { diff --git a/internals/overlord/export_test.go b/internals/overlord/export_test.go index 22529702..8fe8f187 100644 --- a/internals/overlord/export_test.go +++ b/internals/overlord/export_test.go @@ -40,6 +40,14 @@ func FakePruneInterval(prunei, prunew, abortw time.Duration) (restore func()) { } } +func FakePruneTicker(f func(t *time.Ticker) <-chan time.Time) (restore func()) { + old := pruneTickerC + pruneTickerC = f + return func() { + pruneTickerC = old + } +} + // FakeEnsureNext sets o.ensureNext for tests. func FakeEnsureNext(o *Overlord, t time.Time) { o.ensureNext = t diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go index 771a670d..bba805be 100644 --- a/internals/overlord/overlord.go +++ b/internals/overlord/overlord.go @@ -16,6 +16,7 @@ package overlord import ( + "errors" "fmt" "io" "os" @@ -49,6 +50,10 @@ var ( defaultCachedDownloads = 5 ) +var pruneTickerC = func(t *time.Ticker) <-chan time.Time { + return t.C +} + // Overlord is the central manager of the system, keeping track // of all available state managers and related helpers. type Overlord struct { @@ -63,8 +68,11 @@ type Overlord struct { ensureRun int32 pruneTicker *time.Ticker + startOfOperationTime time.Time + // managers inited bool + startedUp bool runner *state.TaskRunner serviceMgr *servstate.ServiceManager commandMgr *cmdstate.CommandManager @@ -136,6 +144,47 @@ func New(pebbleDir string, restartHandler restart.Handler, serviceOutput io.Writ return o, nil } +var timeNow = time.Now + +// StartOfOperationTime returns the time when pebble started operating, +// and sets it in the state when called for the first time. +// The StartOfOperationTime time is seed-time if available, +// or current time otherwise. +func (m *Overlord) StartOfOperationTime() (time.Time, error) { + var opTime time.Time + err := m.State().Get("start-of-operation-time", &opTime) + if err == nil { + return opTime, nil + } + if err != nil && !errors.Is(err, state.ErrNoState) { + return opTime, err + } + + opTime = timeNow() + m.State().Set("start-of-operation-time", opTime) + return opTime, nil +} + +// StartUp proceeds to run any expensive Overlord or managers initialization. +// After this is done once it is a noop. +func (o *Overlord) StartUp() error { + if o.startedUp { + return nil + } + o.startedUp = true + + var err error + st := o.State() + st.Lock() + o.startOfOperationTime, err = o.StartOfOperationTime() + st.Unlock() + if err != nil { + return fmt.Errorf("cannot get start of operation time: %s", err) + } + + return o.stateEng.StartUp() +} + func (o *Overlord) addManager(mgr StateManager) { o.stateEng.AddManager(mgr) } @@ -159,7 +208,7 @@ func loadState(statePath string, restartHandler restart.Handler, backend state.B } } - if !osutil.CanStat(statePath) { + if !osutil.FileExists(statePath) { // fail fast, mostly interesting for tests, this dir is set up by pebble stateDir := filepath.Dir(statePath) if !osutil.IsDir(stateDir) { @@ -255,6 +304,9 @@ func (o *Overlord) ensureBefore(d time.Duration) { // Loop runs a loop in a goroutine to ensure the current state regularly through StateEngine Ensure. func (o *Overlord) Loop() { o.ensureTimerSetup() + if o.loopTomb == nil { + o.loopTomb = new(tomb.Tomb) + } o.loopTomb.Go(func() error { for { // TODO: pass a proper context into Ensure @@ -263,14 +315,15 @@ func (o *Overlord) Loop() { // continue to the next Ensure() try for now o.stateEng.Ensure() o.ensureDidRun() + pruneC := pruneTickerC(o.pruneTicker) select { case <-o.loopTomb.Dying(): return nil case <-o.ensureTimer.C: - case <-o.pruneTicker.C: + case <-pruneC: st := o.State() st.Lock() - st.Prune(pruneWait, abortWait, pruneMaxChanges) + st.Prune(o.startOfOperationTime, pruneWait, abortWait, pruneMaxChanges) st.Unlock() } } @@ -295,6 +348,10 @@ func (o *Overlord) Stop() error { } func (o *Overlord) settle(timeout time.Duration, beforeCleanups func()) error { + if err := o.StartUp(); err != nil { + return err + } + func() { o.ensureLock.Lock() defer o.ensureLock.Unlock() diff --git a/internals/overlord/overlord_test.go b/internals/overlord/overlord_test.go index c6b9cc7e..33696609 100644 --- a/internals/overlord/overlord_test.go +++ b/internals/overlord/overlord_test.go @@ -46,6 +46,26 @@ type overlordSuite struct { var _ = Suite(&overlordSuite{}) +type ticker struct { + tickerChannel chan time.Time +} + +func (w *ticker) tick(n int) { + for i := 0; i < n; i++ { + w.tickerChannel <- time.Now() + } +} + +func fakePruneTicker() (w *ticker, restore func()) { + w = &ticker{ + tickerChannel: make(chan time.Time), + } + restore = overlord.FakePruneTicker(func(t *time.Ticker) <-chan time.Time { + return w.tickerChannel + }) + return w, restore +} + func (ovs *overlordSuite) SetUpTest(c *C) { ovs.dir = c.MkDir() ovs.statePath = filepath.Join(ovs.dir, ".pebble.state") @@ -161,6 +181,12 @@ type witnessManager struct { expectedEnsure int ensureCalled chan struct{} ensureCallback func(s *state.State) error + startedUp int +} + +func (wm *witnessManager) StartUp() error { + wm.startedUp++ + return nil } func (wm *witnessManager) Ensure() error { @@ -178,6 +204,9 @@ func (ovs *overlordSuite) TestTrivialRunAndStop(c *C) { o, err := overlord.New(ovs.dir, nil, nil) c.Assert(err, IsNil) + err = o.StartUp() + c.Assert(err, IsNil) + o.Loop() err = o.Stop() @@ -216,6 +245,10 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) { } o.AddManager(witness) + err := o.StartUp() + + c.Assert(err, IsNil) + o.Loop() defer o.Stop() @@ -227,8 +260,9 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) { } c.Check(time.Since(t0) >= 10*time.Millisecond, Equals, true) - err := o.Stop() + err = o.Stop() c.Assert(err, IsNil) + c.Check(witness.startedUp, Equals, 1) } func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) { @@ -248,7 +282,7 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) { ensureCallback: ensure, } o.AddManager(witness) - + c.Assert(o.StartUp(), IsNil) o.Loop() defer o.Stop() @@ -277,6 +311,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBefore(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -306,6 +342,8 @@ func (ovs *overlordSuite) TestEnsureBeforeSleepy(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -335,6 +373,8 @@ func (ovs *overlordSuite) TestEnsureBeforeLater(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -364,6 +404,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeOutsideEnsure(c *C) } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -420,6 +462,8 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() select { @@ -441,7 +485,7 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) { } func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { - restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour) + restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 5*time.Millisecond, 1*time.Hour) defer restoreIntv() o := overlord.Fake() @@ -459,20 +503,26 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { c.Check(st.Changes(), HasLen, 2) st.Unlock() + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + // start the loop that runs the prune ticker o.Loop() - // ensure the first change is pruned - time.Sleep(1500 * time.Millisecond) - st.Lock() - c.Check(st.Changes(), HasLen, 1) - st.Unlock() + // this needs to be more than pruneWait=5ms mocked above + time.Sleep(10 * time.Millisecond) + w.tick(2) - // ensure the second is also purged after it is ready st.Lock() + c.Check(st.Changes(), HasLen, 1) chg2.SetStatus(state.DoneStatus) st.Unlock() - time.Sleep(1500 * time.Millisecond) + + // this needs to be more than pruneWait=5ms mocked above + time.Sleep(10 * time.Millisecond) + // tick twice for extra Ensure + w.tick(2) + st.Lock() c.Check(st.Changes(), HasLen, 0) st.Unlock() @@ -482,6 +532,134 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { c.Assert(err, IsNil) } +func (ovs *overlordSuite) TestOverlordStartUpSetsStartOfOperation(c *C) { + restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour) + defer restoreIntv() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + st := o.State() + st.Lock() + defer st.Unlock() + + // validity check, not set + var opTime time.Time + c.Assert(st.Get("start-of-operation-time", &opTime), testutil.ErrorIs, state.ErrNoState) + st.Unlock() + + c.Assert(o.StartUp(), IsNil) + + st.Lock() + c.Assert(st.Get("start-of-operation-time", &opTime), IsNil) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneDoesntAbortShortlyAfterStartOfOperation(c *C) { + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is 50min ago, this is less then abort limit + opTime := time.Now().Add(-50 * time.Minute) + st.Set("start-of-operation-time", opTime) + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.MockTime(spawnTime) + + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoingStatus) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneAbortsOld(c *C) { + // Ensure interval is not relevant for this test + restoreEnsureIntv := overlord.FakeEnsureInterval(10 * time.Hour) + defer restoreEnsureIntv() + + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to having unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is a year ago + opTime := time.Now().AddDate(-1, 0, 0) + st.Set("start-of-operation-time", opTime) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + st.Lock() + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.MockTime(spawnTime) + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + st.Unlock() + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + + // validity + op, err := o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(op.Equal(opTime), Equals, true) + + c.Assert(st.Changes(), HasLen, 1) + // change was aborted + c.Check(chg.Status(), Equals, state.HoldStatus) +} + func (ovs *overlordSuite) TestCheckpoint(c *C) { oldUmask := syscall.Umask(0) defer syscall.Umask(oldUmask) @@ -856,6 +1034,8 @@ func (ovs *overlordSuite) TestOverlordCanStandby(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + // can only standby after loop ran once c.Assert(o.CanStandby(), Equals, false) diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go index afc93186..58514a8a 100644 --- a/internals/overlord/patch/patch.go +++ b/internals/overlord/patch/patch.go @@ -20,6 +20,7 @@ package patch import ( + "errors" "fmt" "github.com/canonical/pebble/internals/logger" @@ -43,12 +44,12 @@ var patches = make(map[int][]PatchFunc) func Init(s *state.State) { s.Lock() defer s.Unlock() - if s.Get("patch-level", new(int)) != state.ErrNoState { + if err := s.Get("patch-level", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-level without actual patching") } s.Set("patch-level", Level) - if s.Get("patch-sublevel", new(int)) != state.ErrNoState { + if err := s.Get("patch-sublevel", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-sublevel without actual patching") } s.Set("patch-sublevel", Sublevel) @@ -76,12 +77,12 @@ func Apply(s *state.State) error { var stateLevel, stateSublevel int s.Lock() err := s.Get("patch-level", &stateLevel) - if err == nil || err == state.ErrNoState { + if err == nil || errors.Is(err, state.ErrNoState) { err = s.Get("patch-sublevel", &stateSublevel) } s.Unlock() - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } diff --git a/internals/overlord/restart/restart.go b/internals/overlord/restart/restart.go index 0c1a8840..6c11f62f 100644 --- a/internals/overlord/restart/restart.go +++ b/internals/overlord/restart/restart.go @@ -16,6 +16,8 @@ package restart import ( + "errors" + "github.com/canonical/pebble/internals/overlord/state" ) @@ -61,7 +63,7 @@ func Init(st *state.State, curBootID string, h Handler) error { } var fromBootID string err := st.Get("system-restart-from-boot-id", &fromBootID) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } st.Cache(restartStateKey{}, rs) diff --git a/internals/overlord/restart/restart_test.go b/internals/overlord/restart/restart_test.go index 8a8617ed..74fc0764 100644 --- a/internals/overlord/restart/restart_test.go +++ b/internals/overlord/restart/restart_test.go @@ -21,6 +21,7 @@ import ( "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestRestart(t *testing.T) { TestingT(t) } @@ -134,5 +135,5 @@ func (s *restartSuite) TestRequestRestartSystemAndVerifyReboot(c *C) { err = restart.Init(st, "boot-id-2", h2) c.Assert(err, IsNil) c.Check(h2.rebootAsExpected, Equals, true) - c.Check(st.Get("system-restart-from-boot-id", &fromBootID), Equals, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &fromBootID), testutil.ErrorIs, state.ErrNoState) } diff --git a/internals/overlord/state/change.go b/internals/overlord/state/change.go index bfb3b42d..22bfa123 100644 --- a/internals/overlord/state/change.go +++ b/internals/overlord/state/change.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2023 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -23,8 +23,11 @@ import ( "bytes" "encoding/json" "fmt" + "sort" "strings" "time" + + "github.com/canonical/pebble/internals/logger" ) // Status is used for status values for changes and tasks. @@ -37,7 +40,8 @@ const ( // to an aggregation of its tasks' statuses. See Change.Status for details. DefaultStatus Status = 0 - // HoldStatus means the task should not run, perhaps as a consequence of an error on another task. + // HoldStatus means the task should not run for the moment, perhaps as a + // consequence of an error on another task. HoldStatus Status = 1 // DoStatus means the change or task is ready to start. @@ -65,6 +69,12 @@ const ( // ErrorStatus means the change or task has errored out while running or being undone. ErrorStatus Status = 9 + // WaitStatus means the task was accomplished successfully but some + // external event needs to happen before work can progress further + // (e.g. on classic we require the user to reboot after a + // kernel snap update). + WaitStatus Status = 10 + nStatuses = iota ) @@ -88,6 +98,8 @@ func (s Status) String() string { return "Doing" case DoneStatus: return "Done" + case WaitStatus: + return "Wait" case AbortStatus: return "Abort" case UndoStatus: @@ -104,6 +116,18 @@ func (s Status) String() string { panic(fmt.Sprintf("internal error: unknown task status code: %d", s)) } +// taskWaitComputeStatus is used while computing the wait status of a +// change. It keeps track of whether a task is waiting or not waiting, or the +// computation for it is still in-progress to detect cyclic dependencies. +type taskWaitComputeStatus int + +const ( + taskWaitStatusNotComputed taskWaitComputeStatus = iota + taskWaitStatusComputing + taskWaitStatusNotWaiting + taskWaitStatusWaiting +) + // Change represents a tracked modification to the system state. // // The Change provides both the justification for individual tasks @@ -115,16 +139,16 @@ func (s Status) String() string { // while the individual Task values would track the running of // the hooks themselves. type Change struct { - state *State - id string - kind string - summary string - status Status - clean bool - data customData - taskIDs []string - lanes int - ready chan struct{} + state *State + id string + kind string + summary string + status Status + clean bool + data customData + taskIDs []string + ready chan struct{} + lastObservedStatus Status spawnTime time.Time readyTime time.Time @@ -157,7 +181,6 @@ type marshalledChange struct { Clean bool `json:"clean,omitempty"` Data map[string]*json.RawMessage `json:"data,omitempty"` TaskIDs []string `json:"task-ids,omitempty"` - Lanes int `json:"lanes,omitempty"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -178,7 +201,6 @@ func (c *Change) MarshalJSON() ([]byte, error) { Clean: c.clean, Data: c.data, TaskIDs: c.taskIDs, - Lanes: c.lanes, SpawnTime: c.spawnTime, ReadyTime: readyTime, @@ -206,7 +228,6 @@ func (c *Change) UnmarshalJSON(data []byte) error { } c.data = custData c.taskIDs = unmarshalled.TaskIDs - c.lanes = unmarshalled.Lanes c.ready = make(chan struct{}) c.spawnTime = unmarshalled.SpawnTime if unmarshalled.ReadyTime != nil { @@ -251,12 +272,19 @@ func (c *Change) Get(key string, value interface{}) error { return c.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (c *Change) Has(key string) bool { + c.state.reading() + return c.data.has(key) +} + var statusOrder = []Status{ AbortStatus, UndoingStatus, UndoStatus, DoingStatus, DoStatus, + WaitStatus, ErrorStatus, UndoneStatus, DoneStatus, @@ -269,32 +297,138 @@ func init() { } } +func (c *Change) isTaskWaiting(visited map[string]taskWaitComputeStatus, t *Task, deps []*Task) bool { + taskID := t.ID() + // Retrieve the compute status of the wait for the task, if not + // computed this defaults to 0 (taskWaitStatusNotComputed). + computeStatus := visited[taskID] + switch computeStatus { + case taskWaitStatusComputing: + // Cyclic dependency detected, return false to short-circuit. + logger.Noticef("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + // Make sure errors show up in "snap change " too + t.Logf("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + return false + case taskWaitStatusWaiting, taskWaitStatusNotWaiting: + return computeStatus == taskWaitStatusWaiting + } + visited[taskID] = taskWaitStatusComputing + + var isWaiting bool +depscheck: + for _, wt := range deps { + switch wt.Status() { + case WaitStatus: + isWaiting = true + // States that can be valid when waiting + // - Done, Undone, ErrorStatus, HoldStatus + case DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + // For 'Do' and 'Undo' we have to check whether the task is waiting + // for any dependencies. The logic is the same, but the set of tasks + // varies. + case DoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.WaitTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + case UndoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.HaltTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + default: + // When we determine the change can not be in a wait-state then + // break early. + isWaiting = false + break depscheck + } + } + if isWaiting { + visited[taskID] = taskWaitStatusWaiting + } else { + visited[taskID] = taskWaitStatusNotWaiting + } + return isWaiting +} + +// isChangeWaiting should only ever return true iff it determines all tasks in Do/Undo +// are blocked by tasks in either of three states: 'DoneStatus', 'UndoneStatus' or 'WaitStatus', +// if this fails, we default to the normal status ordering logic. +func (c *Change) isChangeWaiting() bool { + // Since we might visit tasks more than once, we store results to avoid recomputing them. + visited := make(map[string]taskWaitComputeStatus) + for _, t := range c.Tasks() { + switch t.Status() { + case WaitStatus, DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + case DoStatus: + if !c.isTaskWaiting(visited, t, t.WaitTasks()) { + return false + } + case UndoStatus: + if !c.isTaskWaiting(visited, t, t.HaltTasks()) { + return false + } + default: + return false + } + } + // If we end up here, then return true as we know we + // have at least one waiter in this change. + return true +} + // Status returns the current status of the change. // If the status was not explicitly set the result is derived from the status // of the individual tasks related to the change, according to the following // decision sequence: // +// - With all pending tasks blocked by other tasks in WaitStatus, return WaitStatus // - With at least one task in DoStatus, return DoStatus // - With at least one task in ErrorStatus, return ErrorStatus // - Otherwise, return DoneStatus func (c *Change) Status() Status { c.state.reading() - if c.status == DefaultStatus { - if len(c.taskIDs) == 0 { - return HoldStatus - } - statusStats := make([]int, nStatuses) - for _, tid := range c.taskIDs { - statusStats[c.state.tasks[tid].Status()]++ + if c.status != DefaultStatus { + return c.status + } + + if len(c.taskIDs) == 0 { + return HoldStatus + } + + statusStats := make([]int, nStatuses) + for _, tid := range c.taskIDs { + statusStats[c.state.tasks[tid].Status()]++ + } + + // If the change has any waiters, check for any runnable tasks + // or whether it's completely blocked by waiters. + if statusStats[WaitStatus] > 0 { + // Only if the change has all tasks blocked we return WaitStatus. + if c.isChangeWaiting() { + return WaitStatus } - for _, s := range statusOrder { - if statusStats[s] > 0 { - return s - } + } + + // Otherwise we return the current status with the highest priority. + for _, s := range statusOrder { + if statusStats[s] > 0 { + return s } - panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) } - return c.status + panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) +} + +func (c *Change) notifyStatusChange(new Status) { + if c.lastObservedStatus == new { + return + } + c.state.notifyChangeStatusChangedHandlers(c, c.lastObservedStatus, new) + c.lastObservedStatus = new } // SetStatus sets the change status, overriding the default behavior (see Status method). @@ -304,6 +438,7 @@ func (c *Change) SetStatus(s Status) { if s.Ready() { c.markReady() } + c.notifyStatusChange(c.Status()) } func (c *Change) markReady() { @@ -322,15 +457,10 @@ func (c *Change) Ready() <-chan struct{} { return c.ready } -// taskStatusChanged is called by tasks when their status is changed, -// to give the opportunity for the change to close its ready channel. -func (c *Change) taskStatusChanged(t *Task, old, new Status) { - if old.Ready() == new.Ready() { - return - } +func (c *Change) detectChangeReady(excludeTask *Task) { for _, tid := range c.taskIDs { task := c.state.tasks[tid] - if task != t && !task.status.Ready() { + if task != excludeTask && !task.status.Ready() { return } } @@ -343,6 +473,21 @@ func (c *Change) taskStatusChanged(t *Task, old, new Status) { c.markReady() } +// taskStatusChanged is called by tasks when their status is changed, +// to give the opportunity for the change to close its ready channel, and +// notify observers of Change changes. +func (c *Change) taskStatusChanged(t *Task, old, new Status) { + cs := c.Status() + // If the task changes from ready => unready or unready => ready, + // update the ready status for the change. + if old.Ready() == new.Ready() { + c.notifyStatusChange(cs) + return + } + c.detectChangeReady(t) + c.notifyStatusChange(cs) +} + // IsClean returns whether all tasks in the change have been cleaned. See SetClean. func (c *Change) IsClean() bool { c.state.reading() @@ -519,6 +664,44 @@ func (c *Change) AbortLanes(lanes []int) { c.abortLanes(lanes, make(map[int]bool), make(map[string]bool)) } +// AbortUnreadyLanes aborts the tasks from lanes that aren't fully ready, where +// a ready lane is one in which all tasks are ready. +func (c *Change) AbortUnreadyLanes() { + c.state.writing() + c.abortUnreadyLanes() +} + +func (c *Change) abortUnreadyLanes() { + lanesWithLiveTasks := map[int]bool{} + + for _, tid := range c.taskIDs { + t := c.state.tasks[tid] + if !t.Status().Ready() { + for _, tlane := range t.Lanes() { + lanesWithLiveTasks[tlane] = true + } + } + } + + abortLanes := []int{} + for lane := range lanesWithLiveTasks { + abortLanes = append(abortLanes, lane) + } + c.abortLanes(abortLanes, make(map[int]bool), make(map[string]bool)) +} + +// taskEffectiveStatus returns the 'effective' status. This means it accounts +// for tasks being in WaitStatus, and instead of returning the WaitStatus we +// return the actual status. (The status after the wait). +func taskEffectiveStatus(t *Task) Status { + status := t.Status() + if status == WaitStatus { + // If the task is waiting, then use the effective status instead. + status = t.WaitedStatus() + } + return status +} + func (c *Change) abortLanes(lanes []int, abortedLanes map[int]bool, seenTasks map[string]bool) { var hasLive = make(map[int]bool) var hasDead = make(map[int]bool) @@ -528,7 +711,7 @@ NextChangeTask: t := c.state.tasks[tid] var live bool - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus, DoingStatus, DoneStatus: live = true } @@ -579,7 +762,7 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks continue } seenTasks[t.id] = true - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus: // Still pending so don't even start. t.SetStatus(HoldStatus) @@ -607,3 +790,87 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks c.abortLanes(lanes, abortedLanes, seenTasks) } } + +type TaskDependencyCycleError struct { + IDs []string + msg string +} + +func (e *TaskDependencyCycleError) Error() string { return e.msg } + +func (e *TaskDependencyCycleError) Is(err error) bool { + _, ok := err.(*TaskDependencyCycleError) + return ok +} + +// CheckTaskDependencies checks the tasks in the change for cyclic dependencies +// and returns an error in such case. +func (c *Change) CheckTaskDependencies() error { + tasks := c.Tasks() + // count how many tasks any given non-independent task waits for + predecessors := make(map[string]int, len(tasks)) + + taskByID := map[string]*Task{} + for _, t := range tasks { + taskByID[t.id] = t + if l := len(t.waitTasks); l > 0 { + // only add an entry if the task is not independent + predecessors[t.id] = l + } + } + + // Kahn topological sort: make our way starting with tasks that are + // independent (their predecessors count is 0), then visit their direct + // successors (halt tasks), and for each reduce their predecessors + // count; once the count drops to 0, all direct dependencies of a given + // task have been accounted for and the task becomes independent. + + // queue of tasks to check + queue := make([]string, 0, len(tasks)) + // identify all independent tasks + for _, t := range tasks { + if predecessors[t.id] == 0 { + queue = append(queue, t.id) + } + } + + for len(queue) > 0 { + // take the first independent task + id := queue[0] + queue = queue[1:] + // reduce the incoming edge of its successors + for _, successor := range taskByID[id].haltTasks { + predecessors[successor]-- + if predecessors[successor] == 0 { + // a task that was a successor has become + // independent + delete(predecessors, successor) + queue = append(queue, successor) + } + } + } + + if len(predecessors) != 0 { + // tasks that are left cannot have their dependencies satisfied + var unsatisfiedTasks []string + for id := range predecessors { + unsatisfiedTasks = append(unsatisfiedTasks, id) + } + sort.Strings(unsatisfiedTasks) + msg := strings.Builder{} + msg.WriteString("dependency cycle involving tasks [") + for i, id := range unsatisfiedTasks { + t := taskByID[id] + msg.WriteString(fmt.Sprintf("%v:%v", t.id, t.kind)) + if i < len(unsatisfiedTasks)-1 { + msg.WriteRune(' ') + } + } + msg.WriteRune(']') + return &TaskDependencyCycleError{ + IDs: unsatisfiedTasks, + msg: msg.String(), + } + } + return nil +} diff --git a/internals/overlord/state/change_test.go b/internals/overlord/state/change_test.go index 52b77bfc..e25c4f55 100644 --- a/internals/overlord/state/change_test.go +++ b/internals/overlord/state/change_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -20,15 +20,16 @@ package state_test import ( + "errors" "fmt" "sort" "strconv" "strings" "time" - "github.com/canonical/pebble/internals/overlord/state" - . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/overlord/state" ) type changeSuite struct{} @@ -68,7 +69,7 @@ func (cs *changeSuite) TestReadyTime(c *C) { } func (cs *changeSuite) TestStatusString(c *C) { - for s := state.Status(0); s < state.ErrorStatus+1; s++ { + for s := state.Status(0); s < state.WaitStatus+1; s++ { c.Assert(s.String(), Matches, ".+") } } @@ -88,6 +89,21 @@ func (cs *changeSuite) TestGetSet(c *C) { c.Check(v, Equals, 1) } +func (cs *changeSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("install", "...") + c.Check(chg.Has("a"), Equals, false) + + chg.Set("a", 1) + c.Check(chg.Has("a"), Equals, true) + + chg.Set("a", nil) + c.Check(chg.Has("a"), Equals, false) +} + // TODO Better testing of full change roundtripping via JSON. func (cs *changeSuite) TestNewTaskAddTaskAndTasks(c *C) { @@ -227,9 +243,13 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { tasks := make(map[state.Status]*state.Task) - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } chg.AddTask(t) tasks[s] = t } @@ -240,6 +260,7 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { state.UndoStatus, state.DoingStatus, state.DoStatus, + state.WaitStatus, state.ErrorStatus, state.UndoneStatus, state.DoneStatus, @@ -252,7 +273,11 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { if s == s2 { break } - tasks[s2].SetStatus(s) + if s == state.WaitStatus { + tasks[s2].SetToWait(state.DoneStatus) + } else { + tasks[s2].SetStatus(s) + } } c.Assert(chg.Status(), Equals, s) } @@ -432,9 +457,13 @@ func (cs *changeSuite) TestAbort(c *C) { chg := st.NewChange("install", "...") - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } t.Set("old-status", s) chg.AddTask(t) } @@ -451,7 +480,7 @@ func (cs *changeSuite) TestAbort(c *C) { switch s { case state.DoStatus: c.Assert(t.Status(), Equals, state.HoldStatus) - case state.DoneStatus: + case state.DoneStatus, state.WaitStatus: c.Assert(t.Status(), Equals, state.UndoStatus) case state.DoingStatus: c.Assert(t.Status(), Equals, state.AbortStatus) @@ -526,11 +555,13 @@ func (cs *changeSuite) TestAbortKⁿ(c *C) { // Task wait order: // -// => t21 => t22 -// / \ -// t11 => t12 => t41 => t42 -// \ / -// => t31 => t32 +// => t21 => t22 +// / \ +// +// t11 => t12 => t41 => t42 +// +// \ / +// => t31 => t32 // // setup and result lines are :[:,...] // @@ -577,6 +608,10 @@ var abortLanesTests = []struct { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{2}, result: "t21:hold t22:hold t41:hold t42:hold *:do", + }, { + setup: "t11:done:1 t12:wait:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + abort: []int{2}, + result: "t21:hold t22:hold t41:hold t42:hold t11:done t12:wait *:do", }, { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{3}, @@ -660,7 +695,7 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Logf("Testing setup: %s", test.setup) statuses := make(map[string]state.Status) - for s := state.DefaultStatus; s <= state.ErrorStatus; s++ { + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { statuses[strings.ToLower(s.String())] = s } @@ -680,7 +715,11 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { } seen[parts[0]] = true task := tasks[parts[0]] - task.SetStatus(statuses[parts[1]]) + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } if len(parts) > 2 { lanes := strings.Split(parts[2], ",") for _, lane := range lanes { @@ -723,3 +762,691 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) } } + +// setup and result lines are :[:,...] +// order is -> (implies task2 waits for task 1) +// "*" as task name means "all remaining". +var abortUnreadyLanesTests = []struct { + setup string + order string + result string +}{ + + // Some basics. + { + setup: "*:do", + result: "*:hold", + }, { + setup: "*:wait", + result: "*:undo", + }, { + setup: "*:done", + result: "*:done", + }, { + setup: "*:error", + result: "*:error", + }, + + // t11 (1) => t12 (1) => t21 (1) => t22 (1) + // t31 (2) => t32 (2) => t41 (2) => t42 (2) + { + setup: "t11:do:1 t12:do:1 t21:do:1 t22:do:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:hold t32:hold t41:hold t42:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:done:2 t32:done:2 t41:done:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:undo t32:undo t41:undo t42:hold", + }, + // => t21 (2) => t22 (2) + // / \ + // t11 (2,3) => t12 (2,3) => t41 (4) => t42 (4) + // \ / + // => t31 (3) => t32 (3) + { + setup: "t11:do:2,3 t12:do:2,3 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:abort t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:wait:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:undo t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:doing:2 t22:do:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "t21:abort t22:hold t31:abort t32:hold t41:hold t42:hold *:undo", + }, + + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "*:hold", + }, { + setup: "t11:do:1 t12:do:1 t21:doing:2 t22:do:2 t31:done:3 t32:doing:3 t41:undone:4 t42:error:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "t11:hold t12:hold t21:abort t22:hold t31:undo t32:abort t41:undone t42:error", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:abort *:hold", + }, { + // + setup: "t11:done:1 t12:done:1 t21:done:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:undo *:hold", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + result: "t11:done t12:done *:hold", + }, +} + +func (ts *taskRunnerSuite) TestAbortUnreadyLanes(c *C) { + + names := strings.Fields("t11 t12 t21 t22 t31 t32 t41 t42") + + for i, test := range abortUnreadyLanesTests { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + statuses := make(map[string]state.Status) + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { + statuses[strings.ToLower(s.String())] = s + } + + items := strings.Fields(test.setup) + seen := make(map[string]bool) + for i := 0; i < len(items); i++ { + item := items[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(items)-1, Commentf("*: can only be used as the last entry")) + for _, name := range names { + if !seen[name] { + parts[0] = name + items = append(items, strings.Join(parts, ":")) + } + } + continue + } + seen[parts[0]] = true + task := tasks[parts[0]] + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } + if len(parts) > 2 { + lanes := strings.Split(parts[2], ",") + for _, lane := range lanes { + n, err := strconv.Atoi(lane) + c.Assert(err, IsNil) + task.JoinLane(n) + } + } + } + + c.Logf("Aborting") + + chg.AbortUnreadyLanes() + + c.Logf("Expected result: %s", test.result) + + seen = make(map[string]bool) + var expected = strings.Fields(test.result) + var obtained []string + for i := 0; i < len(expected); i++ { + item := expected[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(expected)-1, Commentf("*: can only be used as the last entry")) + var expanded []string + for _, name := range names { + if !seen[name] { + parts[0] = name + expanded = append(expanded, strings.Join(parts, ":")) + } + } + expected = append(expected[:i], append(expanded, expected[i+1:]...)...) + i-- + continue + } + name := parts[0] + seen[parts[0]] = true + obtained = append(obtained, name+":"+strings.ToLower(tasks[name].Status().String())) + } + + c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) + } +} + +// setup is a list of tasks " ", order is -> +// (implies task2 waits for task 1) +var cyclicDependencyTests = []struct { + setup string + order string + err string + errIDs []string +}{ + + // Some basics. + { + setup: "t1", + }, { + setup: "", + }, { + // independent tasks + setup: "t1 t2 t3", + }, { + // some independent and some ordered tasks + setup: "t1 t2 t3 t4", + order: "t2->t3", + }, + // some independent, dependencies as if added by WaitAll() + // t1 => t2 + // t1,t2 => t3 + // t1,t2,t3 => t4 + { + setup: "t1 t2 t3 t4", + order: "t1->t2 t1->t3 t2->t3 t1->t4 t2->t4 t3->t4", + }, { + // simple loop + setup: "t1 t2", + order: "t1->t2 t2->t1", + err: `dependency cycle involving tasks \[1:t1 2:t2\]`, + errIDs: []string{"1", "2"}, + }, + + // t1 => t2 => t3 => t4 + // t5 => t6 => t7 => t8 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t5->t6 t6->t7 t7->t8", + }, + // => t21 => t22 + // / \ + // t11 => t12 => t41 => t42 + // \ / + // => t31 => t32 + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + }, + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + err: `dependency cycle involving tasks \[3:t21 4:t22 5:t31 6:t32 7:t41 8:t42\]`, + errIDs: []string{"3", "4", "5", "6", "7", "8"}, + }, + // t1 => t2 => t3 => t4 --> t6 + // t5 => t6 => t7 => t8 --> t2 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t4->t6 t5->t6 t6->t7 t7->t8 t8->t2", + err: `dependency cycle involving tasks \[2:t2 3:t3 4:t4 6:t6 7:t7 8:t8\]`, + errIDs: []string{"2", "3", "4", "6", "7", "8"}, + }, +} + +func (ts *taskRunnerSuite) TestCheckTaskDependencies(c *C) { + + for i, test := range cyclicDependencyTests { + names := strings.Fields(test.setup) + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask(name, name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + err := chg.CheckTaskDependencies() + + if test.err != "" { + c.Assert(err, ErrorMatches, test.err) + c.Assert(errors.Is(err, &state.TaskDependencyCycleError{}), Equals, true) + errTasksDepCycle := err.(*state.TaskDependencyCycleError) + c.Assert(errTasksDepCycle.IDs, DeepEquals, test.errIDs) + } else { + c.Assert(err, IsNil) + } + } +} + +func (cs *changeSuite) TestIsWaitingStatusOrderWithWaits(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t1.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Set the wait-task into WaitStatus, to ensure we trigger the isWaiting + // logic and that it doesn't return WaitStatus for statuses which are in + // higher order + t4.SetToWait(state.DoneStatus) + + // Test the following sequences: + // task1 (do) => task2 (done) => task3 (doing) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (done) => task2 (done) => task3 (undoing) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoingStatus) + c.Check(chg.Status(), Equals, state.UndoingStatus) + + // task1 (done) => task2 (done) => task3 (abort) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.AbortStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingSingle(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + + chg.AddTask(t1) + c.Check(chg.Status(), Equals, state.DoStatus) + + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) no reboot + t1.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (wait) => task2 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingCircularDependency(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + + // Setup circular dependency between t1,t2 and t3, they should + // still act normally. + t2.WaitFor(t1) + t3.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // To trigger the cyclic dependency check, we must trigger the isWaiting logic + // and we do this by putting t4 into WaitStatus. + t4.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (wait) => task2 (do) => task3 (do) means need a reboot + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Put t4 into wait-status to trigger the isWaiting logic each time + // for the change. + t4.SetToWait(state.DoneStatus) + + // task1 (do) + task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (done) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // For the next two cases we are testing that a task with dependencies + // which have completed, but in a non-successful way is handled correctly. + // task1 (error) + task2 (wait) => task3 (do) means need reboot + // to finalize task2 + t1.SetStatus(state.ErrorStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (error) => task3 (do) means need reboot + // to finalize task1 + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) => task3 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (done) => task3 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (abort) => task3 (do) + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.AbortStatus) + t3.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // we use <=| to denote the reverse dependence relationship + // followed by undo logic + + // task1 (undo) <=| task2 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (wait) means need a reboot + t1.SetStatus(state.UndoStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) <=| task2 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("task4", "...") + t5 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + t4.WaitFor(t1) + t4.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + chg.AddTask(t5) + + // Put t5 into wait-status to trigger the isWaiting logic each time + // for the change. + t5.SetToWait(state.DoneStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (wait) + task4 (error) means + // need reboot to continue undoing 1 and 2 + t3.SetStatus(state.ErrorStatus) + t4.SetToWait(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (undo) + task2 (undo) => task3 (error) + task4 (wait) means + // need reboot to continue undoing 1 and 2 + t3.SetToWait(state.UndoneStatus) + t4.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undo) no reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (wait) + task2 (done) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} diff --git a/internals/overlord/state/copy.go b/internals/overlord/state/copy.go new file mode 100644 index 00000000..f87c2e85 --- /dev/null +++ b/internals/overlord/state/copy.go @@ -0,0 +1,141 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2020 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package state + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/canonical/pebble/internals/jsonutil" + "github.com/canonical/pebble/internals/osutil" +) + +type checkpointOnlyBackend struct { + path string +} + +func (b *checkpointOnlyBackend) Checkpoint(data []byte) error { + if err := os.MkdirAll(filepath.Dir(b.path), 0755); err != nil { + return err + } + return osutil.AtomicWriteFile(b.path, data, 0600, 0) +} + +func (b *checkpointOnlyBackend) EnsureBefore(d time.Duration) { + panic("cannot use EnsureBefore in checkpointOnlyBackend") +} + +// copyData will copy the given subkeys specifier from srcData to dstData. +// +// The subkeys is constructed from a dotted path like "user.auth". This copy +// helper is recursive and the pos parameter tells the function the current +// position of the copy. +func copyData(subkeys []string, pos int, srcData map[string]*json.RawMessage, dstData map[string]interface{}) error { + if pos < 0 || pos > len(subkeys) { + return fmt.Errorf("internal error: copyData used with an out-of-bounds position: %v not in [0:%v]", pos, len(subkeys)) + } + raw, ok := srcData[subkeys[pos]] + if !ok { + return ErrNoState + } + + if pos+1 == len(subkeys) { + dstData[subkeys[pos]] = raw + return nil + } + + var srcDatam map[string]*json.RawMessage + if err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &srcDatam); err != nil { + return fmt.Errorf("cannot unmarshal state entry %q with value %q as a map while trying to copy over %q", strings.Join(subkeys[:pos+1], "."), *raw, strings.Join(subkeys, ".")) + } + + // no subkey entry -> create one + if _, ok := dstData[subkeys[pos]]; !ok { + dstData[subkeys[pos]] = make(map[string]interface{}) + } + // and use existing data + var dstDatam map[string]interface{} + switch dstDataEntry := dstData[subkeys[pos]].(type) { + case map[string]interface{}: + dstDatam = dstDataEntry + case *json.RawMessage: + dstDatam = make(map[string]interface{}) + if err := jsonutil.DecodeWithNumber(bytes.NewReader(*dstDataEntry), &dstDatam); err != nil { + return fmt.Errorf("internal error: cannot decode subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstDataEntry) + } + default: + return fmt.Errorf("internal error: cannot create subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstData[subkeys[pos]]) + } + + return copyData(subkeys, pos+1, srcDatam, dstDatam) +} + +// CopyState takes a state from the srcStatePath and copies all +// dataEntries to the dstPath. Note that srcStatePath should never +// point to a state that is in use. +func CopyState(srcStatePath, dstStatePath string, dataEntries []string) error { + if osutil.FileExists(dstStatePath) { + // XXX: TOCTOU - look into moving this check into + // checkpointOnlyBackend. The issue is right now State + // will simply panic if Commit() returns an error + return fmt.Errorf("cannot copy state: %q already exists", dstStatePath) + } + if len(dataEntries) == 0 { + return fmt.Errorf("cannot copy state: must provide at least one data entry to copy") + } + + f, err := os.Open(srcStatePath) + if err != nil { + return fmt.Errorf("cannot open state: %s", err) + } + defer f.Close() + + // No need to lock/unlock the state here, srcState should not be + // in use at all. + srcState, err := ReadState(nil, f) + if err != nil { + return err + } + + // copy relevant data + dstData := make(map[string]interface{}) + for _, dataEntry := range dataEntries { + subkeys := strings.Split(dataEntry, ".") + if err := copyData(subkeys, 0, srcState.data, dstData); err != nil && !errors.Is(err, ErrNoState) { + return err + } + } + + // write it out + dstState := New(&checkpointOnlyBackend{path: dstStatePath}) + dstState.Lock() + defer dstState.Unlock() + for k, v := range dstData { + dstState.Set(k, v) + } + + return nil +} diff --git a/internals/overlord/state/copy_test.go b/internals/overlord/state/copy_test.go new file mode 100644 index 00000000..e7130e88 --- /dev/null +++ b/internals/overlord/state/copy_test.go @@ -0,0 +1,146 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016-2020 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package state_test + +import ( + "io/ioutil" + "path/filepath" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/overlord/state" +) + +func (ss *stateSuite) TestCopyStateAlreadyExists(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err := ioutil.WriteFile(dstStateFile, nil, 0644) + c.Assert(err, IsNil) + + err = state.CopyState(srcStateFile, dstStateFile, []string{"some-data"}) + c.Assert(err, ErrorMatches, `cannot copy state: "/.*/dst-state.json" already exists`) +} +func (ss *stateSuite) TestCopyStateNoDataEntriesToCopy(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + + err := state.CopyState(srcStateFile, dstStateFile, nil) + c.Assert(err, ErrorMatches, `cannot copy state: must provide at least one data entry to copy`) +} + +var srcStateContent = []byte(` +{ + "data": { + "api-download-tokens-secret": "123", + "api-download-tokens-secret-time": "2020-02-21T10:32:37.916147296Z", + "auth": { + "last-id": 1, + "users": [ + { + "id": 1, + "email": "some@user.com", + "macaroon": "1234", + "store-macaroon": "5678", + "store-discharges": [ + "9012345" + ] + } + ], + "device": { + "brand": "generic", + "model": "generic-classic", + "serial": "xxxxx-yyyyy-", + "key-id": "xxxxxx", + "session-macaroon": "xxxx" + }, + "macaroon-key": "xxxx=" + }, + "config": { + } + } +} +`) + +const stateSuffix = `,"changes":{},"tasks":{},"last-change-id":0,"last-task-id":0,"last-lane-id":0}` + +func (ss *stateSuite) TestCopyStateIntegration(c *C) { + // create a mock srcState + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent, 0644) + c.Assert(err, IsNil) + + // copy + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"auth.users", "no-existing-does-not-error", "auth.last-id"}) + c.Assert(err, IsNil) + + // and check that the right bits got copied + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"auth":{"last-id":1,"users":[{"id":1,"email":"some@user.com","macaroon":"1234","store-macaroon":"5678","store-discharges":["9012345"]}]}}`+stateSuffix) +} + +var srcStateContent1 = []byte(`{ + "data": { + "A": {"B": [{"C": 1}, {"D": 2}]}, + "E": {"F": 2, "G": 3}, + "H": 4, + "I": null + } +}`) + +func (ss *stateSuite) TestCopyState(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"A.B", "no-existing-does-not-error", "E.F", "E", "I", "E.non-existing"}) + c.Assert(err, IsNil) + + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"A":{"B":[{"C":1},{"D":2}]},"E":{"F":2,"G":3},"I":null}`+stateSuffix) +} + +func (ss *stateSuite) TestCopyStateUnmarshalNotMap(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"E.F.subkey-not-in-a-map"}) + c.Assert(err, ErrorMatches, `cannot unmarshal state entry "E.F" with value "2" as a map while trying to copy over "E.F.subkey-not-in-a-map"`) +} + +func (ss *stateSuite) TestCopyStateDuplicatesInDataEntriesAreFine(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"E", "E"}) + c.Assert(err, IsNil) + + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"E":{"F":2,"G":3}}`+stateSuffix) +} diff --git a/internals/overlord/state/export_test.go b/internals/overlord/state/export_test.go index adbcc5d5..3bf6f8ef 100644 --- a/internals/overlord/state/export_test.go +++ b/internals/overlord/state/export_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -23,8 +23,8 @@ import ( "time" ) -// FakeCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime. -func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) { +// MockCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime. +func MockCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) { oldInterval := unlockCheckpointRetryInterval oldMaxTime := unlockCheckpointRetryMaxTime unlockCheckpointRetryInterval = retryInterval @@ -35,12 +35,12 @@ func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restor } } -func FakeChangeTimes(chg *Change, spawnTime, readyTime time.Time) { +func MockChangeTimes(chg *Change, spawnTime, readyTime time.Time) { chg.spawnTime = spawnTime chg.readyTime = readyTime } -func FakeTaskTimes(t *Task, spawnTime, readyTime time.Time) { +func MockTaskTimes(t *Task, spawnTime, readyTime time.Time) { t.spawnTime = spawnTime t.readyTime = readyTime } diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index 660a8a64..86733024 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -46,7 +46,7 @@ type customData map[string]*json.RawMessage func (data customData) get(key string, value interface{}) error { entryJSON := data[key] if entryJSON == nil { - return ErrNoState + return &NoStateError{Key: key} } err := json.Unmarshal(*entryJSON, value) if err != nil { @@ -87,6 +87,9 @@ type State struct { lastTaskId int lastChangeId int lastLaneId int + // lastHandlerId is not serialized, it's only used during runtime + // for registering runtime callbacks + lastHandlerId int backend Backend data customData @@ -97,18 +100,27 @@ type State struct { modified bool cache map[interface{}]interface{} + + pendingChangeByAttr map[string]func(*Change) bool + + // task/changes observing + taskHandlers map[int]func(t *Task, old, new Status) + changeHandlers map[int]func(chg *Change, old, new Status) } // New returns a new empty state. func New(backend Backend) *State { return &State{ - backend: backend, - data: make(customData), - changes: make(map[string]*Change), - tasks: make(map[string]*Task), - warnings: make(map[string]*Warning), - modified: true, - cache: make(map[interface{}]interface{}), + backend: backend, + data: make(customData), + changes: make(map[string]*Change), + tasks: make(map[string]*Task), + warnings: make(map[string]*Warning), + modified: true, + cache: make(map[interface{}]interface{}), + pendingChangeByAttr: make(map[string]func(*Change) bool), + taskHandlers: make(map[int]func(t *Task, old Status, new Status)), + changeHandlers: make(map[int]func(chg *Change, old Status, new Status)), } } @@ -241,6 +253,28 @@ func (s *State) EnsureBefore(d time.Duration) { // ErrNoState represents the case of no state entry for a given key. var ErrNoState = errors.New("no state entry for key") +// NoStateError represents the case where no state could be found for a given key. +type NoStateError struct { + // Key is the key for which no state could be found. + Key string +} + +func (e *NoStateError) Error() string { + var keyMsg string + if e.Key != "" { + keyMsg = fmt.Sprintf(" %q", e.Key) + } + + return fmt.Sprintf("no state entry for key%s", keyMsg) +} + +// Is returns true if the error is of type *NoStateError or equal to ErrNoState. +// NoStateError's key isn't compared between errors. +func (e *NoStateError) Is(err error) bool { + _, ok := err.(*NoStateError) + return ok || errors.Is(err, ErrNoState) +} + // Get unmarshals the stored value associated with the provided key // into the value parameter. // It returns ErrNoState if there is no entry for key. @@ -249,6 +283,12 @@ func (s *State) Get(key string, value interface{}) error { return s.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (s *State) Has(key string) bool { + s.reading() + return s.data.has(key) +} + // Set associates value with key for future consulting by managers. // The provided value must properly marshal and unmarshal with encoding/json. func (s *State) Set(key string, value interface{}) { @@ -357,15 +397,25 @@ func (s *State) tasksIn(tids []string) []*Task { return res } +// RegisterPendingChangeByAttr registers predicates that will be invoked by +// Prune on changes with the specified attribute set to check whether even if +// they meet the time criteria they must not be aborted yet. +func (s *State) RegisterPendingChangeByAttr(attr string, f func(*Change) bool) { + s.pendingChangeByAttr[attr] = f +} + // Prune does several cleanup tasks to the in-memory state: // // - it removes changes that became ready for more than pruneWait and aborts -// tasks spawned for more than abortWait. +// tasks spawned for more than abortWait unless prevented by predicates +// registered with RegisterPendingChangeByAttr. +// // - it removes tasks unlinked to changes after pruneWait. When there are more // changes than the limit set via "maxReadyChanges" those changes in ready // state will also removed even if they are below the pruneWait duration. +// // - it removes expired warnings. -func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { +func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) { now := time.Now() pruneLimit := now.Add(-pruneWait) abortLimit := now.Add(-abortWait) @@ -392,15 +442,24 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +NextChange: for _, chg := range changes { - spawnTime := chg.SpawnTime() readyTime := chg.ReadyTime() + spawnTime := chg.SpawnTime() + if spawnTime.Before(startOfOperation) { + spawnTime = startOfOperation + } if readyTime.IsZero() { if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 { chg.Abort() delete(s.changes, chg.ID()) } else if spawnTime.Before(abortLimit) { - chg.Abort() + for attr, pending := range s.pendingChangeByAttr { + if chg.Has(attr) && pending(chg) { + continue NextChange + } + } + chg.AbortUnreadyLanes() } continue } @@ -424,6 +483,75 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +// GetMaybeTimings implements snapcore/snapd/timings.GetSaver +func (s *State) GetMaybeTimings(timings interface{}) error { + if err := s.Get("timings", timings); err != nil && !errors.Is(err, ErrNoState) { + return err + } + return nil +} + +// AddTaskStatusChangedHandler adds a callback function that will be invoked +// whenever tasks change status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddTaskStatusChangedHandler(f func(t *Task, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.taskHandlers[id] = f + return id +} + +func (s *State) RemoveTaskStatusChangedHandler(id int) { + s.reading() + delete(s.taskHandlers, id) +} + +func (s *State) notifyTaskStatusChangedHandlers(t *Task, old, new Status) { + s.reading() + for _, f := range s.taskHandlers { + f(t, old, new) + } +} + +// AddChangeStatusChangedHandler adds a callback function that will be invoked +// whenever a Change changes status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddChangeStatusChangedHandler(f func(chg *Change, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.changeHandlers[id] = f + return id +} + +func (s *State) RemoveChangeStatusChangedHandler(id int) { + s.reading() + delete(s.changeHandlers, id) +} + +func (s *State) notifyChangeStatusChangedHandlers(chg *Change, old, new Status) { + s.reading() + for _, f := range s.changeHandlers { + f(chg, old, new) + } +} + +// SaveTimings implements snapcore/snapd/timings.GetSaver +func (s *State) SaveTimings(timings interface{}) { + s.Set("timings", timings) +} + // ReadState returns the state deserialized from r. func ReadState(backend Backend, r io.Reader) (*State, error) { s := new(State) @@ -437,5 +565,6 @@ func ReadState(backend Backend, r io.Reader) (*State, error) { s.backend = backend s.modified = false s.cache = make(map[interface{}]interface{}) + s.pendingChangeByAttr = make(map[string]func(*Change) bool) return s, err } diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go index e1e8a707..2ab46026 100644 --- a/internals/overlord/state/state_test.go +++ b/internals/overlord/state/state_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -29,6 +29,7 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestState(t *testing.T) { TestingT(t) } @@ -76,6 +77,37 @@ func (ss *stateSuite) TestGetAndSet(c *C) { c.Check(&mSt2B, DeepEquals, mSt2) } +func (ss *stateSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + c.Check(st.Has("a"), Equals, false) + + st.Set("a", 1) + c.Check(st.Has("a"), Equals, true) + + st.Set("a", nil) + c.Check(st.Has("a"), Equals, false) +} + +func (ss *stateSuite) TestStrayTaskWithNoChange(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + _ = st.NewTask("bar", "...") + + // only the task with associate change is returned + c.Assert(st.Tasks(), HasLen, 1) + c.Assert(st.Tasks()[0].ID(), Equals, t1.ID()) + // but count includes all tasks + c.Assert(st.TaskCount(), Equals, 2) +} + func (ss *stateSuite) TestSetPanic(c *C) { st := state.New(nil) st.Lock() @@ -94,7 +126,7 @@ func (ss *stateSuite) TestGetNoState(c *C) { var mSt1B mgrState1 err := st.Get("mgr9", &mSt1B) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) } func (ss *stateSuite) TestSetToNilDeletes(c *C) { @@ -112,7 +144,7 @@ func (ss *stateSuite) TestSetToNilDeletes(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -126,7 +158,7 @@ func (ss *stateSuite) TestNullMeansNoState(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -227,7 +259,7 @@ func (ss *stateSuite) TestImplicitCheckpointAndRead(c *C) { } func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) defer restore() retries := 0 @@ -250,7 +282,7 @@ func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) { } func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond) defer restore() boom := errors.New("boom") @@ -272,7 +304,7 @@ func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) { } func (ss *stateSuite) TestImplicitCheckpointModifiedOnly(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) defer restore() b := &fakeStateBackend{} @@ -421,7 +453,7 @@ func (ss *stateSuite) TestNewTaskAndCheckpoint(c *C) { t1ID := t1.ID() t1.Set("a", 1) t1.SetStatus(state.DoneStatus) - t1.SetProgress("foo", 5, 10) + t1.SetProgress("snap", 5, 10) t1.JoinLane(42) t1.JoinLane(43) @@ -702,7 +734,7 @@ func (ss *stateSuite) TestMethodEntrance(c *C) { func() { st.Tasks() }, func() { st.Task("foo") }, func() { st.MarshalJSON() }, - func() { st.Prune(time.Hour, time.Hour, 100) }, + func() { st.Prune(time.Now(), time.Hour, time.Hour, 100) }, func() { st.TaskCount() }, func() { st.AllWarnings() }, func() { st.PendingWarnings() }, @@ -744,31 +776,32 @@ func (ss *stateSuite) TestPrune(c *C) { chg1 := st.NewChange("abort", "...") chg1.AddTask(t1) - state.FakeChangeTimes(chg1, now.Add(-abortWait), unset) + state.MockChangeTimes(chg1, now.Add(-abortWait), unset) chg2 := st.NewChange("prune", "...") chg2.AddTask(t2) c.Assert(chg2.Status(), Equals, state.DoStatus) - state.FakeChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait)) + state.MockChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait)) chg3 := st.NewChange("ready-but-recent", "...") chg3.AddTask(t3) - state.FakeChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2)) + state.MockChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2)) chg4 := st.NewChange("old-but-not-ready", "...") chg4.AddTask(t4) - state.FakeChangeTimes(chg4, now.Add(-pruneWait/2), unset) + state.MockChangeTimes(chg4, now.Add(-pruneWait/2), unset) // unlinked task t5 := st.NewTask("unliked", "...") c.Check(st.Task(t5.ID()), IsNil) - state.FakeTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait)) + state.MockTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait)) // two warnings, one expired st.AddWarning("hello", now, never, time.Nanosecond, state.DefaultRepeatAfter) st.Warnf("hello again") - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg1.ID()), Equals, chg1) c.Assert(st.Change(chg2.ID()), IsNil) @@ -793,6 +826,55 @@ func (ss *stateSuite) TestPrune(c *C) { c.Check(st.AllWarnings(), HasLen, 1) } +func (ss *stateSuite) TestRegisterPendingChangeByAttr(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + unset := time.Time{} + + t1 := st.NewTask("foo", "...") + t2 := st.NewTask("foo", "...") + t3 := st.NewTask("foo", "...") + t4 := st.NewTask("foo", "...") + + chg1 := st.NewChange("abort", "...") + chg1.AddTask(t1) + chg1.AddTask(t2) + state.MockChangeTimes(chg1, now.Add(-abortWait), unset) + + chg2 := st.NewChange("pending", "...") + chg2.AddTask(t3) + chg2.AddTask(t4) + state.MockChangeTimes(chg2, now.Add(-abortWait), unset) + chg2.Set("pending-flag", true) + t3.SetStatus(state.HoldStatus) + + st.RegisterPendingChangeByAttr("pending-flag", func(chg *state.Change) bool { + c.Check(chg.ID(), Equals, chg2.ID()) + return true + }) + + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) + + c.Assert(st.Change(chg1.ID()), Equals, chg1) + c.Assert(st.Change(chg2.ID()), Equals, chg2) + c.Assert(st.Task(t1.ID()), Equals, t1) + c.Assert(st.Task(t2.ID()), Equals, t2) + c.Assert(st.Task(t3.ID()), Equals, t3) + c.Assert(st.Task(t4.ID()), Equals, t4) + + c.Assert(t1.Status(), Equals, state.HoldStatus) + c.Assert(t2.Status(), Equals, state.HoldStatus) + c.Assert(t3.Status(), Equals, state.HoldStatus) + c.Assert(t4.Status(), Equals, state.DoStatus) +} + func (ss *stateSuite) TestPruneEmptyChange(c *C) { // Empty changes are a bit special because they start out on Hold // which is a Ready status, but the change itself is not considered Ready @@ -807,9 +889,10 @@ func (ss *stateSuite) TestPruneEmptyChange(c *C) { abortWait := 3 * time.Hour chg := st.NewChange("abort", "...") - state.FakeChangeTimes(chg, now.Add(-pruneWait), time.Time{}) + state.MockChangeTimes(chg, now.Add(-pruneWait), time.Time{}) - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg.ID()), IsNil) } @@ -831,7 +914,7 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) { t.SetStatus(state.DoneStatus) when := time.Duration(i) * time.Second - state.FakeChangeTimes(chg, now.Add(-when), now.Add(-when)) + state.MockChangeTimes(chg, now.Add(-when), now.Add(-when)) } c.Assert(st.Changes(), HasLen, 10) @@ -844,13 +927,14 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) { // test that nothing is done when we are within pruneWait and // maxReadyChanges + past := time.Now().AddDate(-1, 0, 0) maxReadyChanges := 100 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 15) // but with maxReadyChanges we remove the ready ones maxReadyChanges = 5 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 10) remaining := map[string]bool{} for _, chg := range st.Changes() { @@ -886,8 +970,9 @@ func (ss *stateSuite) TestPruneMaxChangesSomeNotReady(c *C) { c.Assert(st.Changes(), HasLen, 10) // nothing can be pruned + past := time.Now().AddDate(-1, 0, 0) maxChanges := 5 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 10) } @@ -905,10 +990,10 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { c.Assert(st.Changes(), HasLen, 10) // one extra change that just now entered ready state - chg := st.NewChange(fmt.Sprintf("chg99"), "so-ready") + chg := st.NewChange("chg99", "so-ready") t := st.NewTask("foo", "so-ready") when := 1 * time.Second - state.FakeChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when)) + state.MockChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when)) t.SetStatus(state.DoneStatus) chg.AddTask(t) @@ -916,11 +1001,45 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { // // this test we do not purge the freshly ready change maxChanges := 10 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 11) } -func (ss *stateSuite) TestReadStateInitsCache(c *C) { +func (ss *stateSuite) TestPruneHonorsStartOperationTime(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + + startTime := 2 * time.Hour + spawnTime := 10 * time.Hour + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + chg := st.NewChange("change", "...") + t := st.NewTask("foo", "") + chg.AddTask(t) + // change spawned 10h ago + state.MockChangeTimes(chg, now.Add(-spawnTime), time.Time{}) + + // start operation time is 2h ago, change is not aborted because + // it's less than abortWait limit. + opTime := now.Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoStatus) + + // start operation time is 9h ago, change is aborted. + startTime = 9 * time.Hour + opTime = time.Now().Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.HoldStatus) +} + +func (ss *stateSuite) TestReadStateInitsTransientMapFields(c *C) { st, err := state.ReadState(nil, bytes.NewBufferString("{}")) c.Assert(err, IsNil) st.Lock() @@ -928,4 +1047,195 @@ func (ss *stateSuite) TestReadStateInitsCache(c *C) { st.Cache("key", "value") c.Assert(st.Cached("key"), Equals, "value") + st.RegisterPendingChangeByAttr("attr", func(*state.Change) bool { return false }) +} + +func (ss *stateSuite) TestTimingsSupport(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var tims []int + + err := st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, IsNil) + + st.SaveTimings([]int{1, 2, 3}) + + err = st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, DeepEquals, []int{1, 2, 3}) +} + +func (ss *stateSuite) TestNoStateErrorIs(c *C) { + err := &state.NoStateError{Key: "foo"} + c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) + c.Assert(err, testutil.ErrorIs, &state.NoStateError{Key: "bar"}) + c.Assert(err, testutil.ErrorIs, state.ErrNoState) +} + +func (ss *stateSuite) TestNoStateErrorString(c *C) { + err := &state.NoStateError{} + c.Assert(err.Error(), Equals, `no state entry for key`) + err.Key = "foo" + c.Assert(err.Error(), Equals, `no state entry for key "foo"`) +} + +type taskAndStatus struct { + t *state.Task + old, new state.Status +} + +func (ss *stateSuite) TestTaskChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var taskObservedChanges []taskAndStatus + oId := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) { + taskObservedChanges = append(taskObservedChanges, taskAndStatus{ + t: t, + old: old, + new: new, + }) + }) + + t1 := st.NewTask("foo", "...") + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // task events when task don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to done. + t1.SetStatus(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveTaskStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoingStatus) + + c.Check(taskObservedChanges, DeepEquals, []taskAndStatus{ + { + t: t1, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + t: t1, + old: state.DoingStatus, + new: state.DoneStatus, + }, + }) +} + +type changeAndStatus struct { + chg *state.Change + old, new state.Status +} + +func (ss *stateSuite) TestChangeChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // change events when changes don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to waiting + t1.SetToWait(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveChangeStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoneStatus) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + }) +} + +func (ss *stateSuite) TestChangeSetStatusChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // We have a single task in Doing, now we manipulate the status + // of the change to ensure we are receiving correct events + chg.SetStatus(state.WaitStatus) + + // Change to a new status + chg.SetStatus(state.ErrorStatus) + + // Now return the status back to Default, which should result + // in the change reporting Doing + chg.SetStatus(state.DefaultStatus) + st.RemoveChangeStatusChangedHandler(oId) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + { + chg: chg, + old: state.WaitStatus, + new: state.ErrorStatus, + }, + { + chg: chg, + old: state.ErrorStatus, + new: state.DoingStatus, + }, + }) } diff --git a/internals/overlord/state/task.go b/internals/overlord/state/task.go index a3e2a486..fb5ec826 100644 --- a/internals/overlord/state/task.go +++ b/internals/overlord/state/task.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -38,19 +38,22 @@ type progress struct { // // See Change for more details. type Task struct { - state *State - id string - kind string - summary string - status Status - clean bool - progress *progress - data customData - waitTasks []string - haltTasks []string - lanes []int - log []string - change string + state *State + id string + kind string + summary string + status Status + // waitedStatus is the Status that should be used instead of + // WaitStatus once the wait is complete (i.e post reboot). + waitedStatus Status + clean bool + progress *progress + data customData + waitTasks []string + haltTasks []string + lanes []int + log []string + change string spawnTime time.Time readyTime time.Time @@ -77,18 +80,19 @@ func newTask(state *State, id, kind, summary string) *Task { } type marshalledTask struct { - ID string `json:"id"` - Kind string `json:"kind"` - Summary string `json:"summary"` - Status Status `json:"status"` - Clean bool `json:"clean,omitempty"` - Progress *progress `json:"progress,omitempty"` - Data map[string]*json.RawMessage `json:"data,omitempty"` - WaitTasks []string `json:"wait-tasks,omitempty"` - HaltTasks []string `json:"halt-tasks,omitempty"` - Lanes []int `json:"lanes,omitempty"` - Log []string `json:"log,omitempty"` - Change string `json:"change"` + ID string `json:"id"` + Kind string `json:"kind"` + Summary string `json:"summary"` + Status Status `json:"status"` + WaitedStatus Status `json:"waited-status"` + Clean bool `json:"clean,omitempty"` + Progress *progress `json:"progress,omitempty"` + Data map[string]*json.RawMessage `json:"data,omitempty"` + WaitTasks []string `json:"wait-tasks,omitempty"` + HaltTasks []string `json:"halt-tasks,omitempty"` + Lanes []int `json:"lanes,omitempty"` + Log []string `json:"log,omitempty"` + Change string `json:"change"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -111,18 +115,19 @@ func (t *Task) MarshalJSON() ([]byte, error) { atTime = &t.atTime } return json.Marshal(marshalledTask{ - ID: t.id, - Kind: t.kind, - Summary: t.summary, - Status: t.status, - Clean: t.clean, - Progress: t.progress, - Data: t.data, - WaitTasks: t.waitTasks, - HaltTasks: t.haltTasks, - Lanes: t.lanes, - Log: t.log, - Change: t.change, + ID: t.id, + Kind: t.kind, + Summary: t.summary, + Status: t.status, + WaitedStatus: t.waitedStatus, + Clean: t.clean, + Progress: t.progress, + Data: t.data, + WaitTasks: t.waitTasks, + HaltTasks: t.haltTasks, + Lanes: t.lanes, + Log: t.log, + Change: t.change, SpawnTime: t.spawnTime, ReadyTime: readyTime, @@ -148,6 +153,13 @@ func (t *Task) UnmarshalJSON(data []byte) error { t.kind = unmarshalled.Kind t.summary = unmarshalled.Summary t.status = unmarshalled.Status + t.waitedStatus = unmarshalled.WaitedStatus + if t.waitedStatus == DefaultStatus { + // For backwards-compatibility, default the waitStatus, which is + // the result status after a wait, to DoneStatus to keep any previous + // behaviour before any upgrade. + t.waitedStatus = DoneStatus + } t.clean = unmarshalled.Clean t.progress = unmarshalled.Progress custData := unmarshalled.Data @@ -188,6 +200,40 @@ func (t *Task) Summary() string { } // Status returns the current task status. +// +// Possible state transitions: +// +// /----aborting lane--Do +// | | +// V V +// Hold Doing-->Wait +// ^ / | \ +// | abort / V V +// no undo / Done Error +// | V | +// \----------Abort aborting lane +// / | | +// | finished or | +// running not running | +// V \------->| +// kill goroutine | +// | V +// / \ ----->Undo +// / no error / | +// | from goroutine | +// error | +// from goroutine | +// | V +// | Undoing-->Wait +// V | \ +// Error V V +// Undone Error +// +// Do -> Doing -> Done is the direct succcess scenario. +// +// Wait can transition to its waited status, +// usually Done|Undone or back to Doing. +// See Wait struct, SetToWait and WaitedStatus. func (t *Task) Status() Status { t.state.reading() if t.status == DefaultStatus { @@ -196,10 +242,10 @@ func (t *Task) Status() Status { return t.status } -// SetStatus sets the task status, overriding the default behavior (see Status method). -func (t *Task) SetStatus(new Status) { - t.state.writing() - old := t.status +func (t *Task) changeStatus(old, new Status) { + if old == new { + return + } t.status = new if !old.Ready() && new.Ready() { t.readyTime = timeNow() @@ -208,6 +254,55 @@ func (t *Task) SetStatus(new Status) { if chg != nil { chg.taskStatusChanged(t, old, new) } + t.state.notifyTaskStatusChangedHandlers(t, old, new) +} + +// SetStatus sets the task status, overriding the default behavior (see Status method). +func (t *Task) SetStatus(new Status) { + if new == WaitStatus { + panic("Task.SetStatus() called with WaitStatus, which is not allowed. Use SetToWait() instead") + } + + t.state.writing() + old := t.status + if new == DoneStatus && old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // DoneStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.changeStatus(old, new) +} + +// SetToWait puts the task into WaitStatus, and sets the status the task should be restored +// to after the SetToWait. +func (t *Task) SetToWait(resultStatus Status) { + switch resultStatus { + case DefaultStatus, WaitStatus: + panic("Task.SetToWait() cannot be invoked with either of DefaultStatus or WaitStatus") + } + + t.state.writing() + old := t.status + if old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // WaitStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.waitedStatus = resultStatus + t.changeStatus(old, WaitStatus) +} + +// WaitedStatus returns the status the Task should return to once the current WaitStatus +// has been resolved. +func (t *Task) WaitedStatus() Status { + t.state.reading() + return t.waitedStatus } // IsClean returns whether the task has been cleaned. See SetClean. @@ -320,7 +415,7 @@ const ( var timeNow = time.Now -func FakeTime(now time.Time) (restore func()) { +func MockTime(now time.Time) (restore func()) { timeNow = func() time.Time { return now } return func() { timeNow = time.Now } } @@ -332,7 +427,7 @@ func (t *Task) addLog(kind, format string, args []interface{}) { } tstr := timeNow().Format(time.RFC3339) - msg := fmt.Sprintf(tstr+" "+kind+" "+format, args...) + msg := tstr + " " + kind + " " + fmt.Sprintf(format, args...) t.log = append(t.log, msg) logger.Debugf(msg) } @@ -468,7 +563,7 @@ func (t *Task) At(when time.Time) { // TaskSetEdge designates tasks inside a TaskSet for outside reference. // // This is useful to give tasks inside TaskSets a special meaning. It -// is used to mark e.g. the last task in a task set. +// is used to mark e.g. the last task used for downloading a snap. type TaskSetEdge string // A TaskSet holds a set of tasks. @@ -484,10 +579,16 @@ func NewTaskSet(tasks ...*Task) *TaskSet { return &TaskSet{tasks, nil} } -// Edge returns the task marked with the given edge name. +// MaybeEdge returns the task marked with the given edge name or nil if no such +// task exists. +func (ts TaskSet) MaybeEdge(e TaskSetEdge) *Task { + return ts.edges[e] +} + +// Edge returns the task marked with the given edge name or an error. func (ts TaskSet) Edge(e TaskSetEdge) (*Task, error) { - t, ok := ts.edges[e] - if !ok { + t := ts.MaybeEdge(e) + if t == nil { return nil, fmt.Errorf("internal error: missing %q edge in task set", e) } return t, nil @@ -509,7 +610,7 @@ func (ts *TaskSet) WaitAll(anotherTs *TaskSet) { } } -// AddTask adds the the task to the task set. +// AddTask adds the task to the task set. func (ts *TaskSet) AddTask(task *Task) { for _, t := range ts.tasks { if t == task { @@ -522,6 +623,9 @@ func (ts *TaskSet) AddTask(task *Task) { // MarkEdge marks the given task as a specific edge. Any pre-existing // edge mark will be overridden. func (ts *TaskSet) MarkEdge(task *Task, edge TaskSetEdge) { + if task == nil { + panic(fmt.Sprintf("cannot set edge %q with nil task", edge)) + } if ts.edges == nil { ts.edges = make(map[TaskSetEdge]*Task) } diff --git a/internals/overlord/state/task_test.go b/internals/overlord/state/task_test.go index 9c073d80..d1e26e54 100644 --- a/internals/overlord/state/task_test.go +++ b/internals/overlord/state/task_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -127,7 +127,7 @@ func (ts *taskSuite) TestClear(c *C) { t.Clear("a") - c.Check(t.Get("a", &v), Equals, state.ErrNoState) + c.Check(t.Get("a", &v), testutil.ErrorIs, state.ErrNoState) } func (ts *taskSuite) TestStatusAndSetStatus(c *C) { @@ -144,6 +144,60 @@ func (ts *taskSuite) TestStatusAndSetStatus(c *C) { c.Check(t.Status(), Equals, state.DoneStatus) } +func (ts *taskSuite) TestSetDoneAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetStatus(state.DoneStatus) + c.Check(t.Status(), Equals, state.AbortStatus) +} + +func (ts *taskSuite) TestSetWaitAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetToWait(state.DoneStatus) // noop + c.Check(t.Status(), Equals, state.AbortStatus) + c.Check(t.WaitedStatus(), Equals, state.DefaultStatus) +} + +func (ts *taskSuite) TestSetWait(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetToWait(state.DoneStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.DoneStatus) + t.SetToWait(state.UndoStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.UndoStatus) +} + +func (ts *taskSuite) TestTaskMarshalsWaitStatus(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t1 := st.NewTask("download", "1...") + t1.SetToWait(state.UndoStatus) + + d, err := t1.MarshalJSON() + c.Assert(err, IsNil) + + needle := fmt.Sprintf(`"waited-status":%d`, t1.WaitedStatus()) + c.Assert(string(d), testutil.Contains, needle) +} + func (ts *taskSuite) TestIsCleanAndSetClean(c *C) { st := state.New(nil) st.Lock() @@ -174,9 +228,9 @@ func (ts *taskSuite) TestProgressAndSetProgress(c *C) { t := st.NewTask("download", "1...") - t.SetProgress("foo", 2, 99) + t.SetProgress("snap", 2, 99) label, cur, tot := t.Progress() - c.Check(label, Equals, "foo") + c.Check(label, Equals, "snap") c.Check(cur, Equals, 2) c.Check(tot, Equals, 99) @@ -305,7 +359,7 @@ func (ts *taskSuite) TestAt(c *C) { t := st.NewTask("download", "1...") now := time.Now() - restore := state.FakeTime(now) + restore := state.MockTime(now) defer restore() when := now.Add(10 * time.Second) t.At(when) @@ -561,6 +615,10 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { // edges are just typed strings edge1 := state.TaskSetEdge("on-edge") edge2 := state.TaskSetEdge("eddie") + edge3 := state.TaskSetEdge("not-found") + + // nil task causes panic + c.Check(func() { ts.MarkEdge(nil, edge1) }, PanicMatches, `cannot set edge "on-edge" with nil task`) // no edge marked yet t, err := ts.Edge(edge1) @@ -590,6 +648,12 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { t, err = ts.Edge(edge1) c.Assert(t, Equals, t3) c.Assert(err, IsNil) + + // it is possible to check if edge exists without failing + t = ts.MaybeEdge(edge1) + c.Assert(t, Equals, t3) + t = ts.MaybeEdge(edge3) + c.Assert(t, IsNil) } func (cs *taskSuite) TestTaskAddAllWithEdges(c *C) { diff --git a/internals/overlord/state/taskrunner.go b/internals/overlord/state/taskrunner.go index 5d9ff244..c6a2b835 100644 --- a/internals/overlord/state/taskrunner.go +++ b/internals/overlord/state/taskrunner.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -46,6 +46,23 @@ func (r *Retry) Error() string { return "task should be retried" } +// Wait is returned from a handler to signal that the task cannot +// proceed at the moment maybe because some manual action from the +// user required at this point or because of errors. The task +// will be set to WaitStatus, and it's wait complete status will be +// set to WaitedStatus. +type Wait struct { + Reason string + // If not explicitly set, then WaitedStatus will default to + // DoneStatus, meaning that the task will be set to DoneStatus + // after the wait has resolved. + WaitedStatus Status +} + +func (r *Wait) Error() string { + return "task set to wait, manual action required" +} + type blockedFunc func(t *Task, running []*Task) bool // TaskRunner controls the running of goroutines to execute known task kinds. @@ -62,6 +79,9 @@ type TaskRunner struct { blocked []blockedFunc someBlocked bool + // optional callback executed on task errors + taskErrorCallback func(err error) + // go-routines lifecycle tombs map[string]*tomb.Tomb } @@ -85,6 +105,11 @@ func NewTaskRunner(s *State) *TaskRunner { } } +// OnTaskError sets an error callback executed when any task errors out. +func (r *TaskRunner) OnTaskError(f func(err error)) { + r.taskErrorCallback = f +} + // AddHandler registers the functions to concurrently call for doing and // undoing tasks of the given kind. The undo handler may be nil. func (r *TaskRunner) AddHandler(kind string, do, undo HandlerFunc) { @@ -214,7 +239,7 @@ func (r *TaskRunner) run(t *Task) { switch err.(type) { case nil: // we are ok - case *Retry: + case *Retry, *Wait: // preserve default: if r.stopped { @@ -227,13 +252,24 @@ func (r *TaskRunner) run(t *Task) { switch x := err.(type) { case *Retry: // Handler asked to be called again later. - // TODO Allow postponing retries past the next Ensure. if t.Status() == AbortStatus { // Would work without it but might take two ensures. r.tryUndo(t) } else if x.After != 0 { t.At(timeNow().Add(x.After)) } + case *Wait: + if t.Status() == AbortStatus { + // Would work without it but might take two ensures. + r.tryUndo(t) + } else { + // Default to DoneStatus if no status is set in Wait + waitedStatus := x.WaitedStatus + if waitedStatus == DefaultStatus { + waitedStatus = DoneStatus + } + t.SetToWait(waitedStatus) + } case nil: var next []*Task switch t.Status() { @@ -259,6 +295,11 @@ func (r *TaskRunner) run(t *Task) { r.abortLanes(t.Change(), t.Lanes()) t.SetStatus(ErrorStatus) t.Errorf("%s", err) + // ensure the error is available in the global log too + logger.Noticef("[change %s %q task] failed: %v", t.Change().ID(), t.Summary(), err) + if r.taskErrorCallback != nil { + r.taskErrorCallback(err) + } } return nil @@ -388,6 +429,10 @@ ConsiderTasks: } continue } + if status == WaitStatus { + // nothing more to run + continue + } if mustWait(t) { // Dependencies still unhandled. diff --git a/internals/overlord/state/taskrunner_test.go b/internals/overlord/state/taskrunner_test.go index 38c5685a..2b9de82a 100644 --- a/internals/overlord/state/taskrunner_test.go +++ b/internals/overlord/state/taskrunner_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -31,11 +31,14 @@ import ( . "gopkg.in/check.v1" "gopkg.in/tomb.v2" - "github.com/canonical/pebble/internals/overlord/restart" + "github.com/canonical/pebble/internals/logger" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) -type taskRunnerSuite struct{} +type taskRunnerSuite struct { + testutil.BaseTest +} var _ = Suite(&taskRunnerSuite{}) @@ -58,8 +61,6 @@ func (b *stateBackend) EnsureBefore(d time.Duration) { } } -func (b *stateBackend) RequestRestart(t restart.RestartType) {} - func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) { for i := 0; i < 20; i++ { sb.ensureBefore = time.Hour @@ -142,6 +143,10 @@ var sequenceTests = []struct{ setup, result string }{{ result: "t31:undo t32:do t32:do-error t21:undo", }} +func (ts *taskRunnerSuite) SetUpTest(c *C) { + ts.BaseTest.SetUpTest(c) +} + func (ts *taskRunnerSuite) TestSequenceTests(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -185,11 +190,12 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { r.AddHandler("do", fn("do"), nil) r.AddHandler("do-undo", fn("do"), fn("undo")) + past := time.Now().AddDate(-1, 0, 0) for _, test := range sequenceTests { st.Lock() // Delete previous changes. - st.Prune(1, 1, 1) + st.Prune(past, 1, 1, 1) chg := st.NewChange("install", "...") tasks := make(map[string]*state.Task) @@ -342,6 +348,206 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { } } +func (ts *taskRunnerSuite) TestAbortAcrossLanesDescendantTask(c *C) { + + // () + // t11(1) -> t12(1) => t15(1) + // \ / + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / \ + // t21(2) -> t22(2) => t25(2) + // + names := strings.Fields("t11 t12 t13 t14 t15 t21 t22 t23 t24 t25") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + tasks["t15"].WaitFor(tasks["t14"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t15", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t25", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + tasks["t25"].WaitFor(tasks["t24"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t15"].WaitFor(tasks["t24"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t15" { + ch <- "t15:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:do", + }) + c.Assert(sequence[8:10], testutil.DeepUnsortedMatches, []string{ + "t25:do", + "t15:error", + }) + c.Assert(sequence[10:11], testutil.DeepUnsortedMatches, []string{ + "t25:undo", + }) + c.Assert(sequence[11:15], DeepEquals, []string{ + "t24:undo", "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[15:19], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + +func (ts *taskRunnerSuite) TestAbortAcrossLanesStriclyOrderedTasks(c *C) { + + // () + // t11(1) -> t12(1) + // \ + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / + // t21(2) -> t22(2) + // + names := strings.Fields("t11 t12 t13 t14 t21 t22 t23 t24") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t24" { + ch <- "t24:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:error", + }) + c.Assert(sequence[8:11], DeepEquals, []string{ + "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[11:], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + func (ts *taskRunnerSuite) TestExternalAbort(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -372,6 +578,101 @@ func (ts *taskRunnerSuite) TestExternalAbort(c *C) { ensureChange(c, r, sb, chg) } +func (ts *taskRunnerSuite) TestUndoSingleLane(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("noop", func(t *state.Task, tb *tomb.Tomb) error { + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("noop-slow", func(t *state.Task, tb *tomb.Tomb) error { + time.Sleep(10 * time.Millisecond) + t.State().Lock() + defer t.State().Unlock() + // critical + t.SetStatus(state.DoneStatus) + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error { + return fmt.Errorf("fail") + }, nil) + + st.Lock() + + lane := st.NewLane() + chg := st.NewChange("install", "...") + + // first taskset + var prev *state.Task + for i := 0; i < 10; i++ { + t := st.NewTask("noop-slow", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + + prev = t + } + + // second taskset with a failing task that triggers undo of the change + prev = nil + for i := 0; i < 10; i++ { + t := st.NewTask("noop", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + prev = t + } + + // error trigger + t := st.NewTask("fail", "...") + t.WaitFor(prev) + chg.AddTask(t) + t.JoinLane(lane) + + st.Unlock() + + var done bool + for !done { + c.Assert(r.Ensure(), Equals, nil) + st.Lock() + done = chg.IsReady() && chg.IsClean() + st.Unlock() + } + + st.Lock() + defer st.Unlock() + + // make sure all tasks are either undone or on hold (except for "fail" task which + // is in error). + for _, t := range st.Tasks() { + switch t.Kind() { + case "fail": + c.Assert(t.Status(), Equals, state.ErrorStatus) + case "noop", "noop-slow": + if t.Status() != state.UndoneStatus && t.Status() != state.HoldStatus { + for _, tsk := range st.Tasks() { + fmt.Printf("%s -> %s\n", tsk.Kind(), tsk.Status()) + } + c.Fatalf("unexpected status: %s", t.Status()) + } + default: + c.Fatalf("unexpected kind: %s", t.Kind()) + } + } +} + func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -508,6 +809,55 @@ func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) { c.Check(t.AtTime().IsZero(), Equals, false) } +func (ts *taskRunnerSuite) testTaskReturningWait(c *C, waitedStatus, expectedStatus state.Status) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("ask-for-wait", func(t *state.Task, tb *tomb.Tomb) error { + // ask for wait + return &state.Wait{WaitedStatus: waitedStatus} + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t := st.NewTask("ask-for-wait", "...") + chg.AddTask(t) + st.Unlock() + + r.Ensure() + // wait for handler to finish + r.Wait() + + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, expectedStatus) + c.Check(chg.Status().Ready(), Equals, false) + + st.Unlock() + defer st.Lock() + // does nothing + r.Ensure() + + // state is unchanged + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(chg.Status().Ready(), Equals, false) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitNormal(c *C) { + ts.testTaskReturningWait(c, state.UndoneStatus, state.UndoneStatus) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitDefaultStatus(c *C) { + // If no state was set (DefaultStatus), then it should default to + // DoneStatus instead. + ts.testTaskReturningWait(c, state.DefaultStatus, state.DoneStatus) +} + func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { ensureBeforeTick := make(chan bool, 1) sb := &stateBackend{ @@ -536,7 +886,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { st.Unlock() tock := time.Now() - restore := state.FakeTime(tock) + restore := state.MockTime(tock) defer restore() r.Ensure() // will run and be rescheduled in a minute select { @@ -554,7 +904,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { schedule := t.AtTime() c.Check(schedule.IsZero(), Equals, false) - state.FakeTime(tock.Add(5 * time.Second)) + state.MockTime(tock.Add(5 * time.Second)) sb.ensureBefore = time.Hour st.Unlock() r.Ensure() // too soon @@ -565,7 +915,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { c.Check(sb.ensureBefore, Equals, 55*time.Second) c.Check(t.AtTime().Equal(schedule), Equals, true) - state.FakeTime(schedule) + state.MockTime(schedule) sb.ensureBefore = time.Hour st.Unlock() r.Ensure() // time to run again @@ -823,7 +1173,7 @@ func (ts *taskRunnerSuite) TestUndoSequence(c *C) { terr.WaitFor(prev) chg.AddTask(terr) - c.Check(chg.Tasks(), HasLen, 9) // sanity check + c.Check(chg.Tasks(), HasLen, 9) // validity check st.Unlock() @@ -910,3 +1260,71 @@ func (ts *taskRunnerSuite) TestCleanup(c *C) { c.Assert(chgIsClean(), Equals, true) c.Assert(called, Equals, 2) } + +func (ts *taskRunnerSuite) TestErrorCallbackCalledOnError(c *C) { + logbuf, restore := logger.MockLogger("PREFIX: ") + defer restore() + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return fmt.Errorf("handler error for %q", t.Kind()) + }, nil) + + st.Lock() + chg := st.NewChange("install", "change summary") + t1 := st.NewTask("foo", "task summary") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.ErrorStatus) + c.Check(strings.Join(t1.Log(), ""), Matches, `.*handler error for "foo"`) + c.Check(called, Equals, true) + + c.Check(logbuf.String(), Matches, `(?m).*: \[change 1 "task summary" task\] failed: handler error for "foo".*`) +} + +func (ts *taskRunnerSuite) TestErrorCallbackNotCalled(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return nil + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.DoneStatus) + c.Check(called, Equals, false) +} diff --git a/internals/overlord/state/warning.go b/internals/overlord/state/warning.go index 69dac7c8..6211bf37 100644 --- a/internals/overlord/state/warning.go +++ b/internals/overlord/state/warning.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2018 Canonical Ltd + * Copyright (C) 2018 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as diff --git a/internals/overlord/state/warning_test.go b/internals/overlord/state/warning_test.go index bf30f69e..26479694 100644 --- a/internals/overlord/state/warning_test.go +++ b/internals/overlord/state/warning_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2018 Canonical Ltd + * Copyright (C) 2018 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -99,7 +99,7 @@ func (stateSuite) TestUnmarshalErrors(c *check.C) { } for _, t := range []T1{ - // sanity check + // validity check {`{"message": "x", "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, nil}, // remove one field at a time: {`{ "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, state.ErrNoWarningMessage}, diff --git a/internals/overlord/stateengine.go b/internals/overlord/stateengine.go index 25710333..b468aea5 100644 --- a/internals/overlord/stateengine.go +++ b/internals/overlord/stateengine.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package overlord @@ -30,10 +35,18 @@ type StateManager interface { Ensure() error } +// StateStarterUp is optionally implemented by StateManager that have expensive +// initialization to perform before the main Overlord loop. +type StateStarterUp interface { + // StartUp asks manager to perform any expensive initialization. + StartUp() error +} + // StateWaiter is optionally implemented by StateManagers that have running // activities that can be waited. type StateWaiter interface { - // Wait asks manager to wait for all running activities to finish. + // Wait asks manager to wait for all running activities to + // finish. Wait() } @@ -53,8 +66,9 @@ type StateStopper interface { // cope with Ensure calls in any order, coordinating among themselves // solely via the state. type StateEngine struct { - state *state.State - stopped bool + state *state.State + stopped bool + startedUp bool // managers in use mgrLock sync.Mutex managers []StateManager @@ -72,6 +86,37 @@ func (se *StateEngine) State() *state.State { return se.state } +type startupError struct { + errs []error +} + +func (e *startupError) Error() string { + return fmt.Sprintf("state startup errors: %v", e.errs) +} + +// StartUp asks all managers to perform any expensive initialization. It is a noop after the first invocation. +func (se *StateEngine) StartUp() error { + se.mgrLock.Lock() + defer se.mgrLock.Unlock() + if se.startedUp { + return nil + } + se.startedUp = true + var errs []error + for _, m := range se.managers { + if starterUp, ok := m.(StateStarterUp); ok { + err := starterUp.StartUp() + if err != nil { + errs = append(errs, err) + } + } + } + if len(errs) != 0 { + return &startupError{errs} + } + return nil +} + type ensureError struct { errs []error } @@ -91,6 +136,9 @@ func (e *ensureError) Error() string { func (se *StateEngine) Ensure() error { se.mgrLock.Lock() defer se.mgrLock.Unlock() + if !se.startedUp { + return fmt.Errorf("state engine skipped startup") + } if se.stopped { return fmt.Errorf("state engine already stopped") } diff --git a/internals/overlord/stateengine_test.go b/internals/overlord/stateengine_test.go index f12ed98b..a427e661 100644 --- a/internals/overlord/stateengine_test.go +++ b/internals/overlord/stateengine_test.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package overlord_test @@ -35,9 +40,14 @@ func (ses *stateEngineSuite) TestNewAndState(c *C) { } type fakeManager struct { - name string - calls *[]string - ensureError, stopError error + name string + calls *[]string + ensureError, startupError error +} + +func (fm *fakeManager) StartUp() error { + *fm.calls = append(*fm.calls, "startup:"+fm.name) + return fm.startupError } func (fm *fakeManager) Ensure() error { @@ -55,6 +65,48 @@ func (fm *fakeManager) Wait() { var _ overlord.StateManager = (*fakeManager)(nil) +func (ses *stateEngineSuite) TestStartUp(c *C) { + s := state.New(nil) + se := overlord.NewStateEngine(s) + + calls := []string{} + + mgr1 := &fakeManager{name: "mgr1", calls: &calls} + mgr2 := &fakeManager{name: "mgr2", calls: &calls} + + se.AddManager(mgr1) + se.AddManager(mgr2) + + err := se.StartUp() + c.Assert(err, IsNil) + c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"}) + + // noop + err = se.StartUp() + c.Assert(err, IsNil) + c.Check(calls, HasLen, 2) +} + +func (ses *stateEngineSuite) TestStartUpError(c *C) { + s := state.New(nil) + se := overlord.NewStateEngine(s) + + calls := []string{} + + err1 := errors.New("boom1") + err2 := errors.New("boom2") + + mgr1 := &fakeManager{name: "mgr1", calls: &calls, startupError: err1} + mgr2 := &fakeManager{name: "mgr2", calls: &calls, startupError: err2} + + se.AddManager(mgr1) + se.AddManager(mgr2) + + err := se.StartUp() + c.Check(err.Error(), DeepEquals, "state startup errors: [boom1 boom2]") + c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"}) +} + func (ses *stateEngineSuite) TestEnsure(c *C) { s := state.New(nil) se := overlord.NewStateEngine(s) @@ -68,6 +120,11 @@ func (ses *stateEngineSuite) TestEnsure(c *C) { se.AddManager(mgr2) err := se.Ensure() + c.Check(err, ErrorMatches, "state engine skipped startup") + c.Assert(se.StartUp(), IsNil) + calls = []string{} + + err = se.Ensure() c.Assert(err, IsNil) c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"}) @@ -91,6 +148,9 @@ func (ses *stateEngineSuite) TestEnsureError(c *C) { se.AddManager(mgr1) se.AddManager(mgr2) + c.Assert(se.StartUp(), IsNil) + calls = []string{} + err := se.Ensure() c.Check(err.Error(), DeepEquals, "state ensure errors: [boom1 boom2]") c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"}) @@ -108,6 +168,9 @@ func (ses *stateEngineSuite) TestStop(c *C) { se.AddManager(mgr1) se.AddManager(mgr2) + c.Assert(se.StartUp(), IsNil) + calls = []string{} + se.Stop() c.Check(calls, DeepEquals, []string{"stop:mgr1", "stop:mgr2"}) se.Stop() diff --git a/internals/systemd/sdnotify.go b/internals/systemd/sdnotify.go index 5f1279cf..ed16aceb 100644 --- a/internals/systemd/sdnotify.go +++ b/internals/systemd/sdnotify.go @@ -27,7 +27,7 @@ var osGetenv = os.Getenv func SocketAvailable() bool { notifySocket := osGetenv("NOTIFY_SOCKET") - return notifySocket != "" && osutil.CanStat(notifySocket) + return notifySocket != "" && osutil.FileExists(notifySocket) } // SdNotify sends the given state string notification to systemd. diff --git a/internals/systemd/systemd.go b/internals/systemd/systemd.go index c5e11114..0811fbee 100644 --- a/internals/systemd/systemd.go +++ b/internals/systemd/systemd.go @@ -713,7 +713,7 @@ func (s *systemd) RemoveMountUnitFile(mountedDir string) error { } unit := MountUnitPath(unitNamePath) - if !osutil.CanStat(unit) { + if !osutil.FileExists(unit) { return nil } diff --git a/internals/systemd/systemd_test.go b/internals/systemd/systemd_test.go index f32c363d..20967271 100644 --- a/internals/systemd/systemd_test.go +++ b/internals/systemd/systemd_test.go @@ -567,7 +567,7 @@ WantedBy=multi-user.target } func (s *SystemdTestSuite) TestFuseInContainer(c *C) { - if !osutil.CanStat("/dev/fuse") { + if !osutil.FileExists("/dev/fuse") { c.Skip("No /dev/fuse on the system") } @@ -713,7 +713,7 @@ func (s *SystemdTestSuite) TestRemoveMountUnit(c *C) { c.Assert(err, IsNil) // the file is gone - c.Check(osutil.CanStat(mountUnit), Equals, false) + c.Check(osutil.FileExists(mountUnit), Equals, false) // and the unit is disabled and the daemon reloaded c.Check(s.argses, DeepEquals, [][]string{ {"--root", s.rootDir, "disable", "snap-foo-42.mount"}, diff --git a/internals/testutil/containschecker.go b/internals/testutil/containschecker.go index 82ff2ee5..644ef158 100644 --- a/internals/testutil/containschecker.go +++ b/internals/testutil/containschecker.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2015-2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package testutil @@ -81,7 +86,7 @@ func (c *containsChecker) Check(params []interface{}, names []string) (result bo var container interface{} = params[0] var elem interface{} = params[1] if commonEquals(container, elem, &result, &error) { - return + return result, error } // Do the actual test using == switch containerV := reflect.ValueOf(container); containerV.Kind() { @@ -111,7 +116,7 @@ type deepContainsChecker struct { } // DeepContains is a Checker that looks for a elem in a container using -// DeepEqual. The elem can be any object. The container can be an array, slice +// DeepEqual. The elem can be any object. The container can be an array, slice // or string. var DeepContains check.Checker = &deepContainsChecker{ &check.CheckerInfo{Name: "DeepContains", Params: []string{"container", "elem"}}, @@ -121,7 +126,7 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul var container interface{} = params[0] var elem interface{} = params[1] if commonEquals(container, elem, &result, &error) { - return + return result, error } // Do the actual test using reflect.DeepEqual switch containerV := reflect.ValueOf(container); containerV.Kind() { @@ -145,3 +150,126 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul return false, fmt.Sprintf("%T is not a supported container", container) } } + +type deepUnsortedMatchChecker struct { + *check.CheckerInfo +} + +// DeepUnsortedMatches checks if two containers contain the same elements in +// the same number (but possibly different order) using DeepEqual. The container +// can be an array, a slice or a map. +var DeepUnsortedMatches check.Checker = &deepUnsortedMatchChecker{ + &check.CheckerInfo{Name: "DeepUnsortedMatches", Params: []string{"container1", "container2"}}, +} + +func (c *deepUnsortedMatchChecker) Check(params []interface{}, _ []string) (bool, string) { + container1 := reflect.ValueOf(params[0]) + container2 := reflect.ValueOf(params[1]) + + // if both args are nil, return true + if container1.Kind() == reflect.Invalid && container2.Kind() == reflect.Invalid { + return true, "" + } + + if container1.Kind() == reflect.Invalid || container2.Kind() == reflect.Invalid { + return false, "only one container was nil" + } + + if container1.Kind() != container2.Kind() { + return false, fmt.Sprintf("containers are of different types: %s != %s", container1.Kind(), container2.Kind()) + } + + switch container1.Kind() { + case reflect.Array, reflect.Slice: + return deepSequenceMatch(container1, container2) + case reflect.Map: + return deepMapMatch(container1, container2) + default: + return false, fmt.Sprintf("'%s' is not a supported type: must be slice, array or map", container1.Kind().String()) + } +} + +func deepMapMatch(container1, container2 reflect.Value) (bool, string) { + if valid, output := validateContainerTypesAndLengths(container1, container2); !valid { + return false, output + } + + switch container1.Type().Elem().Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + // only run the unsorted match if the map values are containers + default: + if !reflect.DeepEqual(container1.Interface(), container2.Interface()) { + return false, "maps don't match" + } + return true, "" + } + + for _, key := range container1.MapKeys() { + el1 := container1.MapIndex(key) + el2 := container2.MapIndex(key) + + absent := el2 == reflect.Value{} + if absent { + return false, fmt.Sprintf("key %q from one map is absent from the other map", key) + } + + var ok bool + var msg string + switch el1.Kind() { + case reflect.Array, reflect.Slice: + ok, msg = deepSequenceMatch(el1, el2) + case reflect.Map: + ok, msg = deepMapMatch(el1, el2) + } + + if !ok { + return false, msg + } + } + + return true, "" +} + +func deepSequenceMatch(container1, container2 reflect.Value) (bool, string) { + if valid, output := validateContainerTypesAndLengths(container1, container2); !valid { + return false, output + } + + matched := make([]bool, container1.Len()) +out: + for i := 0; i < container1.Len(); i++ { + el1 := container1.Index(i).Interface() + + for e := 0; e < container2.Len(); e++ { + el2 := container2.Index(e).Interface() + + if !matched[e] && reflect.DeepEqual(el1, el2) { + // mark already matched elements, so that duplicate elements in + // one container can't be matched to the same element in the other. + matched[e] = true + continue out + } + } + + return false, fmt.Sprintf("element [%d]=%s was unmatched in the second container", i, el1) + } + + return true, "" +} + +func validateContainerTypesAndLengths(container1, container2 reflect.Value) (bool, string) { + if container1.Len() != container2.Len() { + return false, fmt.Sprintf("containers have different lengths: %d != %d", container1.Len(), container2.Len()) + } else if container1.Type().Elem() != container2.Type().Elem() { + return false, fmt.Sprintf("containers have different element types: %s != %s", container1.Type().Elem(), container2.Type().Elem()) + } + + if container1.Kind() == reflect.Map && container2.Kind() == reflect.Map { + keyType1, keyType2 := container1.Type().Key(), container2.Type().Key() + if keyType1 != keyType2 { + return false, fmt.Sprintf("maps have different key types: %s != %s", keyType1, keyType2) + } + } + + return true, "" +} diff --git a/internals/testutil/containschecker_test.go b/internals/testutil/containschecker_test.go index 86c5c0cc..8a94c897 100644 --- a/internals/testutil/containschecker_test.go +++ b/internals/testutil/containschecker_test.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2015-2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package testutil_test @@ -216,3 +221,152 @@ func (*containsCheckerSuite) TestDeepContainsUncomparableType(c *check.C) { testCheck(c, DeepContains, true, "", containerSlice, elem) testCheck(c, DeepContains, true, "", containerMap, elem) } + +type example struct { + a string + b map[string]int +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceSuccess(c *check.C) { + slice1 := []example{ + {a: "one", b: map[string]int{"a": 1}}, + {a: "two", b: map[string]int{"b": 2}}, + } + slice2 := []example{ + {a: "two", b: map[string]int{"b": 2}}, + {a: "one", b: map[string]int{"a": 1}}, + } + + c.Check(slice1, DeepUnsortedMatches, slice2) + c.Check(slice2, DeepUnsortedMatches, slice1) + c.Check([]string{"a", "a"}, DeepUnsortedMatches, []string{"a", "a"}) + c.Check([]string{"a", "b", "a"}, DeepUnsortedMatches, []string{"b", "a", "a"}) + slice := [1]int{1} + c.Check(slice, DeepUnsortedMatches, slice) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceFailure(c *check.C) { + slice1 := []string{"a", "a", "b"} + slice2 := []string{"b", "a", "c"} + + testCheck(c, DeepUnsortedMatches, false, "element [1]=a was unmatched in the second container", slice1, slice2) + testCheck(c, DeepUnsortedMatches, false, "element [2]=c was unmatched in the second container", slice2, slice1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapSuccess(c *check.C) { + map1 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + "c": {a: "c", b: map[string]int{"c": 3, "d": 4}}, + } + map2 := map[string]example{ + "c": {a: "c", b: map[string]int{"c": 3, "d": 4}}, + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + } + + c.Check(map1, DeepUnsortedMatches, map2) + c.Check(map2, DeepUnsortedMatches, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapStructFail(c *check.C) { + map1 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 2, "b": 1}}, + } + map2 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + } + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedKeyFailure(c *check.C) { + map1 := map[string]int{"a": 1, "c": 2} + map2 := map[string]int{"a": 1, "b": 2} + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedValueFailure(c *check.C) { + map1 := map[string]int{"a": 1, "b": 2} + map2 := map[string]int{"a": 1, "b": 3} + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentTypeFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers are of different types: slice != array", []int{}, [1]int{}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentElementType(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers have different element types: int != string", []int{1}, []string{"a"}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentLengthFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 1 != 2", []int{1}, []int{1, 1}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNilArgFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "only one container was nil", nil, []int{1}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesBothNilArgSuccess(c *check.C) { + c.Check(nil, DeepUnsortedMatches, nil) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNonContainerValues(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "'string' is not a supported type: must be slice, array or map", "a", "a") + testCheck(c, DeepUnsortedMatches, false, "'int' is not a supported type: must be slice, array or map", 1, 2) + testCheck(c, DeepUnsortedMatches, false, "'bool' is not a supported type: must be slice, array or map", true, false) + testCheck(c, DeepUnsortedMatches, false, "'ptr' is not a supported type: must be slice, array or map", &[]string{"a", "b"}, &[]string{"a", "b"}) + testCheck(c, DeepUnsortedMatches, false, "'func' is not a supported type: must be slice, array or map", func() {}, func() {}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsOfSlices(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}} + map2 := map[string][]string{"a": {"bar", "foo"}, "b": {"bar", "foo"}} + + c.Check(map1, DeepUnsortedMatches, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentKeyTypes(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[int][]string{1: {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "maps have different key types: string != int", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentValueTypes(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[string][2]string{"a": {"foo", "bar"}} + + testCheck(c, DeepUnsortedMatches, false, "containers have different element types: []string != [2]string", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentLengths(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}} + map2 := map[string][]string{"a": {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 2 != 1", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsMissingKey(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[string][]string{"b": {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "key \"a\" from one map is absent from the other map", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNestedMaps(c *check.C) { + map1 := map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}} + map2 := map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}} + c.Check(map1, DeepUnsortedMatches, map2) + + map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}} + map2 = map[string]map[string][]string{"a": {"c": []string{"bar", "foo"}}} + testCheck(c, DeepUnsortedMatches, false, "key \"b\" from one map is absent from the other map", map1, map2) + + map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}, "c": {"b": []string{"foo"}}} + map2 = map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}, "c": {"b": []string{"bar"}}} + testCheck(c, DeepUnsortedMatches, false, "element [0]=foo was unmatched in the second container", map1, map2) +} diff --git a/internals/testutil/errorischecker.go b/internals/testutil/errorischecker.go new file mode 100644 index 00000000..e22030c6 --- /dev/null +++ b/internals/testutil/errorischecker.go @@ -0,0 +1,53 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2022 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package testutil + +import ( + "errors" + + "gopkg.in/check.v1" +) + +// ErrorIs calls errors.Is with the provided arguments. +var ErrorIs = &errorIsChecker{ + &check.CheckerInfo{Name: "ErrorIs", Params: []string{"error", "target"}}, +} + +type errorIsChecker struct { + *check.CheckerInfo +} + +func (*errorIsChecker) Check(params []interface{}, names []string) (result bool, errMsg string) { + if params[0] == nil { + return params[1] == nil, "" + } + + err, ok := params[0].(error) + if !ok { + return false, "first argument must be an error" + } + + target, ok := params[1].(error) + if !ok { + return false, "second argument must be an error" + } + + return errors.Is(err, target), "" +} diff --git a/internals/testutil/errorischecker_test.go b/internals/testutil/errorischecker_test.go new file mode 100644 index 00000000..5f2fc28c --- /dev/null +++ b/internals/testutil/errorischecker_test.go @@ -0,0 +1,72 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2022 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package testutil_test + +import ( + "errors" + + . "github.com/canonical/pebble/internals/testutil" + + . "gopkg.in/check.v1" +) + +type errorIsCheckerSuite struct{} + +var _ = Suite(&errorIsCheckerSuite{}) + +type baseError struct{} + +func (baseError) Error() string { return "" } + +func (baseError) Is(err error) bool { + _, ok := err.(baseError) + return ok +} + +type wrapperError struct { + err error +} + +func (*wrapperError) Error() string { return "" } + +func (e *wrapperError) Unwrap() error { return e.err } + +func (*errorIsCheckerSuite) TestErrorIsCheckSucceeds(c *C) { + testInfo(c, ErrorIs, "ErrorIs", []string{"error", "target"}) + + c.Assert(baseError{}, ErrorIs, baseError{}) + err := &wrapperError{err: baseError{}} + c.Assert(err, ErrorIs, baseError{}) +} + +func (*errorIsCheckerSuite) TestErrorIsCheckFails(c *C) { + c.Assert(nil, Not(ErrorIs), baseError{}) + c.Assert(errors.New(""), Not(ErrorIs), baseError{}) +} + +func (*errorIsCheckerSuite) TestErrorIsWithInvalidArguments(c *C) { + res, errMsg := ErrorIs.Check([]interface{}{"", errors.New("")}, []string{"error", "target"}) + c.Assert(res, Equals, false) + c.Assert(errMsg, Equals, "first argument must be an error") + + res, errMsg = ErrorIs.Check([]interface{}{errors.New(""), ""}, []string{"error", "target"}) + c.Assert(res, Equals, false) + c.Assert(errMsg, Equals, "second argument must be an error") +}