diff --git a/internals/cli/cmd_run.go b/internals/cli/cmd_run.go index 2c1f2083..62d513a5 100644 --- a/internals/cli/cmd_run.go +++ b/internals/cli/cmd_run.go @@ -182,7 +182,10 @@ func runDaemon(rcmd *cmdRun, ch chan os.Signal, ready chan<- func()) error { } d.Version = cmd.Version - d.Start() + err = d.Start() + if err != nil { + return err + } watchdog, err := runWatchdog(d) if err != nil { diff --git a/internals/daemon/api_changes_test.go b/internals/daemon/api_changes_test.go index 0f5f6228..f52788ff 100644 --- a/internals/daemon/api_changes_test.go +++ b/internals/daemon/api_changes_test.go @@ -46,7 +46,7 @@ func setupChanges(st *state.State) []string { } func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -75,7 +75,7 @@ func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) { } func (s *apiSuite) TestStateChangesInProgress(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -104,7 +104,7 @@ func (s *apiSuite) TestStateChangesInProgress(c *check.C) { } func (s *apiSuite) TestStateChangesAll(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -133,7 +133,7 @@ func (s *apiSuite) TestStateChangesAll(c *check.C) { } func (s *apiSuite) TestStateChangesReady(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -161,7 +161,7 @@ func (s *apiSuite) TestStateChangesReady(c *check.C) { } func (s *apiSuite) TestStateChangesForServiceName(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -192,7 +192,7 @@ func (s *apiSuite) TestStateChangesForServiceName(c *check.C) { } func (s *apiSuite) TestStateChange(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup @@ -261,7 +261,7 @@ func (s *apiSuite) TestStateChange(c *check.C) { } func (s *apiSuite) TestStateChangeAbort(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() soon := 0 @@ -334,7 +334,7 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { } func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { - restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) + restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC)) defer restore() // Setup diff --git a/internals/daemon/api_files_test.go b/internals/daemon/api_files_test.go index ab41554d..26157f2b 100644 --- a/internals/daemon/api_files_test.go +++ b/internals/daemon/api_files_test.go @@ -625,7 +625,7 @@ func (s *filesSuite) TestRemoveSingle(c *C) { c.Check(r.Result, HasLen, 1) checkFileResult(c, r.Result[0], tmpDir+"/file", "", "") - c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false) + c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false) } func (s *filesSuite) TestRemoveMultiple(c *C) { @@ -667,7 +667,7 @@ func (s *filesSuite) TestRemoveMultiple(c *C) { checkFileResult(c, r.Result[2], tmpDir+"/non-empty", "generic-file-error", ".*directory not empty") checkFileResult(c, r.Result[3], tmpDir+"/recursive", "", "") - c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false) + c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false) c.Check(osutil.IsDir(tmpDir+"/empty"), Equals, false) c.Check(osutil.IsDir(tmpDir+"/non-empty"), Equals, true) c.Check(osutil.IsDir(tmpDir+"/recursive"), Equals, false) @@ -1186,10 +1186,10 @@ group not found checkFileResult(c, r.Result[4], pathUserNotFound, "generic-file-error", ".*unknown user.*") checkFileResult(c, r.Result[5], pathGroupNotFound, "generic-file-error", ".*unknown group.*") - c.Check(osutil.CanStat(pathNoContent), Equals, false) - c.Check(osutil.CanStat(pathNotAbsolute), Equals, false) - c.Check(osutil.CanStat(pathNotFound), Equals, false) - c.Check(osutil.CanStat(pathPermissionDenied), Equals, false) + c.Check(osutil.FileExists(pathNoContent), Equals, false) + c.Check(osutil.FileExists(pathNotAbsolute), Equals, false) + c.Check(osutil.FileExists(pathNotFound), Equals, false) + c.Check(osutil.FileExists(pathPermissionDenied), Equals, false) } func assertFile(c *C, path string, perm os.FileMode, content string) { diff --git a/internals/daemon/api_test.go b/internals/daemon/api_test.go index ed7ebd7f..38735743 100644 --- a/internals/daemon/api_test.go +++ b/internals/daemon/api_test.go @@ -58,6 +58,7 @@ func (s *apiSuite) daemon(c *check.C) *Daemon { } d, err := New(&Options{Dir: s.pebbleDir}) c.Assert(err, check.IsNil) + c.Assert(d.Overlord().StartUp(), check.IsNil) d.addRoutes() s.d = d return d diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index b7405ee0..dc46030f 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -385,6 +385,10 @@ func (d *Daemon) Init() error { return nil } +func (d *Daemon) Overlord() *overlord.Overlord { + return d.overlord +} + // SetDegradedMode puts the daemon into a degraded mode which will the // error given in the "err" argument for commands that are not marked // as readonlyOK. @@ -452,13 +456,13 @@ func (d *Daemon) initStandbyHandling() { d.standbyOpinions.Start() } -func (d *Daemon) Start() { +func (d *Daemon) Start() error { if d.rebootIsMissing { // we need to schedule and wait for a system restart d.tomb.Kill(nil) // avoid systemd killing us again while we wait systemdSdNotify("READY=1") - return + return nil } if d.overlord == nil { panic("internal error: no Overlord") @@ -466,6 +470,10 @@ func (d *Daemon) Start() { d.StartTime = time.Now() + // now perform expensive overlord/manages initialization + if err := d.overlord.StartUp(); err != nil { + return err + } d.connTracker = &connTracker{conns: make(map[net.Conn]struct{})} d.serve = &http.Server{ Handler: logit(d.router), @@ -505,6 +513,7 @@ func (d *Daemon) Start() { // notify systemd that we are ready systemdSdNotify("READY=1") + return nil } // HandleRestart implements overlord.RestartBehavior. @@ -673,7 +682,7 @@ func (d *Daemon) rebootDelay() (time.Duration, error) { // see whether a reboot had already been scheduled var rebootAt time.Time err := d.state.Get("daemon-system-restart-at", &rebootAt) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return 0, err } rebootDelay := 1 * time.Minute @@ -758,7 +767,7 @@ var errExpectedReboot = errors.New("expected reboot did not happen") func (d *Daemon) RebootIsMissing(st *state.State) error { var nTentative int err := st.Get("daemon-system-restart-tentative", &nTentative) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } nTentative++ diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index c10d04d9..5e6c8f1e 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -103,7 +103,8 @@ func (s *daemonSuite) TestExplicitPaths(c *C) { d := s.newDaemon(c) d.Init() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) info, err := os.Stat(s.socketPath) @@ -459,7 +460,8 @@ func (s *daemonSuite) TestStartStop(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l2, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) generalDone := make(chan struct{}) go func() { @@ -500,7 +502,8 @@ func (s *daemonSuite) TestRestartWiring(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) generalDone := make(chan struct{}) @@ -567,7 +570,8 @@ func (s *daemonSuite) TestGracefulStop(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: untrustedL, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) generalAccepting := make(chan struct{}) go func() { @@ -634,7 +638,8 @@ func (s *daemonSuite) TestRestartSystemWiring(c *check.C) { untrustedAccept := make(chan struct{}) d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept} - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) defer d.Stop(nil) st := d.overlord.State() @@ -781,7 +786,8 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) st := d.overlord.State() st.Lock() @@ -791,7 +797,7 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) { ch := make(chan os.Signal, 2) ch <- syscall.SIGTERM // stop will check if we got a sigterm in between (which we did) - err := d.Stop(ch) + err = d.Stop(ch) c.Assert(err, check.IsNil) } @@ -813,7 +819,8 @@ func (s *daemonSuite) TestRestartShutdown(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) st := d.overlord.State() st.Lock() @@ -860,7 +867,8 @@ func (s *daemonSuite) TestRestartExpectedRebootIsMissing(c *check.C) { c.Check(err, check.IsNil) c.Check(n, check.Equals, 1) - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) c.Check(s.notified, check.DeepEquals, []string{"READY=1"}) @@ -896,8 +904,8 @@ func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) { @@ -920,9 +928,9 @@ func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) { defer st.Unlock() var v interface{} // these were cleared - c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) - c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState) + c.Check(st.Get("daemon-system-restart-tentative", &v), testutil.ErrorIs, state.ErrNoState) } func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { @@ -936,7 +944,8 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { d := s.newDaemon(c) makeDaemonListeners(c, d) - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) // pretend some ensure happened for i := 0; i < 5; i++ { @@ -955,7 +964,7 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { case <-time.After(15 * time.Second): c.Errorf("daemon did not stop after 15s") } - err := d.Stop(nil) + err = d.Stop(nil) c.Check(err, check.Equals, ErrRestartSocket) c.Check(d.restartSocket, check.Equals, true) } @@ -972,7 +981,8 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { st := d.overlord.State() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) // pretend some ensure happened for i := 0; i < 5; i++ { d.overlord.StateEngine().Ensure() @@ -998,7 +1008,7 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { c.Errorf("daemon did not stop after 5s") } // when the daemon got a pending change it just restarts - err := d.Stop(nil) + err = d.Stop(nil) c.Check(err, check.IsNil) c.Check(d.restartSocket, check.Equals, false) } @@ -1058,7 +1068,8 @@ func (s *daemonSuite) TestHTTPAPI(c *check.C) { s.httpAddress = ":0" // Go will choose port (use listener.Addr() to find it) d := s.newDaemon(c) d.Init() - d.Start() + err := d.Start() + c.Assert(err, check.IsNil) port := d.httpListener.Addr().(*net.TCPAddr).Port request, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d/v1/health", port), nil) @@ -1101,7 +1112,8 @@ services: d := s.newDaemon(c) err := d.Init() c.Assert(err, IsNil) - d.Start() + err = d.Start() + c.Assert(err, check.IsNil) // Start the test service. payload := bytes.NewBufferString(`{"action": "start", "services": ["test1"]}`) diff --git a/internals/jsonutil/json.go b/internals/jsonutil/json.go new file mode 100644 index 00000000..056ab0a4 --- /dev/null +++ b/internals/jsonutil/json.go @@ -0,0 +1,66 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2017 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package jsonutil + +import ( + "encoding/json" + "fmt" + "io" + "reflect" + "strings" + + "github.com/canonical/x-go/strutil" +) + +// DecodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved +// via json.Number data type. It errors out on invalid json or any excess input. +func DecodeWithNumber(r io.Reader, value interface{}) error { + dec := json.NewDecoder(r) + dec.UseNumber() + if err := dec.Decode(&value); err != nil { + return err + } + if dec.More() { + return fmt.Errorf("cannot parse json value") + } + return nil +} + +// StructFields takes a pointer to a struct and a list of exceptions, +// and returns a list of the fields in the struct that are JSON-tagged +// and whose tag is not in the list of exceptions. +// The struct can be nil. +func StructFields(s interface{}, exceptions ...string) []string { + st := reflect.TypeOf(s).Elem() + num := st.NumField() + fields := make([]string, 0, num) + for i := 0; i < num; i++ { + tag := st.Field(i).Tag.Get("json") + idx := strings.IndexByte(tag, ',') + if idx > -1 { + tag = tag[:idx] + } + if tag != "" && !strutil.ListContains(exceptions, tag) { + fields = append(fields, tag) + } + } + + return fields +} diff --git a/internals/jsonutil/json_test.go b/internals/jsonutil/json_test.go new file mode 100644 index 00000000..2fb35c0b --- /dev/null +++ b/internals/jsonutil/json_test.go @@ -0,0 +1,90 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2017 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package jsonutil_test + +import ( + "encoding/json" + "strings" + "testing" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/jsonutil" +) + +func Test(t *testing.T) { TestingT(t) } + +type utilSuite struct{} + +var _ = Suite(&utilSuite{}) + +func (s *utilSuite) TestDecodeError(c *C) { + input := "{]" + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, `invalid character ']' looking for beginning of object key string`) +} + +func (s *utilSuite) TestDecodeErrorOnExcessData(c *C) { + input := "1000000000[1,2]" + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, `cannot parse json value`) +} + +func (s *utilSuite) TestDecodeSuccess(c *C) { + input := `{"a":1000000000, "b": 1.2, "c": "foo", "d":null}` + var output interface{} + err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output) + c.Assert(err, IsNil) + c.Assert(output, DeepEquals, map[string]interface{}{ + "a": json.Number("1000000000"), + "b": json.Number("1.2"), + "c": "foo", + "d": nil, + }) +} + +func (utilSuite) TestStructFields(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int `json:"potato,stuff"` + } + c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello", "potato"}) +} + +func (utilSuite) TestStructFieldsExcept(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int `json:"potato,stuff"` + } + c.Assert(jsonutil.StructFields((*aStruct)(nil), "potato"), DeepEquals, []string{"hello"}) + c.Assert(jsonutil.StructFields((*aStruct)(nil), "hello"), DeepEquals, []string{"potato"}) +} + +func (utilSuite) TestStructFieldsSurvivesNoTag(c *C) { + type aStruct struct { + Foo int `json:"hello"` + Bar int + } + c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello"}) +} diff --git a/internals/jsonutil/safejson/safejson.go b/internals/jsonutil/safejson/safejson.go new file mode 100644 index 00000000..298e6945 --- /dev/null +++ b/internals/jsonutil/safejson/safejson.go @@ -0,0 +1,202 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package safejson + +import ( + "fmt" + "strconv" + "unicode" + "unicode/utf16" + "unicode/utf8" + + "github.com/canonical/x-go/strutil" +) + +// String accepts any valid JSON string. Its Clean method will remove +// characters that aren't expected in a short descriptive text. +// I.e.: Cc, Co, Cf, Cs, noncharacters, and � (U+FFFD, the replacement +// character) are removed. +type String struct { + s string +} + +func (str *String) UnmarshalJSON(in []byte) (err error) { + str.s, err = unmarshal(in, uOpt{}) + return +} + +// Clean returns the string, with Cc, Co, Cf, Cs, noncharacters, +// and � (U+FFFD) removed. +func (str String) Clean() string { + return str.s +} + +// Paragraph accepts any valid JSON string. Its Clean method will remove +// characters that aren't expected in a long descriptive text. +// I.e.: Cc (except for \n), Co, Cf, Cs, noncharacters, and � (U+FFFD, +// the replacement character) are removed. +type Paragraph struct { + s string +} + +func (par *Paragraph) UnmarshalJSON(in []byte) (err error) { + par.s, err = unmarshal(in, uOpt{nlOK: true}) + return +} + +// Clean returns the string, with Cc minus \n, Co, Cf, Cs, noncharacters, +// and � (U+FFFD) removed. +func (par Paragraph) Clean() string { + return par.s +} + +func unescapeUCS2(in []byte) (rune, bool) { + if len(in) < 6 || in[0] != '\\' || in[1] != 'u' { + return -1, false + } + u, err := strconv.ParseUint(string(in[2:6]), 16, 32) + if err != nil { + return -1, false + } + return rune(u), true +} + +type uOpt struct { + nlOK bool + simple bool +} + +func unmarshal(in []byte, o uOpt) (string, error) { + // heavily based on (inspired by?) unquoteBytes from encoding/json + + if len(in) < 2 || in[0] != '"' || in[len(in)-1] != '"' { + // maybe it's a null and that's alright + if len(in) == 4 && in[0] == 'n' && in[1] == 'u' && in[2] == 'l' && in[3] == 'l' { + return "", nil + } + return "", fmt.Errorf("missing string delimiters: %q", in) + } + + // prune the quotes + in = in[1 : len(in)-1] + i := 0 + // try the fast track + for i < len(in) { + // 0x00..0x19 is the first of Cc + // 0x20..0x7e is all of printable ASCII (minus control chars) + if in[i] < 0x20 || in[i] > 0x7e || in[i] == '\\' || in[i] == '"' { + break + } + i++ + } + if i == len(in) { + // wee + return string(in), nil + } + if o.simple { + return "", fmt.Errorf("character %q in string %q unsupported for this value", in[i], in) + } + // in[i] is the first problematic one + out := make([]byte, i, len(in)+2*utf8.UTFMax) + copy(out, in) + var r, r2 rune + var n int + var c byte + var ubuf [utf8.UTFMax]byte + var ok bool + for i < len(in) { + c = in[i] + switch { + case c == '"': + return "", fmt.Errorf("unexpected unescaped quote at %d in \"%s\"", i, in) + case c < 0x20: + return "", fmt.Errorf("unexpected control character at %d in %q", i, in) + case c == '\\': + // handle escapes + i++ + if i == len(in) { + return "", fmt.Errorf("unexpected end of string (trailing backslash) in \"%s\"", in) + } + switch in[i] { + case 'u': + // oh dear, a unicode wotsit + r, ok = unescapeUCS2(in[i-1:]) + if !ok { + x := in[i-1:] + if len(x) > 6 { + x = x[:6] + } + return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in) + } + i += 5 + if utf16.IsSurrogate(r) { + // sigh + r2, ok = unescapeUCS2(in[i:]) + if !ok { + x := in[i:] + if len(x) > 6 { + x = x[:6] + } + return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in) + } + i += 6 + r = utf16.DecodeRune(r, r2) + } + if r <= 0x9f { + // otherwise, it's Cc (both halves, as we're looking at runes) + if (o.nlOK && r == '\n') || (r >= 0x20 && r <= 0x7e) { + out = append(out, byte(r)) + } + } else if r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) { + n = utf8.EncodeRune(ubuf[:], r) + out = append(out, ubuf[:n]...) + } + case 'b', 'f', 'r', 't': + // do nothing + i++ + case 'n': + if o.nlOK { + out = append(out, '\n') + } + i++ + case '"', '/', '\\': + // the spec says just ", / and \ can be backslash-escaped + // but go adds ' to the list (in unquoteBytes) + out = append(out, in[i]) + i++ + default: + return "", fmt.Errorf(`unknown escape '%c' at %d of "%s"`, in[i], i, in) + } + case c <= 0x7e: + // printable ASCII, except " or \ + out = append(out, c) + i++ + default: + r, n = utf8.DecodeRune(in[i:]) + j := i + n + if r > 0x9f && r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) { + out = append(out, in[i:j]...) + } + i = j + } + } + + return string(out), nil +} diff --git a/internals/jsonutil/safejson/safejson_test.go b/internals/jsonutil/safejson/safejson_test.go new file mode 100644 index 00000000..aae12aea --- /dev/null +++ b/internals/jsonutil/safejson/safejson_test.go @@ -0,0 +1,148 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package safejson_test + +import ( + "encoding/json" + "testing" + + "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/jsonutil/safejson" +) + +func Test(t *testing.T) { check.TestingT(t) } + +type escapeSuite struct{} + +var _ = check.Suite(escapeSuite{}) + +var table = map[string]string{ + "null": "", + `"hello"`: "hello", + `"árbol"`: "árbol", + `"\u0020"`: " ", + `"\uD83D\uDE00"`: "😀", + `"a\b\r\tb"`: "ab", + `"\\\""`: `\"`, + // escape sequences (NOTE just the control char is stripped) + `"\u001b[3mhello\u001b[m"`: "[3mhello[m", + `"a\u0080z"`: "az", + "\"a\u0080z\"": "az", + "\"a\u007fz\"": "az", + "\"a\u009fz\"": "az", + // replacement char + `"a\uFFFDb"`: "ab", + // private unicode chars + `"a\uE000b"`: "ab", + `"a\uDB80\uDC00b"`: "ab", +} + +func (escapeSuite) TestStrings(c *check.C) { + var u safejson.String + for j, s := range table { + comm := check.Commentf(j) + c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + + c.Assert(u.UnmarshalJSON([]byte(j)), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + } +} + +func (escapeSuite) TestBadStrings(c *check.C) { + var u1 safejson.String + + cc0 := make([][]byte, 0x20) + for i := range cc0 { + cc0[i] = []byte{'"', byte(i), '"'} + } + badesc := make([][]byte, 0, 0x7f-0x21-9) + for c := byte('!'); c <= '~'; c++ { + switch c { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u': + continue + default: + badesc = append(badesc, []byte{'"', '\\', c, '"'}) + } + } + + table := map[string][][]byte{ + // these are from json itself (so we're not checking them): + "invalid character '.+' in string literal": cc0, + "invalid character '.+' in string escape code": badesc, + `invalid character '.+' in \\u .*`: {[]byte(`"\u02"`), []byte(`"\u02zz"`)}, + "invalid character '\"' after top-level value": {[]byte(`"""`)}, + "unexpected end of JSON input": {[]byte(`"\"`)}, + } + + for e, js := range table { + for _, j := range js { + comm := check.Commentf("%q", j) + c.Check(json.Unmarshal(j, &u1), check.ErrorMatches, e, comm) + } + } + + table = map[string][][]byte{ + // these are from our lib + `missing string delimiters.*`: {{}, {'"'}}, + `unexpected control character at 0 in "\\.+"`: cc0, + `unknown escape '.' at 1 of "\\."`: badesc, + `badly formed \\u escape.*`: { + []byte(`"\u02"`), []byte(`"\u02zz"`), []byte(`"a\u02xxz"`), + []byte(`"\uD83Da"`), []byte(`"\uD83Da\u20"`), []byte(`"\uD83Da\u20zzz"`), + }, + `unexpected unescaped quote at 0 in """`: {[]byte(`"""`)}, + `unexpected end of string \(trailing backslash\).*`: {[]byte(`"\"`)}, + } + + for e, js := range table { + for _, j := range js { + comm := check.Commentf("%q", j) + c.Check(u1.UnmarshalJSON(j), check.ErrorMatches, e, comm) + } + } +} + +func (escapeSuite) TestParagraph(c *check.C) { + var u safejson.Paragraph + for j1, v1 := range table { + for j2, v2 := range table { + if j2 == "null" && j1 != "null" { + continue + } + + var j, s string + + if j1 == "null" { + j = j2 + s = v2 + } else { + j = j1[:len(j1)-1] + "\\n" + j2[1:] + s = v1 + "\n" + v2 + } + + comm := check.Commentf(j) + c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm) + c.Check(u.Clean(), check.Equals, s, comm) + } + } + +} diff --git a/internals/osutil/io_test.go b/internals/osutil/io_test.go index a57d88a1..61c6e69c 100644 --- a/internals/osutil/io_test.go +++ b/internals/osutil/io_test.go @@ -245,9 +245,9 @@ func (ts *AtomicWriteTestSuite) TestAtomicFileCancel(c *C) { aw, err := osutil.NewAtomicFile(p, 0644, 0, osutil.NoChown, osutil.NoChown) c.Assert(err, IsNil) fn := aw.File.Name() - c.Check(osutil.CanStat(fn), Equals, true) + c.Check(osutil.FileExists(fn), Equals, true) c.Check(aw.Cancel(), IsNil) - c.Check(osutil.CanStat(fn), Equals, false) + c.Check(osutil.FileExists(fn), Equals, false) } // SafeIoAtomicWriteTestSuite runs all AtomicWrite with safe diff --git a/internals/osutil/squashfs/fstype.go b/internals/osutil/squashfs/fstype.go index 07a8b8d5..f03fd213 100644 --- a/internals/osutil/squashfs/fstype.go +++ b/internals/osutil/squashfs/fstype.go @@ -25,7 +25,7 @@ import ( var useFuse = useFuseImpl func useFuseImpl() bool { - if !osutil.CanStat("/dev/fuse") { + if !osutil.FileExists("/dev/fuse") { return false } diff --git a/internals/osutil/stat.go b/internals/osutil/stat.go index c806dcad..8387e4c9 100644 --- a/internals/osutil/stat.go +++ b/internals/osutil/stat.go @@ -25,9 +25,9 @@ import ( "syscall" ) -// CanStat returns true if stat succeeds on the given path. +// FileExists returns true if stat succeeds on the given path. // It may return false on permission issues. -func CanStat(path string) bool { +func FileExists(path string) bool { _, err := os.Stat(path) return err == nil } diff --git a/internals/osutil/stat_test.go b/internals/osutil/stat_test.go index a1235ced..cdf11672 100644 --- a/internals/osutil/stat_test.go +++ b/internals/osutil/stat_test.go @@ -34,8 +34,8 @@ func (ts *StatTestSuite) TestCanStat(c *C) { err := ioutil.WriteFile(fname, []byte(fname), 0644) c.Assert(err, IsNil) - c.Assert(CanStat(fname), Equals, true) - c.Assert(CanStat("/i-do-not-exist"), Equals, false) + c.Assert(FileExists(fname), Equals, true) + c.Assert(FileExists("/i-do-not-exist"), Equals, false) } func (ts *StatTestSuite) TestCanStatOddPerms(c *C) { @@ -43,7 +43,7 @@ func (ts *StatTestSuite) TestCanStatOddPerms(c *C) { err := ioutil.WriteFile(fname, []byte(fname), 0100) c.Assert(err, IsNil) - c.Assert(CanStat(fname), Equals, true) + c.Assert(FileExists(fname), Equals, true) } func (ts *StatTestSuite) TestIsDir(c *C) { diff --git a/internals/overlord/export_test.go b/internals/overlord/export_test.go index 22529702..8fe8f187 100644 --- a/internals/overlord/export_test.go +++ b/internals/overlord/export_test.go @@ -40,6 +40,14 @@ func FakePruneInterval(prunei, prunew, abortw time.Duration) (restore func()) { } } +func FakePruneTicker(f func(t *time.Ticker) <-chan time.Time) (restore func()) { + old := pruneTickerC + pruneTickerC = f + return func() { + pruneTickerC = old + } +} + // FakeEnsureNext sets o.ensureNext for tests. func FakeEnsureNext(o *Overlord, t time.Time) { o.ensureNext = t diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go index 771a670d..bba805be 100644 --- a/internals/overlord/overlord.go +++ b/internals/overlord/overlord.go @@ -16,6 +16,7 @@ package overlord import ( + "errors" "fmt" "io" "os" @@ -49,6 +50,10 @@ var ( defaultCachedDownloads = 5 ) +var pruneTickerC = func(t *time.Ticker) <-chan time.Time { + return t.C +} + // Overlord is the central manager of the system, keeping track // of all available state managers and related helpers. type Overlord struct { @@ -63,8 +68,11 @@ type Overlord struct { ensureRun int32 pruneTicker *time.Ticker + startOfOperationTime time.Time + // managers inited bool + startedUp bool runner *state.TaskRunner serviceMgr *servstate.ServiceManager commandMgr *cmdstate.CommandManager @@ -136,6 +144,47 @@ func New(pebbleDir string, restartHandler restart.Handler, serviceOutput io.Writ return o, nil } +var timeNow = time.Now + +// StartOfOperationTime returns the time when pebble started operating, +// and sets it in the state when called for the first time. +// The StartOfOperationTime time is seed-time if available, +// or current time otherwise. +func (m *Overlord) StartOfOperationTime() (time.Time, error) { + var opTime time.Time + err := m.State().Get("start-of-operation-time", &opTime) + if err == nil { + return opTime, nil + } + if err != nil && !errors.Is(err, state.ErrNoState) { + return opTime, err + } + + opTime = timeNow() + m.State().Set("start-of-operation-time", opTime) + return opTime, nil +} + +// StartUp proceeds to run any expensive Overlord or managers initialization. +// After this is done once it is a noop. +func (o *Overlord) StartUp() error { + if o.startedUp { + return nil + } + o.startedUp = true + + var err error + st := o.State() + st.Lock() + o.startOfOperationTime, err = o.StartOfOperationTime() + st.Unlock() + if err != nil { + return fmt.Errorf("cannot get start of operation time: %s", err) + } + + return o.stateEng.StartUp() +} + func (o *Overlord) addManager(mgr StateManager) { o.stateEng.AddManager(mgr) } @@ -159,7 +208,7 @@ func loadState(statePath string, restartHandler restart.Handler, backend state.B } } - if !osutil.CanStat(statePath) { + if !osutil.FileExists(statePath) { // fail fast, mostly interesting for tests, this dir is set up by pebble stateDir := filepath.Dir(statePath) if !osutil.IsDir(stateDir) { @@ -255,6 +304,9 @@ func (o *Overlord) ensureBefore(d time.Duration) { // Loop runs a loop in a goroutine to ensure the current state regularly through StateEngine Ensure. func (o *Overlord) Loop() { o.ensureTimerSetup() + if o.loopTomb == nil { + o.loopTomb = new(tomb.Tomb) + } o.loopTomb.Go(func() error { for { // TODO: pass a proper context into Ensure @@ -263,14 +315,15 @@ func (o *Overlord) Loop() { // continue to the next Ensure() try for now o.stateEng.Ensure() o.ensureDidRun() + pruneC := pruneTickerC(o.pruneTicker) select { case <-o.loopTomb.Dying(): return nil case <-o.ensureTimer.C: - case <-o.pruneTicker.C: + case <-pruneC: st := o.State() st.Lock() - st.Prune(pruneWait, abortWait, pruneMaxChanges) + st.Prune(o.startOfOperationTime, pruneWait, abortWait, pruneMaxChanges) st.Unlock() } } @@ -295,6 +348,10 @@ func (o *Overlord) Stop() error { } func (o *Overlord) settle(timeout time.Duration, beforeCleanups func()) error { + if err := o.StartUp(); err != nil { + return err + } + func() { o.ensureLock.Lock() defer o.ensureLock.Unlock() diff --git a/internals/overlord/overlord_test.go b/internals/overlord/overlord_test.go index c6b9cc7e..33696609 100644 --- a/internals/overlord/overlord_test.go +++ b/internals/overlord/overlord_test.go @@ -46,6 +46,26 @@ type overlordSuite struct { var _ = Suite(&overlordSuite{}) +type ticker struct { + tickerChannel chan time.Time +} + +func (w *ticker) tick(n int) { + for i := 0; i < n; i++ { + w.tickerChannel <- time.Now() + } +} + +func fakePruneTicker() (w *ticker, restore func()) { + w = &ticker{ + tickerChannel: make(chan time.Time), + } + restore = overlord.FakePruneTicker(func(t *time.Ticker) <-chan time.Time { + return w.tickerChannel + }) + return w, restore +} + func (ovs *overlordSuite) SetUpTest(c *C) { ovs.dir = c.MkDir() ovs.statePath = filepath.Join(ovs.dir, ".pebble.state") @@ -161,6 +181,12 @@ type witnessManager struct { expectedEnsure int ensureCalled chan struct{} ensureCallback func(s *state.State) error + startedUp int +} + +func (wm *witnessManager) StartUp() error { + wm.startedUp++ + return nil } func (wm *witnessManager) Ensure() error { @@ -178,6 +204,9 @@ func (ovs *overlordSuite) TestTrivialRunAndStop(c *C) { o, err := overlord.New(ovs.dir, nil, nil) c.Assert(err, IsNil) + err = o.StartUp() + c.Assert(err, IsNil) + o.Loop() err = o.Stop() @@ -216,6 +245,10 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) { } o.AddManager(witness) + err := o.StartUp() + + c.Assert(err, IsNil) + o.Loop() defer o.Stop() @@ -227,8 +260,9 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) { } c.Check(time.Since(t0) >= 10*time.Millisecond, Equals, true) - err := o.Stop() + err = o.Stop() c.Assert(err, IsNil) + c.Check(witness.startedUp, Equals, 1) } func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) { @@ -248,7 +282,7 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) { ensureCallback: ensure, } o.AddManager(witness) - + c.Assert(o.StartUp(), IsNil) o.Loop() defer o.Stop() @@ -277,6 +311,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBefore(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -306,6 +342,8 @@ func (ovs *overlordSuite) TestEnsureBeforeSleepy(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -335,6 +373,8 @@ func (ovs *overlordSuite) TestEnsureBeforeLater(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -364,6 +404,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeOutsideEnsure(c *C) } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() defer o.Stop() @@ -420,6 +462,8 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + o.Loop() select { @@ -441,7 +485,7 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) { } func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { - restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour) + restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 5*time.Millisecond, 1*time.Hour) defer restoreIntv() o := overlord.Fake() @@ -459,20 +503,26 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { c.Check(st.Changes(), HasLen, 2) st.Unlock() + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + // start the loop that runs the prune ticker o.Loop() - // ensure the first change is pruned - time.Sleep(1500 * time.Millisecond) - st.Lock() - c.Check(st.Changes(), HasLen, 1) - st.Unlock() + // this needs to be more than pruneWait=5ms mocked above + time.Sleep(10 * time.Millisecond) + w.tick(2) - // ensure the second is also purged after it is ready st.Lock() + c.Check(st.Changes(), HasLen, 1) chg2.SetStatus(state.DoneStatus) st.Unlock() - time.Sleep(1500 * time.Millisecond) + + // this needs to be more than pruneWait=5ms mocked above + time.Sleep(10 * time.Millisecond) + // tick twice for extra Ensure + w.tick(2) + st.Lock() c.Check(st.Changes(), HasLen, 0) st.Unlock() @@ -482,6 +532,134 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) { c.Assert(err, IsNil) } +func (ovs *overlordSuite) TestOverlordStartUpSetsStartOfOperation(c *C) { + restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour) + defer restoreIntv() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + st := o.State() + st.Lock() + defer st.Unlock() + + // validity check, not set + var opTime time.Time + c.Assert(st.Get("start-of-operation-time", &opTime), testutil.ErrorIs, state.ErrNoState) + st.Unlock() + + c.Assert(o.StartUp(), IsNil) + + st.Lock() + c.Assert(st.Get("start-of-operation-time", &opTime), IsNil) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneDoesntAbortShortlyAfterStartOfOperation(c *C) { + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is 50min ago, this is less then abort limit + opTime := time.Now().Add(-50 * time.Minute) + st.Set("start-of-operation-time", opTime) + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.MockTime(spawnTime) + + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoingStatus) +} + +func (ovs *overlordSuite) TestEnsureLoopPruneAbortsOld(c *C) { + // Ensure interval is not relevant for this test + restoreEnsureIntv := overlord.FakeEnsureInterval(10 * time.Hour) + defer restoreEnsureIntv() + + w, restoreTicker := fakePruneTicker() + defer restoreTicker() + + o, err := overlord.New(ovs.dir, nil, nil) + c.Assert(err, IsNil) + + // avoid immediate transition to Done due to having unknown kind + o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error { + return &state.Retry{} + }, nil) + + st := o.State() + st.Lock() + + // start of operation time is a year ago + opTime := time.Now().AddDate(-1, 0, 0) + st.Set("start-of-operation-time", opTime) + + st.Unlock() + c.Assert(o.StartUp(), IsNil) + st.Lock() + + // spawn time one month ago + spawnTime := time.Now().AddDate(0, -1, 0) + restoreTimeNow := state.MockTime(spawnTime) + t := st.NewTask("bar", "...") + chg := st.NewChange("other-change", "...") + chg.AddTask(t) + + restoreTimeNow() + + // validity + c.Check(st.Changes(), HasLen, 1) + st.Unlock() + + // start the loop that runs the prune ticker + o.Loop() + w.tick(2) + + c.Assert(o.Stop(), IsNil) + + st.Lock() + defer st.Unlock() + + // validity + op, err := o.StartOfOperationTime() + c.Assert(err, IsNil) + c.Check(op.Equal(opTime), Equals, true) + + c.Assert(st.Changes(), HasLen, 1) + // change was aborted + c.Check(chg.Status(), Equals, state.HoldStatus) +} + func (ovs *overlordSuite) TestCheckpoint(c *C) { oldUmask := syscall.Umask(0) defer syscall.Umask(oldUmask) @@ -856,6 +1034,8 @@ func (ovs *overlordSuite) TestOverlordCanStandby(c *C) { } o.AddManager(witness) + c.Assert(o.StartUp(), IsNil) + // can only standby after loop ran once c.Assert(o.CanStandby(), Equals, false) diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go index afc93186..58514a8a 100644 --- a/internals/overlord/patch/patch.go +++ b/internals/overlord/patch/patch.go @@ -20,6 +20,7 @@ package patch import ( + "errors" "fmt" "github.com/canonical/pebble/internals/logger" @@ -43,12 +44,12 @@ var patches = make(map[int][]PatchFunc) func Init(s *state.State) { s.Lock() defer s.Unlock() - if s.Get("patch-level", new(int)) != state.ErrNoState { + if err := s.Get("patch-level", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-level without actual patching") } s.Set("patch-level", Level) - if s.Get("patch-sublevel", new(int)) != state.ErrNoState { + if err := s.Get("patch-sublevel", new(int)); !errors.Is(err, state.ErrNoState) { panic("internal error: expected empty state, attempting to override patch-sublevel without actual patching") } s.Set("patch-sublevel", Sublevel) @@ -76,12 +77,12 @@ func Apply(s *state.State) error { var stateLevel, stateSublevel int s.Lock() err := s.Get("patch-level", &stateLevel) - if err == nil || err == state.ErrNoState { + if err == nil || errors.Is(err, state.ErrNoState) { err = s.Get("patch-sublevel", &stateSublevel) } s.Unlock() - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } diff --git a/internals/overlord/restart/restart.go b/internals/overlord/restart/restart.go index 0c1a8840..6c11f62f 100644 --- a/internals/overlord/restart/restart.go +++ b/internals/overlord/restart/restart.go @@ -16,6 +16,8 @@ package restart import ( + "errors" + "github.com/canonical/pebble/internals/overlord/state" ) @@ -61,7 +63,7 @@ func Init(st *state.State, curBootID string, h Handler) error { } var fromBootID string err := st.Get("system-restart-from-boot-id", &fromBootID) - if err != nil && err != state.ErrNoState { + if err != nil && !errors.Is(err, state.ErrNoState) { return err } st.Cache(restartStateKey{}, rs) diff --git a/internals/overlord/restart/restart_test.go b/internals/overlord/restart/restart_test.go index 8a8617ed..74fc0764 100644 --- a/internals/overlord/restart/restart_test.go +++ b/internals/overlord/restart/restart_test.go @@ -21,6 +21,7 @@ import ( "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestRestart(t *testing.T) { TestingT(t) } @@ -134,5 +135,5 @@ func (s *restartSuite) TestRequestRestartSystemAndVerifyReboot(c *C) { err = restart.Init(st, "boot-id-2", h2) c.Assert(err, IsNil) c.Check(h2.rebootAsExpected, Equals, true) - c.Check(st.Get("system-restart-from-boot-id", &fromBootID), Equals, state.ErrNoState) + c.Check(st.Get("system-restart-from-boot-id", &fromBootID), testutil.ErrorIs, state.ErrNoState) } diff --git a/internals/overlord/state/change.go b/internals/overlord/state/change.go index bfb3b42d..22bfa123 100644 --- a/internals/overlord/state/change.go +++ b/internals/overlord/state/change.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2023 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -23,8 +23,11 @@ import ( "bytes" "encoding/json" "fmt" + "sort" "strings" "time" + + "github.com/canonical/pebble/internals/logger" ) // Status is used for status values for changes and tasks. @@ -37,7 +40,8 @@ const ( // to an aggregation of its tasks' statuses. See Change.Status for details. DefaultStatus Status = 0 - // HoldStatus means the task should not run, perhaps as a consequence of an error on another task. + // HoldStatus means the task should not run for the moment, perhaps as a + // consequence of an error on another task. HoldStatus Status = 1 // DoStatus means the change or task is ready to start. @@ -65,6 +69,12 @@ const ( // ErrorStatus means the change or task has errored out while running or being undone. ErrorStatus Status = 9 + // WaitStatus means the task was accomplished successfully but some + // external event needs to happen before work can progress further + // (e.g. on classic we require the user to reboot after a + // kernel snap update). + WaitStatus Status = 10 + nStatuses = iota ) @@ -88,6 +98,8 @@ func (s Status) String() string { return "Doing" case DoneStatus: return "Done" + case WaitStatus: + return "Wait" case AbortStatus: return "Abort" case UndoStatus: @@ -104,6 +116,18 @@ func (s Status) String() string { panic(fmt.Sprintf("internal error: unknown task status code: %d", s)) } +// taskWaitComputeStatus is used while computing the wait status of a +// change. It keeps track of whether a task is waiting or not waiting, or the +// computation for it is still in-progress to detect cyclic dependencies. +type taskWaitComputeStatus int + +const ( + taskWaitStatusNotComputed taskWaitComputeStatus = iota + taskWaitStatusComputing + taskWaitStatusNotWaiting + taskWaitStatusWaiting +) + // Change represents a tracked modification to the system state. // // The Change provides both the justification for individual tasks @@ -115,16 +139,16 @@ func (s Status) String() string { // while the individual Task values would track the running of // the hooks themselves. type Change struct { - state *State - id string - kind string - summary string - status Status - clean bool - data customData - taskIDs []string - lanes int - ready chan struct{} + state *State + id string + kind string + summary string + status Status + clean bool + data customData + taskIDs []string + ready chan struct{} + lastObservedStatus Status spawnTime time.Time readyTime time.Time @@ -157,7 +181,6 @@ type marshalledChange struct { Clean bool `json:"clean,omitempty"` Data map[string]*json.RawMessage `json:"data,omitempty"` TaskIDs []string `json:"task-ids,omitempty"` - Lanes int `json:"lanes,omitempty"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -178,7 +201,6 @@ func (c *Change) MarshalJSON() ([]byte, error) { Clean: c.clean, Data: c.data, TaskIDs: c.taskIDs, - Lanes: c.lanes, SpawnTime: c.spawnTime, ReadyTime: readyTime, @@ -206,7 +228,6 @@ func (c *Change) UnmarshalJSON(data []byte) error { } c.data = custData c.taskIDs = unmarshalled.TaskIDs - c.lanes = unmarshalled.Lanes c.ready = make(chan struct{}) c.spawnTime = unmarshalled.SpawnTime if unmarshalled.ReadyTime != nil { @@ -251,12 +272,19 @@ func (c *Change) Get(key string, value interface{}) error { return c.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (c *Change) Has(key string) bool { + c.state.reading() + return c.data.has(key) +} + var statusOrder = []Status{ AbortStatus, UndoingStatus, UndoStatus, DoingStatus, DoStatus, + WaitStatus, ErrorStatus, UndoneStatus, DoneStatus, @@ -269,32 +297,138 @@ func init() { } } +func (c *Change) isTaskWaiting(visited map[string]taskWaitComputeStatus, t *Task, deps []*Task) bool { + taskID := t.ID() + // Retrieve the compute status of the wait for the task, if not + // computed this defaults to 0 (taskWaitStatusNotComputed). + computeStatus := visited[taskID] + switch computeStatus { + case taskWaitStatusComputing: + // Cyclic dependency detected, return false to short-circuit. + logger.Noticef("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + // Make sure errors show up in "snap change " too + t.Logf("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind()) + return false + case taskWaitStatusWaiting, taskWaitStatusNotWaiting: + return computeStatus == taskWaitStatusWaiting + } + visited[taskID] = taskWaitStatusComputing + + var isWaiting bool +depscheck: + for _, wt := range deps { + switch wt.Status() { + case WaitStatus: + isWaiting = true + // States that can be valid when waiting + // - Done, Undone, ErrorStatus, HoldStatus + case DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + // For 'Do' and 'Undo' we have to check whether the task is waiting + // for any dependencies. The logic is the same, but the set of tasks + // varies. + case DoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.WaitTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + case UndoStatus: + isWaiting = c.isTaskWaiting(visited, wt, wt.HaltTasks()) + if !isWaiting { + // Cancel early if we detect something is runnable. + break depscheck + } + default: + // When we determine the change can not be in a wait-state then + // break early. + isWaiting = false + break depscheck + } + } + if isWaiting { + visited[taskID] = taskWaitStatusWaiting + } else { + visited[taskID] = taskWaitStatusNotWaiting + } + return isWaiting +} + +// isChangeWaiting should only ever return true iff it determines all tasks in Do/Undo +// are blocked by tasks in either of three states: 'DoneStatus', 'UndoneStatus' or 'WaitStatus', +// if this fails, we default to the normal status ordering logic. +func (c *Change) isChangeWaiting() bool { + // Since we might visit tasks more than once, we store results to avoid recomputing them. + visited := make(map[string]taskWaitComputeStatus) + for _, t := range c.Tasks() { + switch t.Status() { + case WaitStatus, DoneStatus, UndoneStatus, ErrorStatus, HoldStatus: + continue + case DoStatus: + if !c.isTaskWaiting(visited, t, t.WaitTasks()) { + return false + } + case UndoStatus: + if !c.isTaskWaiting(visited, t, t.HaltTasks()) { + return false + } + default: + return false + } + } + // If we end up here, then return true as we know we + // have at least one waiter in this change. + return true +} + // Status returns the current status of the change. // If the status was not explicitly set the result is derived from the status // of the individual tasks related to the change, according to the following // decision sequence: // +// - With all pending tasks blocked by other tasks in WaitStatus, return WaitStatus // - With at least one task in DoStatus, return DoStatus // - With at least one task in ErrorStatus, return ErrorStatus // - Otherwise, return DoneStatus func (c *Change) Status() Status { c.state.reading() - if c.status == DefaultStatus { - if len(c.taskIDs) == 0 { - return HoldStatus - } - statusStats := make([]int, nStatuses) - for _, tid := range c.taskIDs { - statusStats[c.state.tasks[tid].Status()]++ + if c.status != DefaultStatus { + return c.status + } + + if len(c.taskIDs) == 0 { + return HoldStatus + } + + statusStats := make([]int, nStatuses) + for _, tid := range c.taskIDs { + statusStats[c.state.tasks[tid].Status()]++ + } + + // If the change has any waiters, check for any runnable tasks + // or whether it's completely blocked by waiters. + if statusStats[WaitStatus] > 0 { + // Only if the change has all tasks blocked we return WaitStatus. + if c.isChangeWaiting() { + return WaitStatus } - for _, s := range statusOrder { - if statusStats[s] > 0 { - return s - } + } + + // Otherwise we return the current status with the highest priority. + for _, s := range statusOrder { + if statusStats[s] > 0 { + return s } - panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) } - return c.status + panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats)) +} + +func (c *Change) notifyStatusChange(new Status) { + if c.lastObservedStatus == new { + return + } + c.state.notifyChangeStatusChangedHandlers(c, c.lastObservedStatus, new) + c.lastObservedStatus = new } // SetStatus sets the change status, overriding the default behavior (see Status method). @@ -304,6 +438,7 @@ func (c *Change) SetStatus(s Status) { if s.Ready() { c.markReady() } + c.notifyStatusChange(c.Status()) } func (c *Change) markReady() { @@ -322,15 +457,10 @@ func (c *Change) Ready() <-chan struct{} { return c.ready } -// taskStatusChanged is called by tasks when their status is changed, -// to give the opportunity for the change to close its ready channel. -func (c *Change) taskStatusChanged(t *Task, old, new Status) { - if old.Ready() == new.Ready() { - return - } +func (c *Change) detectChangeReady(excludeTask *Task) { for _, tid := range c.taskIDs { task := c.state.tasks[tid] - if task != t && !task.status.Ready() { + if task != excludeTask && !task.status.Ready() { return } } @@ -343,6 +473,21 @@ func (c *Change) taskStatusChanged(t *Task, old, new Status) { c.markReady() } +// taskStatusChanged is called by tasks when their status is changed, +// to give the opportunity for the change to close its ready channel, and +// notify observers of Change changes. +func (c *Change) taskStatusChanged(t *Task, old, new Status) { + cs := c.Status() + // If the task changes from ready => unready or unready => ready, + // update the ready status for the change. + if old.Ready() == new.Ready() { + c.notifyStatusChange(cs) + return + } + c.detectChangeReady(t) + c.notifyStatusChange(cs) +} + // IsClean returns whether all tasks in the change have been cleaned. See SetClean. func (c *Change) IsClean() bool { c.state.reading() @@ -519,6 +664,44 @@ func (c *Change) AbortLanes(lanes []int) { c.abortLanes(lanes, make(map[int]bool), make(map[string]bool)) } +// AbortUnreadyLanes aborts the tasks from lanes that aren't fully ready, where +// a ready lane is one in which all tasks are ready. +func (c *Change) AbortUnreadyLanes() { + c.state.writing() + c.abortUnreadyLanes() +} + +func (c *Change) abortUnreadyLanes() { + lanesWithLiveTasks := map[int]bool{} + + for _, tid := range c.taskIDs { + t := c.state.tasks[tid] + if !t.Status().Ready() { + for _, tlane := range t.Lanes() { + lanesWithLiveTasks[tlane] = true + } + } + } + + abortLanes := []int{} + for lane := range lanesWithLiveTasks { + abortLanes = append(abortLanes, lane) + } + c.abortLanes(abortLanes, make(map[int]bool), make(map[string]bool)) +} + +// taskEffectiveStatus returns the 'effective' status. This means it accounts +// for tasks being in WaitStatus, and instead of returning the WaitStatus we +// return the actual status. (The status after the wait). +func taskEffectiveStatus(t *Task) Status { + status := t.Status() + if status == WaitStatus { + // If the task is waiting, then use the effective status instead. + status = t.WaitedStatus() + } + return status +} + func (c *Change) abortLanes(lanes []int, abortedLanes map[int]bool, seenTasks map[string]bool) { var hasLive = make(map[int]bool) var hasDead = make(map[int]bool) @@ -528,7 +711,7 @@ NextChangeTask: t := c.state.tasks[tid] var live bool - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus, DoingStatus, DoneStatus: live = true } @@ -579,7 +762,7 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks continue } seenTasks[t.id] = true - switch t.Status() { + switch taskEffectiveStatus(t) { case DoStatus: // Still pending so don't even start. t.SetStatus(HoldStatus) @@ -607,3 +790,87 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks c.abortLanes(lanes, abortedLanes, seenTasks) } } + +type TaskDependencyCycleError struct { + IDs []string + msg string +} + +func (e *TaskDependencyCycleError) Error() string { return e.msg } + +func (e *TaskDependencyCycleError) Is(err error) bool { + _, ok := err.(*TaskDependencyCycleError) + return ok +} + +// CheckTaskDependencies checks the tasks in the change for cyclic dependencies +// and returns an error in such case. +func (c *Change) CheckTaskDependencies() error { + tasks := c.Tasks() + // count how many tasks any given non-independent task waits for + predecessors := make(map[string]int, len(tasks)) + + taskByID := map[string]*Task{} + for _, t := range tasks { + taskByID[t.id] = t + if l := len(t.waitTasks); l > 0 { + // only add an entry if the task is not independent + predecessors[t.id] = l + } + } + + // Kahn topological sort: make our way starting with tasks that are + // independent (their predecessors count is 0), then visit their direct + // successors (halt tasks), and for each reduce their predecessors + // count; once the count drops to 0, all direct dependencies of a given + // task have been accounted for and the task becomes independent. + + // queue of tasks to check + queue := make([]string, 0, len(tasks)) + // identify all independent tasks + for _, t := range tasks { + if predecessors[t.id] == 0 { + queue = append(queue, t.id) + } + } + + for len(queue) > 0 { + // take the first independent task + id := queue[0] + queue = queue[1:] + // reduce the incoming edge of its successors + for _, successor := range taskByID[id].haltTasks { + predecessors[successor]-- + if predecessors[successor] == 0 { + // a task that was a successor has become + // independent + delete(predecessors, successor) + queue = append(queue, successor) + } + } + } + + if len(predecessors) != 0 { + // tasks that are left cannot have their dependencies satisfied + var unsatisfiedTasks []string + for id := range predecessors { + unsatisfiedTasks = append(unsatisfiedTasks, id) + } + sort.Strings(unsatisfiedTasks) + msg := strings.Builder{} + msg.WriteString("dependency cycle involving tasks [") + for i, id := range unsatisfiedTasks { + t := taskByID[id] + msg.WriteString(fmt.Sprintf("%v:%v", t.id, t.kind)) + if i < len(unsatisfiedTasks)-1 { + msg.WriteRune(' ') + } + } + msg.WriteRune(']') + return &TaskDependencyCycleError{ + IDs: unsatisfiedTasks, + msg: msg.String(), + } + } + return nil +} diff --git a/internals/overlord/state/change_test.go b/internals/overlord/state/change_test.go index 52b77bfc..e25c4f55 100644 --- a/internals/overlord/state/change_test.go +++ b/internals/overlord/state/change_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -20,15 +20,16 @@ package state_test import ( + "errors" "fmt" "sort" "strconv" "strings" "time" - "github.com/canonical/pebble/internals/overlord/state" - . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/overlord/state" ) type changeSuite struct{} @@ -68,7 +69,7 @@ func (cs *changeSuite) TestReadyTime(c *C) { } func (cs *changeSuite) TestStatusString(c *C) { - for s := state.Status(0); s < state.ErrorStatus+1; s++ { + for s := state.Status(0); s < state.WaitStatus+1; s++ { c.Assert(s.String(), Matches, ".+") } } @@ -88,6 +89,21 @@ func (cs *changeSuite) TestGetSet(c *C) { c.Check(v, Equals, 1) } +func (cs *changeSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("install", "...") + c.Check(chg.Has("a"), Equals, false) + + chg.Set("a", 1) + c.Check(chg.Has("a"), Equals, true) + + chg.Set("a", nil) + c.Check(chg.Has("a"), Equals, false) +} + // TODO Better testing of full change roundtripping via JSON. func (cs *changeSuite) TestNewTaskAddTaskAndTasks(c *C) { @@ -227,9 +243,13 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { tasks := make(map[state.Status]*state.Task) - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } chg.AddTask(t) tasks[s] = t } @@ -240,6 +260,7 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { state.UndoStatus, state.DoingStatus, state.DoStatus, + state.WaitStatus, state.ErrorStatus, state.UndoneStatus, state.DoneStatus, @@ -252,7 +273,11 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) { if s == s2 { break } - tasks[s2].SetStatus(s) + if s == state.WaitStatus { + tasks[s2].SetToWait(state.DoneStatus) + } else { + tasks[s2].SetStatus(s) + } } c.Assert(chg.Status(), Equals, s) } @@ -432,9 +457,13 @@ func (cs *changeSuite) TestAbort(c *C) { chg := st.NewChange("install", "...") - for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ { + for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ { t := st.NewTask("download", s.String()) - t.SetStatus(s) + if s == state.WaitStatus { + t.SetToWait(state.DoneStatus) + } else { + t.SetStatus(s) + } t.Set("old-status", s) chg.AddTask(t) } @@ -451,7 +480,7 @@ func (cs *changeSuite) TestAbort(c *C) { switch s { case state.DoStatus: c.Assert(t.Status(), Equals, state.HoldStatus) - case state.DoneStatus: + case state.DoneStatus, state.WaitStatus: c.Assert(t.Status(), Equals, state.UndoStatus) case state.DoingStatus: c.Assert(t.Status(), Equals, state.AbortStatus) @@ -526,11 +555,13 @@ func (cs *changeSuite) TestAbortKⁿ(c *C) { // Task wait order: // -// => t21 => t22 -// / \ -// t11 => t12 => t41 => t42 -// \ / -// => t31 => t32 +// => t21 => t22 +// / \ +// +// t11 => t12 => t41 => t42 +// +// \ / +// => t31 => t32 // // setup and result lines are :[:,...] // @@ -577,6 +608,10 @@ var abortLanesTests = []struct { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{2}, result: "t21:hold t22:hold t41:hold t42:hold *:do", + }, { + setup: "t11:done:1 t12:wait:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + abort: []int{2}, + result: "t21:hold t22:hold t41:hold t42:hold t11:done t12:wait *:do", }, { setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", abort: []int{3}, @@ -660,7 +695,7 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Logf("Testing setup: %s", test.setup) statuses := make(map[string]state.Status) - for s := state.DefaultStatus; s <= state.ErrorStatus; s++ { + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { statuses[strings.ToLower(s.String())] = s } @@ -680,7 +715,11 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { } seen[parts[0]] = true task := tasks[parts[0]] - task.SetStatus(statuses[parts[1]]) + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } if len(parts) > 2 { lanes := strings.Split(parts[2], ",") for _, lane := range lanes { @@ -723,3 +762,691 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) { c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) } } + +// setup and result lines are :[:,...] +// order is -> (implies task2 waits for task 1) +// "*" as task name means "all remaining". +var abortUnreadyLanesTests = []struct { + setup string + order string + result string +}{ + + // Some basics. + { + setup: "*:do", + result: "*:hold", + }, { + setup: "*:wait", + result: "*:undo", + }, { + setup: "*:done", + result: "*:done", + }, { + setup: "*:error", + result: "*:error", + }, + + // t11 (1) => t12 (1) => t21 (1) => t22 (1) + // t31 (2) => t32 (2) => t41 (2) => t42 (2) + { + setup: "t11:do:1 t12:do:1 t21:do:1 t22:do:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:hold t32:hold t41:hold t42:hold", + }, { + setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:done:2 t32:done:2 t41:done:2 t42:do:2", + order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:undo t32:undo t41:undo t42:hold", + }, + // => t21 (2) => t22 (2) + // / \ + // t11 (2,3) => t12 (2,3) => t41 (4) => t42 (4) + // \ / + // => t31 (3) => t32 (3) + { + setup: "t11:do:2,3 t12:do:2,3 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "*:hold", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:abort t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:wait:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + // lane 2 is fully complete so it does not get aborted + result: "t11:done t12:done t21:done t22:done t31:undo t32:hold t41:hold t42:hold *:undo", + }, { + setup: "t11:done:2,3 t12:done:2,3 t21:doing:2 t22:do:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + result: "t21:abort t22:hold t31:abort t32:hold t41:hold t42:hold *:undo", + }, + + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "*:hold", + }, { + setup: "t11:do:1 t12:do:1 t21:doing:2 t22:do:2 t31:done:3 t32:doing:3 t41:undone:4 t42:error:4", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + result: "t11:hold t12:hold t21:abort t22:hold t31:undo t32:abort t41:undone t42:error", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:done t22:done t31:abort *:hold", + }, { + // + setup: "t11:done:1 t12:done:1 t21:done:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + result: "t11:done t12:done t21:undo *:hold", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11:done:1 t12:done:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + result: "t11:done t12:done *:hold", + }, +} + +func (ts *taskRunnerSuite) TestAbortUnreadyLanes(c *C) { + + names := strings.Fields("t11 t12 t21 t22 t31 t32 t41 t42") + + for i, test := range abortUnreadyLanesTests { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + statuses := make(map[string]state.Status) + for s := state.DefaultStatus; s <= state.WaitStatus; s++ { + statuses[strings.ToLower(s.String())] = s + } + + items := strings.Fields(test.setup) + seen := make(map[string]bool) + for i := 0; i < len(items); i++ { + item := items[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(items)-1, Commentf("*: can only be used as the last entry")) + for _, name := range names { + if !seen[name] { + parts[0] = name + items = append(items, strings.Join(parts, ":")) + } + } + continue + } + seen[parts[0]] = true + task := tasks[parts[0]] + if statuses[parts[1]] == state.WaitStatus { + task.SetToWait(state.DoneStatus) + } else { + task.SetStatus(statuses[parts[1]]) + } + if len(parts) > 2 { + lanes := strings.Split(parts[2], ",") + for _, lane := range lanes { + n, err := strconv.Atoi(lane) + c.Assert(err, IsNil) + task.JoinLane(n) + } + } + } + + c.Logf("Aborting") + + chg.AbortUnreadyLanes() + + c.Logf("Expected result: %s", test.result) + + seen = make(map[string]bool) + var expected = strings.Fields(test.result) + var obtained []string + for i := 0; i < len(expected); i++ { + item := expected[i] + parts := strings.Split(item, ":") + if parts[0] == "*" { + c.Assert(i, Equals, len(expected)-1, Commentf("*: can only be used as the last entry")) + var expanded []string + for _, name := range names { + if !seen[name] { + parts[0] = name + expanded = append(expanded, strings.Join(parts, ":")) + } + } + expected = append(expected[:i], append(expanded, expected[i+1:]...)...) + i-- + continue + } + name := parts[0] + seen[parts[0]] = true + obtained = append(obtained, name+":"+strings.ToLower(tasks[name].Status().String())) + } + + c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup)) + } +} + +// setup is a list of tasks " ", order is -> +// (implies task2 waits for task 1) +var cyclicDependencyTests = []struct { + setup string + order string + err string + errIDs []string +}{ + + // Some basics. + { + setup: "t1", + }, { + setup: "", + }, { + // independent tasks + setup: "t1 t2 t3", + }, { + // some independent and some ordered tasks + setup: "t1 t2 t3 t4", + order: "t2->t3", + }, + // some independent, dependencies as if added by WaitAll() + // t1 => t2 + // t1,t2 => t3 + // t1,t2,t3 => t4 + { + setup: "t1 t2 t3 t4", + order: "t1->t2 t1->t3 t2->t3 t1->t4 t2->t4 t3->t4", + }, { + // simple loop + setup: "t1 t2", + order: "t1->t2 t2->t1", + err: `dependency cycle involving tasks \[1:t1 2:t2\]`, + errIDs: []string{"1", "2"}, + }, + + // t1 => t2 => t3 => t4 + // t5 => t6 => t7 => t8 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t5->t6 t6->t7 t7->t8", + }, + // => t21 => t22 + // / \ + // t11 => t12 => t41 => t42 + // \ / + // => t31 => t32 + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42", + }, + // t11 (1) => t12 (1) + // t21 (2) => t22 (2) + // t31 (3) => t32 (3) + // t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t21->t22 t31->t32 t41->t42", + }, + // auto refresh like arrangement + // + // (apps) + // => t31 (3) => t32 (3) + // (snapd) (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42", + }, + // arrangement with a cyclic dependency between tasks + // + // /-----------------------------------------\ + // | | + // | => t31 (3) => t32 (3) / + // (snapd) v (base) / + // t11 (1) => t12 (1) => t21 (2) => t22 (2) + // \ + // => t41 (4) => t42 (4) + { + setup: "t11 t12 t21 t22 t31 t32 t41 t42", + order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21", + err: `dependency cycle involving tasks \[3:t21 4:t22 5:t31 6:t32 7:t41 8:t42\]`, + errIDs: []string{"3", "4", "5", "6", "7", "8"}, + }, + // t1 => t2 => t3 => t4 --> t6 + // t5 => t6 => t7 => t8 --> t2 + { + setup: "t1 t2 t3 t4 t5 t6 t7 t8", + order: "t1->t2 t2->t3 t3->t4 t4->t6 t5->t6 t6->t7 t7->t8 t8->t2", + err: `dependency cycle involving tasks \[2:t2 3:t3 4:t4 6:t6 7:t7 8:t8\]`, + errIDs: []string{"2", "3", "4", "6", "7", "8"}, + }, +} + +func (ts *taskRunnerSuite) TestCheckTaskDependencies(c *C) { + + for i, test := range cyclicDependencyTests { + names := strings.Fields(test.setup) + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask(name, name) + chg.AddTask(tasks[name]) + } + + c.Logf("----- %v", i) + c.Logf("Testing setup: %s", test.setup) + + for _, wp := range strings.Fields(test.order) { + pair := strings.Split(wp, "->") + c.Assert(pair, HasLen, 2) + // task 2 waits for task 1 is denoted as: + // task1->task2 + tasks[pair[1]].WaitFor(tasks[pair[0]]) + } + + err := chg.CheckTaskDependencies() + + if test.err != "" { + c.Assert(err, ErrorMatches, test.err) + c.Assert(errors.Is(err, &state.TaskDependencyCycleError{}), Equals, true) + errTasksDepCycle := err.(*state.TaskDependencyCycleError) + c.Assert(errTasksDepCycle.IDs, DeepEquals, test.errIDs) + } else { + c.Assert(err, IsNil) + } + } +} + +func (cs *changeSuite) TestIsWaitingStatusOrderWithWaits(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t1.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Set the wait-task into WaitStatus, to ensure we trigger the isWaiting + // logic and that it doesn't return WaitStatus for statuses which are in + // higher order + t4.SetToWait(state.DoneStatus) + + // Test the following sequences: + // task1 (do) => task2 (done) => task3 (doing) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (done) => task2 (done) => task3 (undoing) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoingStatus) + c.Check(chg.Status(), Equals, state.UndoingStatus) + + // task1 (done) => task2 (done) => task3 (abort) + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.AbortStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingSingle(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + + chg.AddTask(t1) + c.Check(chg.Status(), Equals, state.DoStatus) + + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) no reboot + t1.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (wait) => task2 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingCircularDependency(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + + // Setup circular dependency between t1,t2 and t3, they should + // still act normally. + t2.WaitFor(t1) + t3.WaitFor(t2) + t1.WaitFor(t3) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // To trigger the cyclic dependency check, we must trigger the isWaiting logic + // and we do this by putting t4 into WaitStatus. + t4.SetToWait(state.DoneStatus) + + // task1 (do) => task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) => task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoingStatus) + c.Check(chg.Status(), Equals, state.DoingStatus) + + // task1 (wait) => task2 (do) => task3 (do) means need a reboot + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) => task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + + // Put t4 into wait-status to trigger the isWaiting logic each time + // for the change. + t4.SetToWait(state.DoneStatus) + + // task1 (do) + task2 (do) => task3 (do) no reboot + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (done) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // task1 (done) + task2 (do) => task3 (do) no reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.DoStatus) + + // For the next two cases we are testing that a task with dependencies + // which have completed, but in a non-successful way is handled correctly. + // task1 (error) + task2 (wait) => task3 (do) means need reboot + // to finalize task2 + t1.SetStatus(state.ErrorStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (error) => task3 (do) means need reboot + // to finalize task1 + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (wait) => task3 (do) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) => task3 (do) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (done) + task2 (done) => task3 (wait) means need a reboot + t1.SetStatus(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (abort) => task3 (do) + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.AbortStatus) + t3.SetStatus(state.DoStatus) + c.Check(chg.Status(), Equals, state.AbortStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoTwoTasks(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("wait-task", "...") + t2.WaitFor(t1) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + + // Put t3 into wait-status to trigger the isWaiting logic each time + // for the change. + t3.SetToWait(state.DoneStatus) + + // we use <=| to denote the reverse dependence relationship + // followed by undo logic + + // task1 (undo) <=| task2 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) <=| task2 (wait) means need a reboot + t1.SetStatus(state.UndoStatus) + t2.SetToWait(state.DoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) <=| task2 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} + +func (cs *changeSuite) TestIsWaitingUndoMultipleDependencies(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + + t1 := st.NewTask("task1", "...") + t2 := st.NewTask("task2", "...") + t3 := st.NewTask("task3", "...") + t4 := st.NewTask("task4", "...") + t5 := st.NewTask("wait-task", "...") + t3.WaitFor(t1) + t3.WaitFor(t2) + t4.WaitFor(t1) + t4.WaitFor(t2) + + chg.AddTask(t1) + chg.AddTask(t2) + chg.AddTask(t3) + chg.AddTask(t4) + chg.AddTask(t5) + + // Put t5 into wait-status to trigger the isWaiting logic each time + // for the change. + t5.SetToWait(state.DoneStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undo) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (undone) no reboot + t1.SetStatus(state.UndoStatus) + t2.SetStatus(state.UndoStatus) + t3.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (undo) + task2 (undo) <=| task3 (wait) + task4 (error) means + // need reboot to continue undoing 1 and 2 + t3.SetStatus(state.ErrorStatus) + t4.SetToWait(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (undo) + task2 (undo) => task3 (error) + task4 (wait) means + // need reboot to continue undoing 1 and 2 + t3.SetToWait(state.UndoneStatus) + t4.SetStatus(state.ErrorStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undo) no reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoStatus) + c.Check(chg.Status(), Equals, state.UndoStatus) + + // task1 (wait) + task2 (done) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetStatus(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) + + // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undone) means need a reboot + t1.SetToWait(state.DoneStatus) + t2.SetToWait(state.DoneStatus) + t3.SetStatus(state.UndoneStatus) + t4.SetStatus(state.UndoneStatus) + c.Check(chg.Status(), Equals, state.WaitStatus) +} diff --git a/internals/overlord/state/copy.go b/internals/overlord/state/copy.go new file mode 100644 index 00000000..f87c2e85 --- /dev/null +++ b/internals/overlord/state/copy.go @@ -0,0 +1,141 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2020 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package state + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/canonical/pebble/internals/jsonutil" + "github.com/canonical/pebble/internals/osutil" +) + +type checkpointOnlyBackend struct { + path string +} + +func (b *checkpointOnlyBackend) Checkpoint(data []byte) error { + if err := os.MkdirAll(filepath.Dir(b.path), 0755); err != nil { + return err + } + return osutil.AtomicWriteFile(b.path, data, 0600, 0) +} + +func (b *checkpointOnlyBackend) EnsureBefore(d time.Duration) { + panic("cannot use EnsureBefore in checkpointOnlyBackend") +} + +// copyData will copy the given subkeys specifier from srcData to dstData. +// +// The subkeys is constructed from a dotted path like "user.auth". This copy +// helper is recursive and the pos parameter tells the function the current +// position of the copy. +func copyData(subkeys []string, pos int, srcData map[string]*json.RawMessage, dstData map[string]interface{}) error { + if pos < 0 || pos > len(subkeys) { + return fmt.Errorf("internal error: copyData used with an out-of-bounds position: %v not in [0:%v]", pos, len(subkeys)) + } + raw, ok := srcData[subkeys[pos]] + if !ok { + return ErrNoState + } + + if pos+1 == len(subkeys) { + dstData[subkeys[pos]] = raw + return nil + } + + var srcDatam map[string]*json.RawMessage + if err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &srcDatam); err != nil { + return fmt.Errorf("cannot unmarshal state entry %q with value %q as a map while trying to copy over %q", strings.Join(subkeys[:pos+1], "."), *raw, strings.Join(subkeys, ".")) + } + + // no subkey entry -> create one + if _, ok := dstData[subkeys[pos]]; !ok { + dstData[subkeys[pos]] = make(map[string]interface{}) + } + // and use existing data + var dstDatam map[string]interface{} + switch dstDataEntry := dstData[subkeys[pos]].(type) { + case map[string]interface{}: + dstDatam = dstDataEntry + case *json.RawMessage: + dstDatam = make(map[string]interface{}) + if err := jsonutil.DecodeWithNumber(bytes.NewReader(*dstDataEntry), &dstDatam); err != nil { + return fmt.Errorf("internal error: cannot decode subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstDataEntry) + } + default: + return fmt.Errorf("internal error: cannot create subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstData[subkeys[pos]]) + } + + return copyData(subkeys, pos+1, srcDatam, dstDatam) +} + +// CopyState takes a state from the srcStatePath and copies all +// dataEntries to the dstPath. Note that srcStatePath should never +// point to a state that is in use. +func CopyState(srcStatePath, dstStatePath string, dataEntries []string) error { + if osutil.FileExists(dstStatePath) { + // XXX: TOCTOU - look into moving this check into + // checkpointOnlyBackend. The issue is right now State + // will simply panic if Commit() returns an error + return fmt.Errorf("cannot copy state: %q already exists", dstStatePath) + } + if len(dataEntries) == 0 { + return fmt.Errorf("cannot copy state: must provide at least one data entry to copy") + } + + f, err := os.Open(srcStatePath) + if err != nil { + return fmt.Errorf("cannot open state: %s", err) + } + defer f.Close() + + // No need to lock/unlock the state here, srcState should not be + // in use at all. + srcState, err := ReadState(nil, f) + if err != nil { + return err + } + + // copy relevant data + dstData := make(map[string]interface{}) + for _, dataEntry := range dataEntries { + subkeys := strings.Split(dataEntry, ".") + if err := copyData(subkeys, 0, srcState.data, dstData); err != nil && !errors.Is(err, ErrNoState) { + return err + } + } + + // write it out + dstState := New(&checkpointOnlyBackend{path: dstStatePath}) + dstState.Lock() + defer dstState.Unlock() + for k, v := range dstData { + dstState.Set(k, v) + } + + return nil +} diff --git a/internals/overlord/state/copy_test.go b/internals/overlord/state/copy_test.go new file mode 100644 index 00000000..e7130e88 --- /dev/null +++ b/internals/overlord/state/copy_test.go @@ -0,0 +1,146 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016-2020 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package state_test + +import ( + "io/ioutil" + "path/filepath" + + . "gopkg.in/check.v1" + + "github.com/canonical/pebble/internals/overlord/state" +) + +func (ss *stateSuite) TestCopyStateAlreadyExists(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err := ioutil.WriteFile(dstStateFile, nil, 0644) + c.Assert(err, IsNil) + + err = state.CopyState(srcStateFile, dstStateFile, []string{"some-data"}) + c.Assert(err, ErrorMatches, `cannot copy state: "/.*/dst-state.json" already exists`) +} +func (ss *stateSuite) TestCopyStateNoDataEntriesToCopy(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + + err := state.CopyState(srcStateFile, dstStateFile, nil) + c.Assert(err, ErrorMatches, `cannot copy state: must provide at least one data entry to copy`) +} + +var srcStateContent = []byte(` +{ + "data": { + "api-download-tokens-secret": "123", + "api-download-tokens-secret-time": "2020-02-21T10:32:37.916147296Z", + "auth": { + "last-id": 1, + "users": [ + { + "id": 1, + "email": "some@user.com", + "macaroon": "1234", + "store-macaroon": "5678", + "store-discharges": [ + "9012345" + ] + } + ], + "device": { + "brand": "generic", + "model": "generic-classic", + "serial": "xxxxx-yyyyy-", + "key-id": "xxxxxx", + "session-macaroon": "xxxx" + }, + "macaroon-key": "xxxx=" + }, + "config": { + } + } +} +`) + +const stateSuffix = `,"changes":{},"tasks":{},"last-change-id":0,"last-task-id":0,"last-lane-id":0}` + +func (ss *stateSuite) TestCopyStateIntegration(c *C) { + // create a mock srcState + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent, 0644) + c.Assert(err, IsNil) + + // copy + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"auth.users", "no-existing-does-not-error", "auth.last-id"}) + c.Assert(err, IsNil) + + // and check that the right bits got copied + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"auth":{"last-id":1,"users":[{"id":1,"email":"some@user.com","macaroon":"1234","store-macaroon":"5678","store-discharges":["9012345"]}]}}`+stateSuffix) +} + +var srcStateContent1 = []byte(`{ + "data": { + "A": {"B": [{"C": 1}, {"D": 2}]}, + "E": {"F": 2, "G": 3}, + "H": 4, + "I": null + } +}`) + +func (ss *stateSuite) TestCopyState(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"A.B", "no-existing-does-not-error", "E.F", "E", "I", "E.non-existing"}) + c.Assert(err, IsNil) + + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"A":{"B":[{"C":1},{"D":2}]},"E":{"F":2,"G":3},"I":null}`+stateSuffix) +} + +func (ss *stateSuite) TestCopyStateUnmarshalNotMap(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"E.F.subkey-not-in-a-map"}) + c.Assert(err, ErrorMatches, `cannot unmarshal state entry "E.F" with value "2" as a map while trying to copy over "E.F.subkey-not-in-a-map"`) +} + +func (ss *stateSuite) TestCopyStateDuplicatesInDataEntriesAreFine(c *C) { + srcStateFile := filepath.Join(c.MkDir(), "src-state.json") + err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644) + c.Assert(err, IsNil) + + dstStateFile := filepath.Join(c.MkDir(), "dst-state.json") + err = state.CopyState(srcStateFile, dstStateFile, []string{"E", "E"}) + c.Assert(err, IsNil) + + dstContent, err := ioutil.ReadFile(dstStateFile) + c.Assert(err, IsNil) + c.Check(string(dstContent), Equals, `{"data":{"E":{"F":2,"G":3}}`+stateSuffix) +} diff --git a/internals/overlord/state/export_test.go b/internals/overlord/state/export_test.go index adbcc5d5..3bf6f8ef 100644 --- a/internals/overlord/state/export_test.go +++ b/internals/overlord/state/export_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -23,8 +23,8 @@ import ( "time" ) -// FakeCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime. -func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) { +// MockCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime. +func MockCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) { oldInterval := unlockCheckpointRetryInterval oldMaxTime := unlockCheckpointRetryMaxTime unlockCheckpointRetryInterval = retryInterval @@ -35,12 +35,12 @@ func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restor } } -func FakeChangeTimes(chg *Change, spawnTime, readyTime time.Time) { +func MockChangeTimes(chg *Change, spawnTime, readyTime time.Time) { chg.spawnTime = spawnTime chg.readyTime = readyTime } -func FakeTaskTimes(t *Task, spawnTime, readyTime time.Time) { +func MockTaskTimes(t *Task, spawnTime, readyTime time.Time) { t.spawnTime = spawnTime t.readyTime = readyTime } diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go index 660a8a64..86733024 100644 --- a/internals/overlord/state/state.go +++ b/internals/overlord/state/state.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -46,7 +46,7 @@ type customData map[string]*json.RawMessage func (data customData) get(key string, value interface{}) error { entryJSON := data[key] if entryJSON == nil { - return ErrNoState + return &NoStateError{Key: key} } err := json.Unmarshal(*entryJSON, value) if err != nil { @@ -87,6 +87,9 @@ type State struct { lastTaskId int lastChangeId int lastLaneId int + // lastHandlerId is not serialized, it's only used during runtime + // for registering runtime callbacks + lastHandlerId int backend Backend data customData @@ -97,18 +100,27 @@ type State struct { modified bool cache map[interface{}]interface{} + + pendingChangeByAttr map[string]func(*Change) bool + + // task/changes observing + taskHandlers map[int]func(t *Task, old, new Status) + changeHandlers map[int]func(chg *Change, old, new Status) } // New returns a new empty state. func New(backend Backend) *State { return &State{ - backend: backend, - data: make(customData), - changes: make(map[string]*Change), - tasks: make(map[string]*Task), - warnings: make(map[string]*Warning), - modified: true, - cache: make(map[interface{}]interface{}), + backend: backend, + data: make(customData), + changes: make(map[string]*Change), + tasks: make(map[string]*Task), + warnings: make(map[string]*Warning), + modified: true, + cache: make(map[interface{}]interface{}), + pendingChangeByAttr: make(map[string]func(*Change) bool), + taskHandlers: make(map[int]func(t *Task, old Status, new Status)), + changeHandlers: make(map[int]func(chg *Change, old Status, new Status)), } } @@ -241,6 +253,28 @@ func (s *State) EnsureBefore(d time.Duration) { // ErrNoState represents the case of no state entry for a given key. var ErrNoState = errors.New("no state entry for key") +// NoStateError represents the case where no state could be found for a given key. +type NoStateError struct { + // Key is the key for which no state could be found. + Key string +} + +func (e *NoStateError) Error() string { + var keyMsg string + if e.Key != "" { + keyMsg = fmt.Sprintf(" %q", e.Key) + } + + return fmt.Sprintf("no state entry for key%s", keyMsg) +} + +// Is returns true if the error is of type *NoStateError or equal to ErrNoState. +// NoStateError's key isn't compared between errors. +func (e *NoStateError) Is(err error) bool { + _, ok := err.(*NoStateError) + return ok || errors.Is(err, ErrNoState) +} + // Get unmarshals the stored value associated with the provided key // into the value parameter. // It returns ErrNoState if there is no entry for key. @@ -249,6 +283,12 @@ func (s *State) Get(key string, value interface{}) error { return s.data.get(key, value) } +// Has returns whether the provided key has an associated value. +func (s *State) Has(key string) bool { + s.reading() + return s.data.has(key) +} + // Set associates value with key for future consulting by managers. // The provided value must properly marshal and unmarshal with encoding/json. func (s *State) Set(key string, value interface{}) { @@ -357,15 +397,25 @@ func (s *State) tasksIn(tids []string) []*Task { return res } +// RegisterPendingChangeByAttr registers predicates that will be invoked by +// Prune on changes with the specified attribute set to check whether even if +// they meet the time criteria they must not be aborted yet. +func (s *State) RegisterPendingChangeByAttr(attr string, f func(*Change) bool) { + s.pendingChangeByAttr[attr] = f +} + // Prune does several cleanup tasks to the in-memory state: // // - it removes changes that became ready for more than pruneWait and aborts -// tasks spawned for more than abortWait. +// tasks spawned for more than abortWait unless prevented by predicates +// registered with RegisterPendingChangeByAttr. +// // - it removes tasks unlinked to changes after pruneWait. When there are more // changes than the limit set via "maxReadyChanges" those changes in ready // state will also removed even if they are below the pruneWait duration. +// // - it removes expired warnings. -func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { +func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) { now := time.Now() pruneLimit := now.Add(-pruneWait) abortLimit := now.Add(-abortWait) @@ -392,15 +442,24 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +NextChange: for _, chg := range changes { - spawnTime := chg.SpawnTime() readyTime := chg.ReadyTime() + spawnTime := chg.SpawnTime() + if spawnTime.Before(startOfOperation) { + spawnTime = startOfOperation + } if readyTime.IsZero() { if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 { chg.Abort() delete(s.changes, chg.ID()) } else if spawnTime.Before(abortLimit) { - chg.Abort() + for attr, pending := range s.pendingChangeByAttr { + if chg.Has(attr) && pending(chg) { + continue NextChange + } + } + chg.AbortUnreadyLanes() } continue } @@ -424,6 +483,75 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { } } +// GetMaybeTimings implements snapcore/snapd/timings.GetSaver +func (s *State) GetMaybeTimings(timings interface{}) error { + if err := s.Get("timings", timings); err != nil && !errors.Is(err, ErrNoState) { + return err + } + return nil +} + +// AddTaskStatusChangedHandler adds a callback function that will be invoked +// whenever tasks change status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddTaskStatusChangedHandler(f func(t *Task, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.taskHandlers[id] = f + return id +} + +func (s *State) RemoveTaskStatusChangedHandler(id int) { + s.reading() + delete(s.taskHandlers, id) +} + +func (s *State) notifyTaskStatusChangedHandlers(t *Task, old, new Status) { + s.reading() + for _, f := range s.taskHandlers { + f(t, old, new) + } +} + +// AddChangeStatusChangedHandler adds a callback function that will be invoked +// whenever a Change changes status. +// NOTE: Callbacks registered this way may be invoked in the context +// of the taskrunner, so the callbacks should be as simple as possible, and return +// as quickly as possible, and should avoid the use of i/o code or blocking, as this +// will stop the entire task system. +func (s *State) AddChangeStatusChangedHandler(f func(chg *Change, old, new Status)) (id int) { + // We are reading here as we want to ensure access to the state is serialized, + // and not writing as we are not changing the part of state that goes on the disk. + s.reading() + id = s.lastHandlerId + s.lastHandlerId++ + s.changeHandlers[id] = f + return id +} + +func (s *State) RemoveChangeStatusChangedHandler(id int) { + s.reading() + delete(s.changeHandlers, id) +} + +func (s *State) notifyChangeStatusChangedHandlers(chg *Change, old, new Status) { + s.reading() + for _, f := range s.changeHandlers { + f(chg, old, new) + } +} + +// SaveTimings implements snapcore/snapd/timings.GetSaver +func (s *State) SaveTimings(timings interface{}) { + s.Set("timings", timings) +} + // ReadState returns the state deserialized from r. func ReadState(backend Backend, r io.Reader) (*State, error) { s := new(State) @@ -437,5 +565,6 @@ func ReadState(backend Backend, r io.Reader) (*State, error) { s.backend = backend s.modified = false s.cache = make(map[interface{}]interface{}) + s.pendingChangeByAttr = make(map[string]func(*Change) bool) return s, err } diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go index e1e8a707..2ab46026 100644 --- a/internals/overlord/state/state_test.go +++ b/internals/overlord/state/state_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -29,6 +29,7 @@ import ( . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) func TestState(t *testing.T) { TestingT(t) } @@ -76,6 +77,37 @@ func (ss *stateSuite) TestGetAndSet(c *C) { c.Check(&mSt2B, DeepEquals, mSt2) } +func (ss *stateSuite) TestHas(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + c.Check(st.Has("a"), Equals, false) + + st.Set("a", 1) + c.Check(st.Has("a"), Equals, true) + + st.Set("a", nil) + c.Check(st.Has("a"), Equals, false) +} + +func (ss *stateSuite) TestStrayTaskWithNoChange(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("change", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + _ = st.NewTask("bar", "...") + + // only the task with associate change is returned + c.Assert(st.Tasks(), HasLen, 1) + c.Assert(st.Tasks()[0].ID(), Equals, t1.ID()) + // but count includes all tasks + c.Assert(st.TaskCount(), Equals, 2) +} + func (ss *stateSuite) TestSetPanic(c *C) { st := state.New(nil) st.Lock() @@ -94,7 +126,7 @@ func (ss *stateSuite) TestGetNoState(c *C) { var mSt1B mgrState1 err := st.Get("mgr9", &mSt1B) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) } func (ss *stateSuite) TestSetToNilDeletes(c *C) { @@ -112,7 +144,7 @@ func (ss *stateSuite) TestSetToNilDeletes(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -126,7 +158,7 @@ func (ss *stateSuite) TestNullMeansNoState(c *C) { var v1 map[string]int err = st.Get("a", &v1) - c.Check(err, Equals, state.ErrNoState) + c.Check(err, testutil.ErrorIs, state.ErrNoState) c.Check(v1, HasLen, 0) } @@ -227,7 +259,7 @@ func (ss *stateSuite) TestImplicitCheckpointAndRead(c *C) { } func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) defer restore() retries := 0 @@ -250,7 +282,7 @@ func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) { } func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond) defer restore() boom := errors.New("boom") @@ -272,7 +304,7 @@ func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) { } func (ss *stateSuite) TestImplicitCheckpointModifiedOnly(c *C) { - restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) + restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second) defer restore() b := &fakeStateBackend{} @@ -421,7 +453,7 @@ func (ss *stateSuite) TestNewTaskAndCheckpoint(c *C) { t1ID := t1.ID() t1.Set("a", 1) t1.SetStatus(state.DoneStatus) - t1.SetProgress("foo", 5, 10) + t1.SetProgress("snap", 5, 10) t1.JoinLane(42) t1.JoinLane(43) @@ -702,7 +734,7 @@ func (ss *stateSuite) TestMethodEntrance(c *C) { func() { st.Tasks() }, func() { st.Task("foo") }, func() { st.MarshalJSON() }, - func() { st.Prune(time.Hour, time.Hour, 100) }, + func() { st.Prune(time.Now(), time.Hour, time.Hour, 100) }, func() { st.TaskCount() }, func() { st.AllWarnings() }, func() { st.PendingWarnings() }, @@ -744,31 +776,32 @@ func (ss *stateSuite) TestPrune(c *C) { chg1 := st.NewChange("abort", "...") chg1.AddTask(t1) - state.FakeChangeTimes(chg1, now.Add(-abortWait), unset) + state.MockChangeTimes(chg1, now.Add(-abortWait), unset) chg2 := st.NewChange("prune", "...") chg2.AddTask(t2) c.Assert(chg2.Status(), Equals, state.DoStatus) - state.FakeChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait)) + state.MockChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait)) chg3 := st.NewChange("ready-but-recent", "...") chg3.AddTask(t3) - state.FakeChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2)) + state.MockChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2)) chg4 := st.NewChange("old-but-not-ready", "...") chg4.AddTask(t4) - state.FakeChangeTimes(chg4, now.Add(-pruneWait/2), unset) + state.MockChangeTimes(chg4, now.Add(-pruneWait/2), unset) // unlinked task t5 := st.NewTask("unliked", "...") c.Check(st.Task(t5.ID()), IsNil) - state.FakeTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait)) + state.MockTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait)) // two warnings, one expired st.AddWarning("hello", now, never, time.Nanosecond, state.DefaultRepeatAfter) st.Warnf("hello again") - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg1.ID()), Equals, chg1) c.Assert(st.Change(chg2.ID()), IsNil) @@ -793,6 +826,55 @@ func (ss *stateSuite) TestPrune(c *C) { c.Check(st.AllWarnings(), HasLen, 1) } +func (ss *stateSuite) TestRegisterPendingChangeByAttr(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + unset := time.Time{} + + t1 := st.NewTask("foo", "...") + t2 := st.NewTask("foo", "...") + t3 := st.NewTask("foo", "...") + t4 := st.NewTask("foo", "...") + + chg1 := st.NewChange("abort", "...") + chg1.AddTask(t1) + chg1.AddTask(t2) + state.MockChangeTimes(chg1, now.Add(-abortWait), unset) + + chg2 := st.NewChange("pending", "...") + chg2.AddTask(t3) + chg2.AddTask(t4) + state.MockChangeTimes(chg2, now.Add(-abortWait), unset) + chg2.Set("pending-flag", true) + t3.SetStatus(state.HoldStatus) + + st.RegisterPendingChangeByAttr("pending-flag", func(chg *state.Change) bool { + c.Check(chg.ID(), Equals, chg2.ID()) + return true + }) + + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) + + c.Assert(st.Change(chg1.ID()), Equals, chg1) + c.Assert(st.Change(chg2.ID()), Equals, chg2) + c.Assert(st.Task(t1.ID()), Equals, t1) + c.Assert(st.Task(t2.ID()), Equals, t2) + c.Assert(st.Task(t3.ID()), Equals, t3) + c.Assert(st.Task(t4.ID()), Equals, t4) + + c.Assert(t1.Status(), Equals, state.HoldStatus) + c.Assert(t2.Status(), Equals, state.HoldStatus) + c.Assert(t3.Status(), Equals, state.HoldStatus) + c.Assert(t4.Status(), Equals, state.DoStatus) +} + func (ss *stateSuite) TestPruneEmptyChange(c *C) { // Empty changes are a bit special because they start out on Hold // which is a Ready status, but the change itself is not considered Ready @@ -807,9 +889,10 @@ func (ss *stateSuite) TestPruneEmptyChange(c *C) { abortWait := 3 * time.Hour chg := st.NewChange("abort", "...") - state.FakeChangeTimes(chg, now.Add(-pruneWait), time.Time{}) + state.MockChangeTimes(chg, now.Add(-pruneWait), time.Time{}) - st.Prune(pruneWait, abortWait, 100) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, pruneWait, abortWait, 100) c.Assert(st.Change(chg.ID()), IsNil) } @@ -831,7 +914,7 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) { t.SetStatus(state.DoneStatus) when := time.Duration(i) * time.Second - state.FakeChangeTimes(chg, now.Add(-when), now.Add(-when)) + state.MockChangeTimes(chg, now.Add(-when), now.Add(-when)) } c.Assert(st.Changes(), HasLen, 10) @@ -844,13 +927,14 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) { // test that nothing is done when we are within pruneWait and // maxReadyChanges + past := time.Now().AddDate(-1, 0, 0) maxReadyChanges := 100 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 15) // but with maxReadyChanges we remove the ready ones maxReadyChanges = 5 - st.Prune(pruneWait, abortWait, maxReadyChanges) + st.Prune(past, pruneWait, abortWait, maxReadyChanges) c.Assert(st.Changes(), HasLen, 10) remaining := map[string]bool{} for _, chg := range st.Changes() { @@ -886,8 +970,9 @@ func (ss *stateSuite) TestPruneMaxChangesSomeNotReady(c *C) { c.Assert(st.Changes(), HasLen, 10) // nothing can be pruned + past := time.Now().AddDate(-1, 0, 0) maxChanges := 5 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 10) } @@ -905,10 +990,10 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { c.Assert(st.Changes(), HasLen, 10) // one extra change that just now entered ready state - chg := st.NewChange(fmt.Sprintf("chg99"), "so-ready") + chg := st.NewChange("chg99", "so-ready") t := st.NewTask("foo", "so-ready") when := 1 * time.Second - state.FakeChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when)) + state.MockChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when)) t.SetStatus(state.DoneStatus) chg.AddTask(t) @@ -916,11 +1001,45 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) { // // this test we do not purge the freshly ready change maxChanges := 10 - st.Prune(1*time.Hour, 3*time.Hour, maxChanges) + past := time.Now().AddDate(-1, 0, 0) + st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges) c.Assert(st.Changes(), HasLen, 11) } -func (ss *stateSuite) TestReadStateInitsCache(c *C) { +func (ss *stateSuite) TestPruneHonorsStartOperationTime(c *C) { + st := state.New(&fakeStateBackend{}) + st.Lock() + defer st.Unlock() + + now := time.Now() + + startTime := 2 * time.Hour + spawnTime := 10 * time.Hour + pruneWait := 1 * time.Hour + abortWait := 3 * time.Hour + + chg := st.NewChange("change", "...") + t := st.NewTask("foo", "") + chg.AddTask(t) + // change spawned 10h ago + state.MockChangeTimes(chg, now.Add(-spawnTime), time.Time{}) + + // start operation time is 2h ago, change is not aborted because + // it's less than abortWait limit. + opTime := now.Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.DoStatus) + + // start operation time is 9h ago, change is aborted. + startTime = 9 * time.Hour + opTime = time.Now().Add(-startTime) + st.Prune(opTime, pruneWait, abortWait, 100) + c.Assert(st.Changes(), HasLen, 1) + c.Check(chg.Status(), Equals, state.HoldStatus) +} + +func (ss *stateSuite) TestReadStateInitsTransientMapFields(c *C) { st, err := state.ReadState(nil, bytes.NewBufferString("{}")) c.Assert(err, IsNil) st.Lock() @@ -928,4 +1047,195 @@ func (ss *stateSuite) TestReadStateInitsCache(c *C) { st.Cache("key", "value") c.Assert(st.Cached("key"), Equals, "value") + st.RegisterPendingChangeByAttr("attr", func(*state.Change) bool { return false }) +} + +func (ss *stateSuite) TestTimingsSupport(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var tims []int + + err := st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, IsNil) + + st.SaveTimings([]int{1, 2, 3}) + + err = st.GetMaybeTimings(&tims) + c.Assert(err, IsNil) + c.Check(tims, DeepEquals, []int{1, 2, 3}) +} + +func (ss *stateSuite) TestNoStateErrorIs(c *C) { + err := &state.NoStateError{Key: "foo"} + c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) + c.Assert(err, testutil.ErrorIs, &state.NoStateError{Key: "bar"}) + c.Assert(err, testutil.ErrorIs, state.ErrNoState) +} + +func (ss *stateSuite) TestNoStateErrorString(c *C) { + err := &state.NoStateError{} + c.Assert(err.Error(), Equals, `no state entry for key`) + err.Key = "foo" + c.Assert(err.Error(), Equals, `no state entry for key "foo"`) +} + +type taskAndStatus struct { + t *state.Task + old, new state.Status +} + +func (ss *stateSuite) TestTaskChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var taskObservedChanges []taskAndStatus + oId := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) { + taskObservedChanges = append(taskObservedChanges, taskAndStatus{ + t: t, + old: old, + new: new, + }) + }) + + t1 := st.NewTask("foo", "...") + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // task events when task don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to done. + t1.SetStatus(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveTaskStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoingStatus) + + c.Check(taskObservedChanges, DeepEquals, []taskAndStatus{ + { + t: t1, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + t: t1, + old: state.DoingStatus, + new: state.DoneStatus, + }, + }) +} + +type changeAndStatus struct { + chg *state.Change + old, new state.Status +} + +func (ss *stateSuite) TestChangeChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // Set task status to identical status, we don't want + // change events when changes don't actually change status. + t1.SetStatus(state.DoingStatus) + + // Set task to waiting + t1.SetToWait(state.DoneStatus) + + // Unregister us, and make sure we do not receive more events. + st.RemoveChangeStatusChangedHandler(oId) + + // must not appear in list. + t1.SetStatus(state.DoneStatus) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + }) +} + +func (ss *stateSuite) TestChangeSetStatusChangedHandler(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + var observedChanges []changeAndStatus + oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) { + observedChanges = append(observedChanges, changeAndStatus{ + chg: chg, + old: old, + new: new, + }) + }) + + chg := st.NewChange("test-chg", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + + t1.SetStatus(state.DoingStatus) + + // We have a single task in Doing, now we manipulate the status + // of the change to ensure we are receiving correct events + chg.SetStatus(state.WaitStatus) + + // Change to a new status + chg.SetStatus(state.ErrorStatus) + + // Now return the status back to Default, which should result + // in the change reporting Doing + chg.SetStatus(state.DefaultStatus) + st.RemoveChangeStatusChangedHandler(oId) + + c.Check(observedChanges, DeepEquals, []changeAndStatus{ + { + chg: chg, + old: state.DefaultStatus, + new: state.DoingStatus, + }, + { + chg: chg, + old: state.DoingStatus, + new: state.WaitStatus, + }, + { + chg: chg, + old: state.WaitStatus, + new: state.ErrorStatus, + }, + { + chg: chg, + old: state.ErrorStatus, + new: state.DoingStatus, + }, + }) } diff --git a/internals/overlord/state/task.go b/internals/overlord/state/task.go index a3e2a486..fb5ec826 100644 --- a/internals/overlord/state/task.go +++ b/internals/overlord/state/task.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -38,19 +38,22 @@ type progress struct { // // See Change for more details. type Task struct { - state *State - id string - kind string - summary string - status Status - clean bool - progress *progress - data customData - waitTasks []string - haltTasks []string - lanes []int - log []string - change string + state *State + id string + kind string + summary string + status Status + // waitedStatus is the Status that should be used instead of + // WaitStatus once the wait is complete (i.e post reboot). + waitedStatus Status + clean bool + progress *progress + data customData + waitTasks []string + haltTasks []string + lanes []int + log []string + change string spawnTime time.Time readyTime time.Time @@ -77,18 +80,19 @@ func newTask(state *State, id, kind, summary string) *Task { } type marshalledTask struct { - ID string `json:"id"` - Kind string `json:"kind"` - Summary string `json:"summary"` - Status Status `json:"status"` - Clean bool `json:"clean,omitempty"` - Progress *progress `json:"progress,omitempty"` - Data map[string]*json.RawMessage `json:"data,omitempty"` - WaitTasks []string `json:"wait-tasks,omitempty"` - HaltTasks []string `json:"halt-tasks,omitempty"` - Lanes []int `json:"lanes,omitempty"` - Log []string `json:"log,omitempty"` - Change string `json:"change"` + ID string `json:"id"` + Kind string `json:"kind"` + Summary string `json:"summary"` + Status Status `json:"status"` + WaitedStatus Status `json:"waited-status"` + Clean bool `json:"clean,omitempty"` + Progress *progress `json:"progress,omitempty"` + Data map[string]*json.RawMessage `json:"data,omitempty"` + WaitTasks []string `json:"wait-tasks,omitempty"` + HaltTasks []string `json:"halt-tasks,omitempty"` + Lanes []int `json:"lanes,omitempty"` + Log []string `json:"log,omitempty"` + Change string `json:"change"` SpawnTime time.Time `json:"spawn-time"` ReadyTime *time.Time `json:"ready-time,omitempty"` @@ -111,18 +115,19 @@ func (t *Task) MarshalJSON() ([]byte, error) { atTime = &t.atTime } return json.Marshal(marshalledTask{ - ID: t.id, - Kind: t.kind, - Summary: t.summary, - Status: t.status, - Clean: t.clean, - Progress: t.progress, - Data: t.data, - WaitTasks: t.waitTasks, - HaltTasks: t.haltTasks, - Lanes: t.lanes, - Log: t.log, - Change: t.change, + ID: t.id, + Kind: t.kind, + Summary: t.summary, + Status: t.status, + WaitedStatus: t.waitedStatus, + Clean: t.clean, + Progress: t.progress, + Data: t.data, + WaitTasks: t.waitTasks, + HaltTasks: t.haltTasks, + Lanes: t.lanes, + Log: t.log, + Change: t.change, SpawnTime: t.spawnTime, ReadyTime: readyTime, @@ -148,6 +153,13 @@ func (t *Task) UnmarshalJSON(data []byte) error { t.kind = unmarshalled.Kind t.summary = unmarshalled.Summary t.status = unmarshalled.Status + t.waitedStatus = unmarshalled.WaitedStatus + if t.waitedStatus == DefaultStatus { + // For backwards-compatibility, default the waitStatus, which is + // the result status after a wait, to DoneStatus to keep any previous + // behaviour before any upgrade. + t.waitedStatus = DoneStatus + } t.clean = unmarshalled.Clean t.progress = unmarshalled.Progress custData := unmarshalled.Data @@ -188,6 +200,40 @@ func (t *Task) Summary() string { } // Status returns the current task status. +// +// Possible state transitions: +// +// /----aborting lane--Do +// | | +// V V +// Hold Doing-->Wait +// ^ / | \ +// | abort / V V +// no undo / Done Error +// | V | +// \----------Abort aborting lane +// / | | +// | finished or | +// running not running | +// V \------->| +// kill goroutine | +// | V +// / \ ----->Undo +// / no error / | +// | from goroutine | +// error | +// from goroutine | +// | V +// | Undoing-->Wait +// V | \ +// Error V V +// Undone Error +// +// Do -> Doing -> Done is the direct succcess scenario. +// +// Wait can transition to its waited status, +// usually Done|Undone or back to Doing. +// See Wait struct, SetToWait and WaitedStatus. func (t *Task) Status() Status { t.state.reading() if t.status == DefaultStatus { @@ -196,10 +242,10 @@ func (t *Task) Status() Status { return t.status } -// SetStatus sets the task status, overriding the default behavior (see Status method). -func (t *Task) SetStatus(new Status) { - t.state.writing() - old := t.status +func (t *Task) changeStatus(old, new Status) { + if old == new { + return + } t.status = new if !old.Ready() && new.Ready() { t.readyTime = timeNow() @@ -208,6 +254,55 @@ func (t *Task) SetStatus(new Status) { if chg != nil { chg.taskStatusChanged(t, old, new) } + t.state.notifyTaskStatusChangedHandlers(t, old, new) +} + +// SetStatus sets the task status, overriding the default behavior (see Status method). +func (t *Task) SetStatus(new Status) { + if new == WaitStatus { + panic("Task.SetStatus() called with WaitStatus, which is not allowed. Use SetToWait() instead") + } + + t.state.writing() + old := t.status + if new == DoneStatus && old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // DoneStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.changeStatus(old, new) +} + +// SetToWait puts the task into WaitStatus, and sets the status the task should be restored +// to after the SetToWait. +func (t *Task) SetToWait(resultStatus Status) { + switch resultStatus { + case DefaultStatus, WaitStatus: + panic("Task.SetToWait() cannot be invoked with either of DefaultStatus or WaitStatus") + } + + t.state.writing() + old := t.status + if old == AbortStatus { + // if the task is in AbortStatus (because some other task ran + // in parallel and had an error so the change is aborted) and + // WaitStatus was requested (which can happen if the + // task handler sets its status explicitly) then keep it at + // aborted so it can transition to Undo. + return + } + t.waitedStatus = resultStatus + t.changeStatus(old, WaitStatus) +} + +// WaitedStatus returns the status the Task should return to once the current WaitStatus +// has been resolved. +func (t *Task) WaitedStatus() Status { + t.state.reading() + return t.waitedStatus } // IsClean returns whether the task has been cleaned. See SetClean. @@ -320,7 +415,7 @@ const ( var timeNow = time.Now -func FakeTime(now time.Time) (restore func()) { +func MockTime(now time.Time) (restore func()) { timeNow = func() time.Time { return now } return func() { timeNow = time.Now } } @@ -332,7 +427,7 @@ func (t *Task) addLog(kind, format string, args []interface{}) { } tstr := timeNow().Format(time.RFC3339) - msg := fmt.Sprintf(tstr+" "+kind+" "+format, args...) + msg := tstr + " " + kind + " " + fmt.Sprintf(format, args...) t.log = append(t.log, msg) logger.Debugf(msg) } @@ -468,7 +563,7 @@ func (t *Task) At(when time.Time) { // TaskSetEdge designates tasks inside a TaskSet for outside reference. // // This is useful to give tasks inside TaskSets a special meaning. It -// is used to mark e.g. the last task in a task set. +// is used to mark e.g. the last task used for downloading a snap. type TaskSetEdge string // A TaskSet holds a set of tasks. @@ -484,10 +579,16 @@ func NewTaskSet(tasks ...*Task) *TaskSet { return &TaskSet{tasks, nil} } -// Edge returns the task marked with the given edge name. +// MaybeEdge returns the task marked with the given edge name or nil if no such +// task exists. +func (ts TaskSet) MaybeEdge(e TaskSetEdge) *Task { + return ts.edges[e] +} + +// Edge returns the task marked with the given edge name or an error. func (ts TaskSet) Edge(e TaskSetEdge) (*Task, error) { - t, ok := ts.edges[e] - if !ok { + t := ts.MaybeEdge(e) + if t == nil { return nil, fmt.Errorf("internal error: missing %q edge in task set", e) } return t, nil @@ -509,7 +610,7 @@ func (ts *TaskSet) WaitAll(anotherTs *TaskSet) { } } -// AddTask adds the the task to the task set. +// AddTask adds the task to the task set. func (ts *TaskSet) AddTask(task *Task) { for _, t := range ts.tasks { if t == task { @@ -522,6 +623,9 @@ func (ts *TaskSet) AddTask(task *Task) { // MarkEdge marks the given task as a specific edge. Any pre-existing // edge mark will be overridden. func (ts *TaskSet) MarkEdge(task *Task, edge TaskSetEdge) { + if task == nil { + panic(fmt.Sprintf("cannot set edge %q with nil task", edge)) + } if ts.edges == nil { ts.edges = make(map[TaskSetEdge]*Task) } diff --git a/internals/overlord/state/task_test.go b/internals/overlord/state/task_test.go index 9c073d80..d1e26e54 100644 --- a/internals/overlord/state/task_test.go +++ b/internals/overlord/state/task_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -127,7 +127,7 @@ func (ts *taskSuite) TestClear(c *C) { t.Clear("a") - c.Check(t.Get("a", &v), Equals, state.ErrNoState) + c.Check(t.Get("a", &v), testutil.ErrorIs, state.ErrNoState) } func (ts *taskSuite) TestStatusAndSetStatus(c *C) { @@ -144,6 +144,60 @@ func (ts *taskSuite) TestStatusAndSetStatus(c *C) { c.Check(t.Status(), Equals, state.DoneStatus) } +func (ts *taskSuite) TestSetDoneAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetStatus(state.DoneStatus) + c.Check(t.Status(), Equals, state.AbortStatus) +} + +func (ts *taskSuite) TestSetWaitAfterAbortNoop(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetStatus(state.AbortStatus) + c.Check(t.Status(), Equals, state.AbortStatus) + t.SetToWait(state.DoneStatus) // noop + c.Check(t.Status(), Equals, state.AbortStatus) + c.Check(t.WaitedStatus(), Equals, state.DefaultStatus) +} + +func (ts *taskSuite) TestSetWait(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t := st.NewTask("download", "1...") + t.SetToWait(state.DoneStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.DoneStatus) + t.SetToWait(state.UndoStatus) + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, state.UndoStatus) +} + +func (ts *taskSuite) TestTaskMarshalsWaitStatus(c *C) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + t1 := st.NewTask("download", "1...") + t1.SetToWait(state.UndoStatus) + + d, err := t1.MarshalJSON() + c.Assert(err, IsNil) + + needle := fmt.Sprintf(`"waited-status":%d`, t1.WaitedStatus()) + c.Assert(string(d), testutil.Contains, needle) +} + func (ts *taskSuite) TestIsCleanAndSetClean(c *C) { st := state.New(nil) st.Lock() @@ -174,9 +228,9 @@ func (ts *taskSuite) TestProgressAndSetProgress(c *C) { t := st.NewTask("download", "1...") - t.SetProgress("foo", 2, 99) + t.SetProgress("snap", 2, 99) label, cur, tot := t.Progress() - c.Check(label, Equals, "foo") + c.Check(label, Equals, "snap") c.Check(cur, Equals, 2) c.Check(tot, Equals, 99) @@ -305,7 +359,7 @@ func (ts *taskSuite) TestAt(c *C) { t := st.NewTask("download", "1...") now := time.Now() - restore := state.FakeTime(now) + restore := state.MockTime(now) defer restore() when := now.Add(10 * time.Second) t.At(when) @@ -561,6 +615,10 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { // edges are just typed strings edge1 := state.TaskSetEdge("on-edge") edge2 := state.TaskSetEdge("eddie") + edge3 := state.TaskSetEdge("not-found") + + // nil task causes panic + c.Check(func() { ts.MarkEdge(nil, edge1) }, PanicMatches, `cannot set edge "on-edge" with nil task`) // no edge marked yet t, err := ts.Edge(edge1) @@ -590,6 +648,12 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) { t, err = ts.Edge(edge1) c.Assert(t, Equals, t3) c.Assert(err, IsNil) + + // it is possible to check if edge exists without failing + t = ts.MaybeEdge(edge1) + c.Assert(t, Equals, t3) + t = ts.MaybeEdge(edge3) + c.Assert(t, IsNil) } func (cs *taskSuite) TestTaskAddAllWithEdges(c *C) { diff --git a/internals/overlord/state/taskrunner.go b/internals/overlord/state/taskrunner.go index 5d9ff244..c6a2b835 100644 --- a/internals/overlord/state/taskrunner.go +++ b/internals/overlord/state/taskrunner.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -46,6 +46,23 @@ func (r *Retry) Error() string { return "task should be retried" } +// Wait is returned from a handler to signal that the task cannot +// proceed at the moment maybe because some manual action from the +// user required at this point or because of errors. The task +// will be set to WaitStatus, and it's wait complete status will be +// set to WaitedStatus. +type Wait struct { + Reason string + // If not explicitly set, then WaitedStatus will default to + // DoneStatus, meaning that the task will be set to DoneStatus + // after the wait has resolved. + WaitedStatus Status +} + +func (r *Wait) Error() string { + return "task set to wait, manual action required" +} + type blockedFunc func(t *Task, running []*Task) bool // TaskRunner controls the running of goroutines to execute known task kinds. @@ -62,6 +79,9 @@ type TaskRunner struct { blocked []blockedFunc someBlocked bool + // optional callback executed on task errors + taskErrorCallback func(err error) + // go-routines lifecycle tombs map[string]*tomb.Tomb } @@ -85,6 +105,11 @@ func NewTaskRunner(s *State) *TaskRunner { } } +// OnTaskError sets an error callback executed when any task errors out. +func (r *TaskRunner) OnTaskError(f func(err error)) { + r.taskErrorCallback = f +} + // AddHandler registers the functions to concurrently call for doing and // undoing tasks of the given kind. The undo handler may be nil. func (r *TaskRunner) AddHandler(kind string, do, undo HandlerFunc) { @@ -214,7 +239,7 @@ func (r *TaskRunner) run(t *Task) { switch err.(type) { case nil: // we are ok - case *Retry: + case *Retry, *Wait: // preserve default: if r.stopped { @@ -227,13 +252,24 @@ func (r *TaskRunner) run(t *Task) { switch x := err.(type) { case *Retry: // Handler asked to be called again later. - // TODO Allow postponing retries past the next Ensure. if t.Status() == AbortStatus { // Would work without it but might take two ensures. r.tryUndo(t) } else if x.After != 0 { t.At(timeNow().Add(x.After)) } + case *Wait: + if t.Status() == AbortStatus { + // Would work without it but might take two ensures. + r.tryUndo(t) + } else { + // Default to DoneStatus if no status is set in Wait + waitedStatus := x.WaitedStatus + if waitedStatus == DefaultStatus { + waitedStatus = DoneStatus + } + t.SetToWait(waitedStatus) + } case nil: var next []*Task switch t.Status() { @@ -259,6 +295,11 @@ func (r *TaskRunner) run(t *Task) { r.abortLanes(t.Change(), t.Lanes()) t.SetStatus(ErrorStatus) t.Errorf("%s", err) + // ensure the error is available in the global log too + logger.Noticef("[change %s %q task] failed: %v", t.Change().ID(), t.Summary(), err) + if r.taskErrorCallback != nil { + r.taskErrorCallback(err) + } } return nil @@ -388,6 +429,10 @@ ConsiderTasks: } continue } + if status == WaitStatus { + // nothing more to run + continue + } if mustWait(t) { // Dependencies still unhandled. diff --git a/internals/overlord/state/taskrunner_test.go b/internals/overlord/state/taskrunner_test.go index 38c5685a..2b9de82a 100644 --- a/internals/overlord/state/taskrunner_test.go +++ b/internals/overlord/state/taskrunner_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2016 Canonical Ltd + * Copyright (C) 2016-2022 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -31,11 +31,14 @@ import ( . "gopkg.in/check.v1" "gopkg.in/tomb.v2" - "github.com/canonical/pebble/internals/overlord/restart" + "github.com/canonical/pebble/internals/logger" "github.com/canonical/pebble/internals/overlord/state" + "github.com/canonical/pebble/internals/testutil" ) -type taskRunnerSuite struct{} +type taskRunnerSuite struct { + testutil.BaseTest +} var _ = Suite(&taskRunnerSuite{}) @@ -58,8 +61,6 @@ func (b *stateBackend) EnsureBefore(d time.Duration) { } } -func (b *stateBackend) RequestRestart(t restart.RestartType) {} - func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) { for i := 0; i < 20; i++ { sb.ensureBefore = time.Hour @@ -142,6 +143,10 @@ var sequenceTests = []struct{ setup, result string }{{ result: "t31:undo t32:do t32:do-error t21:undo", }} +func (ts *taskRunnerSuite) SetUpTest(c *C) { + ts.BaseTest.SetUpTest(c) +} + func (ts *taskRunnerSuite) TestSequenceTests(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -185,11 +190,12 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { r.AddHandler("do", fn("do"), nil) r.AddHandler("do-undo", fn("do"), fn("undo")) + past := time.Now().AddDate(-1, 0, 0) for _, test := range sequenceTests { st.Lock() // Delete previous changes. - st.Prune(1, 1, 1) + st.Prune(past, 1, 1, 1) chg := st.NewChange("install", "...") tasks := make(map[string]*state.Task) @@ -342,6 +348,206 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) { } } +func (ts *taskRunnerSuite) TestAbortAcrossLanesDescendantTask(c *C) { + + // () + // t11(1) -> t12(1) => t15(1) + // \ / + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / \ + // t21(2) -> t22(2) => t25(2) + // + names := strings.Fields("t11 t12 t13 t14 t15 t21 t22 t23 t24 t25") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + tasks["t15"].WaitFor(tasks["t14"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t15", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t25", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + tasks["t25"].WaitFor(tasks["t24"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t15"].WaitFor(tasks["t24"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t15" { + ch <- "t15:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:do", + }) + c.Assert(sequence[8:10], testutil.DeepUnsortedMatches, []string{ + "t25:do", + "t15:error", + }) + c.Assert(sequence[10:11], testutil.DeepUnsortedMatches, []string{ + "t25:undo", + }) + c.Assert(sequence[11:15], DeepEquals, []string{ + "t24:undo", "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[15:19], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + +func (ts *taskRunnerSuite) TestAbortAcrossLanesStriclyOrderedTasks(c *C) { + + // () + // t11(1) -> t12(1) + // \ + // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2) + // / + // t21(2) -> t22(2) + // + names := strings.Fields("t11 t12 t13 t14 t21 t22 t23 t24") + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + st.Lock() + defer st.Unlock() + + c.Assert(len(st.Tasks()), Equals, 0) + + chg := st.NewChange("install", "...") + tasks := make(map[string]*state.Task) + for _, name := range names { + tasks[name] = st.NewTask("do", name) + chg.AddTask(tasks[name]) + } + tasks["t12"].WaitFor(tasks["t11"]) + tasks["t13"].WaitFor(tasks["t12"]) + tasks["t14"].WaitFor(tasks["t13"]) + for lane, names := range map[int][]string{ + 1: {"t11", "t12", "t13", "t14", "t23", "t24"}, + 2: {"t21", "t22", "t23", "t24", "t13", "t14"}, + } { + for _, name := range names { + tasks[name].JoinLane(lane) + } + } + + tasks["t22"].WaitFor(tasks["t21"]) + tasks["t23"].WaitFor(tasks["t22"]) + tasks["t24"].WaitFor(tasks["t23"]) + + tasks["t13"].WaitFor(tasks["t22"]) + tasks["t23"].WaitFor(tasks["t14"]) + + ch := make(chan string, 256) + do := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("do %q", task.Summary()) + label := task.Summary() + if label == "t24" { + ch <- "t24:error" + return fmt.Errorf("mock error") + } + ch <- fmt.Sprintf("%s:do", label) + return nil + } + undo := func(task *state.Task, tomb *tomb.Tomb) error { + c.Logf("undo %q", task.Summary()) + label := task.Summary() + ch <- fmt.Sprintf("%s:undo", label) + return nil + } + r.AddHandler("do", do, undo) + + c.Logf("-----") + + st.Unlock() + ensureChange(c, r, sb, chg) + st.Lock() + close(ch) + var sequence []string + for event := range ch { + sequence = append(sequence, event) + } + for _, name := range names { + task := tasks[name] + c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status()) + } + c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{ + "t11:do", "t12:do", + "t21:do", "t22:do", + }) + c.Assert(sequence[4:8], DeepEquals, []string{ + "t13:do", "t14:do", "t23:do", "t24:error", + }) + c.Assert(sequence[8:11], DeepEquals, []string{ + "t23:undo", "t14:undo", "t13:undo", + }) + c.Assert(sequence[11:], testutil.DeepUnsortedMatches, []string{ + "t21:undo", "t22:undo", + "t12:undo", "t11:undo", + }) +} + func (ts *taskRunnerSuite) TestExternalAbort(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -372,6 +578,101 @@ func (ts *taskRunnerSuite) TestExternalAbort(c *C) { ensureChange(c, r, sb, chg) } +func (ts *taskRunnerSuite) TestUndoSingleLane(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("noop", func(t *state.Task, tb *tomb.Tomb) error { + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("noop-slow", func(t *state.Task, tb *tomb.Tomb) error { + time.Sleep(10 * time.Millisecond) + t.State().Lock() + defer t.State().Unlock() + // critical + t.SetStatus(state.DoneStatus) + return nil + }, func(t *state.Task, tb *tomb.Tomb) error { + return nil + }) + + r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error { + return fmt.Errorf("fail") + }, nil) + + st.Lock() + + lane := st.NewLane() + chg := st.NewChange("install", "...") + + // first taskset + var prev *state.Task + for i := 0; i < 10; i++ { + t := st.NewTask("noop-slow", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + + prev = t + } + + // second taskset with a failing task that triggers undo of the change + prev = nil + for i := 0; i < 10; i++ { + t := st.NewTask("noop", "...") + if prev != nil { + t.WaitFor(prev) + } + chg.AddTask(t) + t.JoinLane(lane) + prev = t + } + + // error trigger + t := st.NewTask("fail", "...") + t.WaitFor(prev) + chg.AddTask(t) + t.JoinLane(lane) + + st.Unlock() + + var done bool + for !done { + c.Assert(r.Ensure(), Equals, nil) + st.Lock() + done = chg.IsReady() && chg.IsClean() + st.Unlock() + } + + st.Lock() + defer st.Unlock() + + // make sure all tasks are either undone or on hold (except for "fail" task which + // is in error). + for _, t := range st.Tasks() { + switch t.Kind() { + case "fail": + c.Assert(t.Status(), Equals, state.ErrorStatus) + case "noop", "noop-slow": + if t.Status() != state.UndoneStatus && t.Status() != state.HoldStatus { + for _, tsk := range st.Tasks() { + fmt.Printf("%s -> %s\n", tsk.Kind(), tsk.Status()) + } + c.Fatalf("unexpected status: %s", t.Status()) + } + default: + c.Fatalf("unexpected kind: %s", t.Kind()) + } + } +} + func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) { sb := &stateBackend{} st := state.New(sb) @@ -508,6 +809,55 @@ func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) { c.Check(t.AtTime().IsZero(), Equals, false) } +func (ts *taskRunnerSuite) testTaskReturningWait(c *C, waitedStatus, expectedStatus state.Status) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + defer r.Stop() + + r.AddHandler("ask-for-wait", func(t *state.Task, tb *tomb.Tomb) error { + // ask for wait + return &state.Wait{WaitedStatus: waitedStatus} + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t := st.NewTask("ask-for-wait", "...") + chg.AddTask(t) + st.Unlock() + + r.Ensure() + // wait for handler to finish + r.Wait() + + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(t.WaitedStatus(), Equals, expectedStatus) + c.Check(chg.Status().Ready(), Equals, false) + + st.Unlock() + defer st.Lock() + // does nothing + r.Ensure() + + // state is unchanged + st.Lock() + defer st.Unlock() + c.Check(t.Status(), Equals, state.WaitStatus) + c.Check(chg.Status().Ready(), Equals, false) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitNormal(c *C) { + ts.testTaskReturningWait(c, state.UndoneStatus, state.UndoneStatus) +} + +func (ts *taskRunnerSuite) TestTaskReturningWaitDefaultStatus(c *C) { + // If no state was set (DefaultStatus), then it should default to + // DoneStatus instead. + ts.testTaskReturningWait(c, state.DefaultStatus, state.DoneStatus) +} + func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { ensureBeforeTick := make(chan bool, 1) sb := &stateBackend{ @@ -536,7 +886,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { st.Unlock() tock := time.Now() - restore := state.FakeTime(tock) + restore := state.MockTime(tock) defer restore() r.Ensure() // will run and be rescheduled in a minute select { @@ -554,7 +904,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { schedule := t.AtTime() c.Check(schedule.IsZero(), Equals, false) - state.FakeTime(tock.Add(5 * time.Second)) + state.MockTime(tock.Add(5 * time.Second)) sb.ensureBefore = time.Hour st.Unlock() r.Ensure() // too soon @@ -565,7 +915,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) { c.Check(sb.ensureBefore, Equals, 55*time.Second) c.Check(t.AtTime().Equal(schedule), Equals, true) - state.FakeTime(schedule) + state.MockTime(schedule) sb.ensureBefore = time.Hour st.Unlock() r.Ensure() // time to run again @@ -823,7 +1173,7 @@ func (ts *taskRunnerSuite) TestUndoSequence(c *C) { terr.WaitFor(prev) chg.AddTask(terr) - c.Check(chg.Tasks(), HasLen, 9) // sanity check + c.Check(chg.Tasks(), HasLen, 9) // validity check st.Unlock() @@ -910,3 +1260,71 @@ func (ts *taskRunnerSuite) TestCleanup(c *C) { c.Assert(chgIsClean(), Equals, true) c.Assert(called, Equals, 2) } + +func (ts *taskRunnerSuite) TestErrorCallbackCalledOnError(c *C) { + logbuf, restore := logger.MockLogger("PREFIX: ") + defer restore() + + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return fmt.Errorf("handler error for %q", t.Kind()) + }, nil) + + st.Lock() + chg := st.NewChange("install", "change summary") + t1 := st.NewTask("foo", "task summary") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.ErrorStatus) + c.Check(strings.Join(t1.Log(), ""), Matches, `.*handler error for "foo"`) + c.Check(called, Equals, true) + + c.Check(logbuf.String(), Matches, `(?m).*: \[change 1 "task summary" task\] failed: handler error for "foo".*`) +} + +func (ts *taskRunnerSuite) TestErrorCallbackNotCalled(c *C) { + sb := &stateBackend{} + st := state.New(sb) + r := state.NewTaskRunner(st) + + var called bool + r.OnTaskError(func(err error) { + called = true + }) + + r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error { + return nil + }, nil) + + st.Lock() + chg := st.NewChange("install", "...") + t1 := st.NewTask("foo", "...") + chg.AddTask(t1) + st.Unlock() + + // Mark tasks as done. + ensureChange(c, r, sb, chg) + r.Stop() + + st.Lock() + defer st.Unlock() + + c.Check(t1.Status(), Equals, state.DoneStatus) + c.Check(called, Equals, false) +} diff --git a/internals/overlord/state/warning.go b/internals/overlord/state/warning.go index 69dac7c8..6211bf37 100644 --- a/internals/overlord/state/warning.go +++ b/internals/overlord/state/warning.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2018 Canonical Ltd + * Copyright (C) 2018 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as diff --git a/internals/overlord/state/warning_test.go b/internals/overlord/state/warning_test.go index bf30f69e..26479694 100644 --- a/internals/overlord/state/warning_test.go +++ b/internals/overlord/state/warning_test.go @@ -1,7 +1,7 @@ // -*- Mode: Go; indent-tabs-mode: t -*- /* - * Copyright (c) 2018 Canonical Ltd + * Copyright (C) 2018 Canonical Ltd * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 3 as @@ -99,7 +99,7 @@ func (stateSuite) TestUnmarshalErrors(c *check.C) { } for _, t := range []T1{ - // sanity check + // validity check {`{"message": "x", "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, nil}, // remove one field at a time: {`{ "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, state.ErrNoWarningMessage}, diff --git a/internals/overlord/stateengine.go b/internals/overlord/stateengine.go index 25710333..b468aea5 100644 --- a/internals/overlord/stateengine.go +++ b/internals/overlord/stateengine.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package overlord @@ -30,10 +35,18 @@ type StateManager interface { Ensure() error } +// StateStarterUp is optionally implemented by StateManager that have expensive +// initialization to perform before the main Overlord loop. +type StateStarterUp interface { + // StartUp asks manager to perform any expensive initialization. + StartUp() error +} + // StateWaiter is optionally implemented by StateManagers that have running // activities that can be waited. type StateWaiter interface { - // Wait asks manager to wait for all running activities to finish. + // Wait asks manager to wait for all running activities to + // finish. Wait() } @@ -53,8 +66,9 @@ type StateStopper interface { // cope with Ensure calls in any order, coordinating among themselves // solely via the state. type StateEngine struct { - state *state.State - stopped bool + state *state.State + stopped bool + startedUp bool // managers in use mgrLock sync.Mutex managers []StateManager @@ -72,6 +86,37 @@ func (se *StateEngine) State() *state.State { return se.state } +type startupError struct { + errs []error +} + +func (e *startupError) Error() string { + return fmt.Sprintf("state startup errors: %v", e.errs) +} + +// StartUp asks all managers to perform any expensive initialization. It is a noop after the first invocation. +func (se *StateEngine) StartUp() error { + se.mgrLock.Lock() + defer se.mgrLock.Unlock() + if se.startedUp { + return nil + } + se.startedUp = true + var errs []error + for _, m := range se.managers { + if starterUp, ok := m.(StateStarterUp); ok { + err := starterUp.StartUp() + if err != nil { + errs = append(errs, err) + } + } + } + if len(errs) != 0 { + return &startupError{errs} + } + return nil +} + type ensureError struct { errs []error } @@ -91,6 +136,9 @@ func (e *ensureError) Error() string { func (se *StateEngine) Ensure() error { se.mgrLock.Lock() defer se.mgrLock.Unlock() + if !se.startedUp { + return fmt.Errorf("state engine skipped startup") + } if se.stopped { return fmt.Errorf("state engine already stopped") } diff --git a/internals/overlord/stateengine_test.go b/internals/overlord/stateengine_test.go index f12ed98b..a427e661 100644 --- a/internals/overlord/stateengine_test.go +++ b/internals/overlord/stateengine_test.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2016 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package overlord_test @@ -35,9 +40,14 @@ func (ses *stateEngineSuite) TestNewAndState(c *C) { } type fakeManager struct { - name string - calls *[]string - ensureError, stopError error + name string + calls *[]string + ensureError, startupError error +} + +func (fm *fakeManager) StartUp() error { + *fm.calls = append(*fm.calls, "startup:"+fm.name) + return fm.startupError } func (fm *fakeManager) Ensure() error { @@ -55,6 +65,48 @@ func (fm *fakeManager) Wait() { var _ overlord.StateManager = (*fakeManager)(nil) +func (ses *stateEngineSuite) TestStartUp(c *C) { + s := state.New(nil) + se := overlord.NewStateEngine(s) + + calls := []string{} + + mgr1 := &fakeManager{name: "mgr1", calls: &calls} + mgr2 := &fakeManager{name: "mgr2", calls: &calls} + + se.AddManager(mgr1) + se.AddManager(mgr2) + + err := se.StartUp() + c.Assert(err, IsNil) + c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"}) + + // noop + err = se.StartUp() + c.Assert(err, IsNil) + c.Check(calls, HasLen, 2) +} + +func (ses *stateEngineSuite) TestStartUpError(c *C) { + s := state.New(nil) + se := overlord.NewStateEngine(s) + + calls := []string{} + + err1 := errors.New("boom1") + err2 := errors.New("boom2") + + mgr1 := &fakeManager{name: "mgr1", calls: &calls, startupError: err1} + mgr2 := &fakeManager{name: "mgr2", calls: &calls, startupError: err2} + + se.AddManager(mgr1) + se.AddManager(mgr2) + + err := se.StartUp() + c.Check(err.Error(), DeepEquals, "state startup errors: [boom1 boom2]") + c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"}) +} + func (ses *stateEngineSuite) TestEnsure(c *C) { s := state.New(nil) se := overlord.NewStateEngine(s) @@ -68,6 +120,11 @@ func (ses *stateEngineSuite) TestEnsure(c *C) { se.AddManager(mgr2) err := se.Ensure() + c.Check(err, ErrorMatches, "state engine skipped startup") + c.Assert(se.StartUp(), IsNil) + calls = []string{} + + err = se.Ensure() c.Assert(err, IsNil) c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"}) @@ -91,6 +148,9 @@ func (ses *stateEngineSuite) TestEnsureError(c *C) { se.AddManager(mgr1) se.AddManager(mgr2) + c.Assert(se.StartUp(), IsNil) + calls = []string{} + err := se.Ensure() c.Check(err.Error(), DeepEquals, "state ensure errors: [boom1 boom2]") c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"}) @@ -108,6 +168,9 @@ func (ses *stateEngineSuite) TestStop(c *C) { se.AddManager(mgr1) se.AddManager(mgr2) + c.Assert(se.StartUp(), IsNil) + calls = []string{} + se.Stop() c.Check(calls, DeepEquals, []string{"stop:mgr1", "stop:mgr2"}) se.Stop() diff --git a/internals/systemd/sdnotify.go b/internals/systemd/sdnotify.go index 5f1279cf..ed16aceb 100644 --- a/internals/systemd/sdnotify.go +++ b/internals/systemd/sdnotify.go @@ -27,7 +27,7 @@ var osGetenv = os.Getenv func SocketAvailable() bool { notifySocket := osGetenv("NOTIFY_SOCKET") - return notifySocket != "" && osutil.CanStat(notifySocket) + return notifySocket != "" && osutil.FileExists(notifySocket) } // SdNotify sends the given state string notification to systemd. diff --git a/internals/systemd/systemd.go b/internals/systemd/systemd.go index c5e11114..0811fbee 100644 --- a/internals/systemd/systemd.go +++ b/internals/systemd/systemd.go @@ -713,7 +713,7 @@ func (s *systemd) RemoveMountUnitFile(mountedDir string) error { } unit := MountUnitPath(unitNamePath) - if !osutil.CanStat(unit) { + if !osutil.FileExists(unit) { return nil } diff --git a/internals/systemd/systemd_test.go b/internals/systemd/systemd_test.go index f32c363d..20967271 100644 --- a/internals/systemd/systemd_test.go +++ b/internals/systemd/systemd_test.go @@ -567,7 +567,7 @@ WantedBy=multi-user.target } func (s *SystemdTestSuite) TestFuseInContainer(c *C) { - if !osutil.CanStat("/dev/fuse") { + if !osutil.FileExists("/dev/fuse") { c.Skip("No /dev/fuse on the system") } @@ -713,7 +713,7 @@ func (s *SystemdTestSuite) TestRemoveMountUnit(c *C) { c.Assert(err, IsNil) // the file is gone - c.Check(osutil.CanStat(mountUnit), Equals, false) + c.Check(osutil.FileExists(mountUnit), Equals, false) // and the unit is disabled and the daemon reloaded c.Check(s.argses, DeepEquals, [][]string{ {"--root", s.rootDir, "disable", "snap-foo-42.mount"}, diff --git a/internals/testutil/containschecker.go b/internals/testutil/containschecker.go index 82ff2ee5..644ef158 100644 --- a/internals/testutil/containschecker.go +++ b/internals/testutil/containschecker.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2015-2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package testutil @@ -81,7 +86,7 @@ func (c *containsChecker) Check(params []interface{}, names []string) (result bo var container interface{} = params[0] var elem interface{} = params[1] if commonEquals(container, elem, &result, &error) { - return + return result, error } // Do the actual test using == switch containerV := reflect.ValueOf(container); containerV.Kind() { @@ -111,7 +116,7 @@ type deepContainsChecker struct { } // DeepContains is a Checker that looks for a elem in a container using -// DeepEqual. The elem can be any object. The container can be an array, slice +// DeepEqual. The elem can be any object. The container can be an array, slice // or string. var DeepContains check.Checker = &deepContainsChecker{ &check.CheckerInfo{Name: "DeepContains", Params: []string{"container", "elem"}}, @@ -121,7 +126,7 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul var container interface{} = params[0] var elem interface{} = params[1] if commonEquals(container, elem, &result, &error) { - return + return result, error } // Do the actual test using reflect.DeepEqual switch containerV := reflect.ValueOf(container); containerV.Kind() { @@ -145,3 +150,126 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul return false, fmt.Sprintf("%T is not a supported container", container) } } + +type deepUnsortedMatchChecker struct { + *check.CheckerInfo +} + +// DeepUnsortedMatches checks if two containers contain the same elements in +// the same number (but possibly different order) using DeepEqual. The container +// can be an array, a slice or a map. +var DeepUnsortedMatches check.Checker = &deepUnsortedMatchChecker{ + &check.CheckerInfo{Name: "DeepUnsortedMatches", Params: []string{"container1", "container2"}}, +} + +func (c *deepUnsortedMatchChecker) Check(params []interface{}, _ []string) (bool, string) { + container1 := reflect.ValueOf(params[0]) + container2 := reflect.ValueOf(params[1]) + + // if both args are nil, return true + if container1.Kind() == reflect.Invalid && container2.Kind() == reflect.Invalid { + return true, "" + } + + if container1.Kind() == reflect.Invalid || container2.Kind() == reflect.Invalid { + return false, "only one container was nil" + } + + if container1.Kind() != container2.Kind() { + return false, fmt.Sprintf("containers are of different types: %s != %s", container1.Kind(), container2.Kind()) + } + + switch container1.Kind() { + case reflect.Array, reflect.Slice: + return deepSequenceMatch(container1, container2) + case reflect.Map: + return deepMapMatch(container1, container2) + default: + return false, fmt.Sprintf("'%s' is not a supported type: must be slice, array or map", container1.Kind().String()) + } +} + +func deepMapMatch(container1, container2 reflect.Value) (bool, string) { + if valid, output := validateContainerTypesAndLengths(container1, container2); !valid { + return false, output + } + + switch container1.Type().Elem().Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + // only run the unsorted match if the map values are containers + default: + if !reflect.DeepEqual(container1.Interface(), container2.Interface()) { + return false, "maps don't match" + } + return true, "" + } + + for _, key := range container1.MapKeys() { + el1 := container1.MapIndex(key) + el2 := container2.MapIndex(key) + + absent := el2 == reflect.Value{} + if absent { + return false, fmt.Sprintf("key %q from one map is absent from the other map", key) + } + + var ok bool + var msg string + switch el1.Kind() { + case reflect.Array, reflect.Slice: + ok, msg = deepSequenceMatch(el1, el2) + case reflect.Map: + ok, msg = deepMapMatch(el1, el2) + } + + if !ok { + return false, msg + } + } + + return true, "" +} + +func deepSequenceMatch(container1, container2 reflect.Value) (bool, string) { + if valid, output := validateContainerTypesAndLengths(container1, container2); !valid { + return false, output + } + + matched := make([]bool, container1.Len()) +out: + for i := 0; i < container1.Len(); i++ { + el1 := container1.Index(i).Interface() + + for e := 0; e < container2.Len(); e++ { + el2 := container2.Index(e).Interface() + + if !matched[e] && reflect.DeepEqual(el1, el2) { + // mark already matched elements, so that duplicate elements in + // one container can't be matched to the same element in the other. + matched[e] = true + continue out + } + } + + return false, fmt.Sprintf("element [%d]=%s was unmatched in the second container", i, el1) + } + + return true, "" +} + +func validateContainerTypesAndLengths(container1, container2 reflect.Value) (bool, string) { + if container1.Len() != container2.Len() { + return false, fmt.Sprintf("containers have different lengths: %d != %d", container1.Len(), container2.Len()) + } else if container1.Type().Elem() != container2.Type().Elem() { + return false, fmt.Sprintf("containers have different element types: %s != %s", container1.Type().Elem(), container2.Type().Elem()) + } + + if container1.Kind() == reflect.Map && container2.Kind() == reflect.Map { + keyType1, keyType2 := container1.Type().Key(), container2.Type().Key() + if keyType1 != keyType2 { + return false, fmt.Sprintf("maps have different key types: %s != %s", keyType1, keyType2) + } + } + + return true, "" +} diff --git a/internals/testutil/containschecker_test.go b/internals/testutil/containschecker_test.go index 86c5c0cc..8a94c897 100644 --- a/internals/testutil/containschecker_test.go +++ b/internals/testutil/containschecker_test.go @@ -1,16 +1,21 @@ -// Copyright (c) 2014-2020 Canonical Ltd -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License version 3 as -// published by the Free Software Foundation. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2015-2018 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ package testutil_test @@ -216,3 +221,152 @@ func (*containsCheckerSuite) TestDeepContainsUncomparableType(c *check.C) { testCheck(c, DeepContains, true, "", containerSlice, elem) testCheck(c, DeepContains, true, "", containerMap, elem) } + +type example struct { + a string + b map[string]int +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceSuccess(c *check.C) { + slice1 := []example{ + {a: "one", b: map[string]int{"a": 1}}, + {a: "two", b: map[string]int{"b": 2}}, + } + slice2 := []example{ + {a: "two", b: map[string]int{"b": 2}}, + {a: "one", b: map[string]int{"a": 1}}, + } + + c.Check(slice1, DeepUnsortedMatches, slice2) + c.Check(slice2, DeepUnsortedMatches, slice1) + c.Check([]string{"a", "a"}, DeepUnsortedMatches, []string{"a", "a"}) + c.Check([]string{"a", "b", "a"}, DeepUnsortedMatches, []string{"b", "a", "a"}) + slice := [1]int{1} + c.Check(slice, DeepUnsortedMatches, slice) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceFailure(c *check.C) { + slice1 := []string{"a", "a", "b"} + slice2 := []string{"b", "a", "c"} + + testCheck(c, DeepUnsortedMatches, false, "element [1]=a was unmatched in the second container", slice1, slice2) + testCheck(c, DeepUnsortedMatches, false, "element [2]=c was unmatched in the second container", slice2, slice1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapSuccess(c *check.C) { + map1 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + "c": {a: "c", b: map[string]int{"c": 3, "d": 4}}, + } + map2 := map[string]example{ + "c": {a: "c", b: map[string]int{"c": 3, "d": 4}}, + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + } + + c.Check(map1, DeepUnsortedMatches, map2) + c.Check(map2, DeepUnsortedMatches, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapStructFail(c *check.C) { + map1 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 2, "b": 1}}, + } + map2 := map[string]example{ + "a": {a: "a", b: map[string]int{"a": 1, "b": 2}}, + } + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedKeyFailure(c *check.C) { + map1 := map[string]int{"a": 1, "c": 2} + map2 := map[string]int{"a": 1, "b": 2} + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedValueFailure(c *check.C) { + map1 := map[string]int{"a": 1, "b": 2} + map2 := map[string]int{"a": 1, "b": 3} + + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2) + testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentTypeFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers are of different types: slice != array", []int{}, [1]int{}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentElementType(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers have different element types: int != string", []int{1}, []string{"a"}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentLengthFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 1 != 2", []int{1}, []int{1, 1}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNilArgFailure(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "only one container was nil", nil, []int{1}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesBothNilArgSuccess(c *check.C) { + c.Check(nil, DeepUnsortedMatches, nil) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNonContainerValues(c *check.C) { + testCheck(c, DeepUnsortedMatches, false, "'string' is not a supported type: must be slice, array or map", "a", "a") + testCheck(c, DeepUnsortedMatches, false, "'int' is not a supported type: must be slice, array or map", 1, 2) + testCheck(c, DeepUnsortedMatches, false, "'bool' is not a supported type: must be slice, array or map", true, false) + testCheck(c, DeepUnsortedMatches, false, "'ptr' is not a supported type: must be slice, array or map", &[]string{"a", "b"}, &[]string{"a", "b"}) + testCheck(c, DeepUnsortedMatches, false, "'func' is not a supported type: must be slice, array or map", func() {}, func() {}) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsOfSlices(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}} + map2 := map[string][]string{"a": {"bar", "foo"}, "b": {"bar", "foo"}} + + c.Check(map1, DeepUnsortedMatches, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentKeyTypes(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[int][]string{1: {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "maps have different key types: string != int", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentValueTypes(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[string][2]string{"a": {"foo", "bar"}} + + testCheck(c, DeepUnsortedMatches, false, "containers have different element types: []string != [2]string", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentLengths(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}} + map2 := map[string][]string{"a": {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 2 != 1", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsMissingKey(c *check.C) { + map1 := map[string][]string{"a": {"foo", "bar"}} + map2 := map[string][]string{"b": {"bar", "foo"}} + + testCheck(c, DeepUnsortedMatches, false, "key \"a\" from one map is absent from the other map", map1, map2) +} + +func (*containsCheckerSuite) TestDeepUnsortedMatchesNestedMaps(c *check.C) { + map1 := map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}} + map2 := map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}} + c.Check(map1, DeepUnsortedMatches, map2) + + map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}} + map2 = map[string]map[string][]string{"a": {"c": []string{"bar", "foo"}}} + testCheck(c, DeepUnsortedMatches, false, "key \"b\" from one map is absent from the other map", map1, map2) + + map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}, "c": {"b": []string{"foo"}}} + map2 = map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}, "c": {"b": []string{"bar"}}} + testCheck(c, DeepUnsortedMatches, false, "element [0]=foo was unmatched in the second container", map1, map2) +} diff --git a/internals/testutil/errorischecker.go b/internals/testutil/errorischecker.go new file mode 100644 index 00000000..e22030c6 --- /dev/null +++ b/internals/testutil/errorischecker.go @@ -0,0 +1,53 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2022 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package testutil + +import ( + "errors" + + "gopkg.in/check.v1" +) + +// ErrorIs calls errors.Is with the provided arguments. +var ErrorIs = &errorIsChecker{ + &check.CheckerInfo{Name: "ErrorIs", Params: []string{"error", "target"}}, +} + +type errorIsChecker struct { + *check.CheckerInfo +} + +func (*errorIsChecker) Check(params []interface{}, names []string) (result bool, errMsg string) { + if params[0] == nil { + return params[1] == nil, "" + } + + err, ok := params[0].(error) + if !ok { + return false, "first argument must be an error" + } + + target, ok := params[1].(error) + if !ok { + return false, "second argument must be an error" + } + + return errors.Is(err, target), "" +} diff --git a/internals/testutil/errorischecker_test.go b/internals/testutil/errorischecker_test.go new file mode 100644 index 00000000..5f2fc28c --- /dev/null +++ b/internals/testutil/errorischecker_test.go @@ -0,0 +1,72 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2022 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package testutil_test + +import ( + "errors" + + . "github.com/canonical/pebble/internals/testutil" + + . "gopkg.in/check.v1" +) + +type errorIsCheckerSuite struct{} + +var _ = Suite(&errorIsCheckerSuite{}) + +type baseError struct{} + +func (baseError) Error() string { return "" } + +func (baseError) Is(err error) bool { + _, ok := err.(baseError) + return ok +} + +type wrapperError struct { + err error +} + +func (*wrapperError) Error() string { return "" } + +func (e *wrapperError) Unwrap() error { return e.err } + +func (*errorIsCheckerSuite) TestErrorIsCheckSucceeds(c *C) { + testInfo(c, ErrorIs, "ErrorIs", []string{"error", "target"}) + + c.Assert(baseError{}, ErrorIs, baseError{}) + err := &wrapperError{err: baseError{}} + c.Assert(err, ErrorIs, baseError{}) +} + +func (*errorIsCheckerSuite) TestErrorIsCheckFails(c *C) { + c.Assert(nil, Not(ErrorIs), baseError{}) + c.Assert(errors.New(""), Not(ErrorIs), baseError{}) +} + +func (*errorIsCheckerSuite) TestErrorIsWithInvalidArguments(c *C) { + res, errMsg := ErrorIs.Check([]interface{}{"", errors.New("")}, []string{"error", "target"}) + c.Assert(res, Equals, false) + c.Assert(errMsg, Equals, "first argument must be an error") + + res, errMsg = ErrorIs.Check([]interface{}{errors.New(""), ""}, []string{"error", "target"}) + c.Assert(res, Equals, false) + c.Assert(errMsg, Equals, "second argument must be an error") +}