diff --git a/internals/cli/cmd_run.go b/internals/cli/cmd_run.go
index 2c1f2083..62d513a5 100644
--- a/internals/cli/cmd_run.go
+++ b/internals/cli/cmd_run.go
@@ -182,7 +182,10 @@ func runDaemon(rcmd *cmdRun, ch chan os.Signal, ready chan<- func()) error {
}
d.Version = cmd.Version
- d.Start()
+ err = d.Start()
+ if err != nil {
+ return err
+ }
watchdog, err := runWatchdog(d)
if err != nil {
diff --git a/internals/daemon/api_changes_test.go b/internals/daemon/api_changes_test.go
index 0f5f6228..f52788ff 100644
--- a/internals/daemon/api_changes_test.go
+++ b/internals/daemon/api_changes_test.go
@@ -46,7 +46,7 @@ func setupChanges(st *state.State) []string {
}
func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -75,7 +75,7 @@ func (s *apiSuite) TestStateChangesDefaultToInProgress(c *check.C) {
}
func (s *apiSuite) TestStateChangesInProgress(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -104,7 +104,7 @@ func (s *apiSuite) TestStateChangesInProgress(c *check.C) {
}
func (s *apiSuite) TestStateChangesAll(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -133,7 +133,7 @@ func (s *apiSuite) TestStateChangesAll(c *check.C) {
}
func (s *apiSuite) TestStateChangesReady(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -161,7 +161,7 @@ func (s *apiSuite) TestStateChangesReady(c *check.C) {
}
func (s *apiSuite) TestStateChangesForServiceName(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -192,7 +192,7 @@ func (s *apiSuite) TestStateChangesForServiceName(c *check.C) {
}
func (s *apiSuite) TestStateChange(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
@@ -261,7 +261,7 @@ func (s *apiSuite) TestStateChange(c *check.C) {
}
func (s *apiSuite) TestStateChangeAbort(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
soon := 0
@@ -334,7 +334,7 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) {
}
func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) {
- restore := state.FakeTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
+ restore := state.MockTime(time.Date(2016, 04, 21, 1, 2, 3, 0, time.UTC))
defer restore()
// Setup
diff --git a/internals/daemon/api_files_test.go b/internals/daemon/api_files_test.go
index ab41554d..26157f2b 100644
--- a/internals/daemon/api_files_test.go
+++ b/internals/daemon/api_files_test.go
@@ -625,7 +625,7 @@ func (s *filesSuite) TestRemoveSingle(c *C) {
c.Check(r.Result, HasLen, 1)
checkFileResult(c, r.Result[0], tmpDir+"/file", "", "")
- c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false)
+ c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false)
}
func (s *filesSuite) TestRemoveMultiple(c *C) {
@@ -667,7 +667,7 @@ func (s *filesSuite) TestRemoveMultiple(c *C) {
checkFileResult(c, r.Result[2], tmpDir+"/non-empty", "generic-file-error", ".*directory not empty")
checkFileResult(c, r.Result[3], tmpDir+"/recursive", "", "")
- c.Check(osutil.CanStat(tmpDir+"/file"), Equals, false)
+ c.Check(osutil.FileExists(tmpDir+"/file"), Equals, false)
c.Check(osutil.IsDir(tmpDir+"/empty"), Equals, false)
c.Check(osutil.IsDir(tmpDir+"/non-empty"), Equals, true)
c.Check(osutil.IsDir(tmpDir+"/recursive"), Equals, false)
@@ -1186,10 +1186,10 @@ group not found
checkFileResult(c, r.Result[4], pathUserNotFound, "generic-file-error", ".*unknown user.*")
checkFileResult(c, r.Result[5], pathGroupNotFound, "generic-file-error", ".*unknown group.*")
- c.Check(osutil.CanStat(pathNoContent), Equals, false)
- c.Check(osutil.CanStat(pathNotAbsolute), Equals, false)
- c.Check(osutil.CanStat(pathNotFound), Equals, false)
- c.Check(osutil.CanStat(pathPermissionDenied), Equals, false)
+ c.Check(osutil.FileExists(pathNoContent), Equals, false)
+ c.Check(osutil.FileExists(pathNotAbsolute), Equals, false)
+ c.Check(osutil.FileExists(pathNotFound), Equals, false)
+ c.Check(osutil.FileExists(pathPermissionDenied), Equals, false)
}
func assertFile(c *C, path string, perm os.FileMode, content string) {
diff --git a/internals/daemon/api_test.go b/internals/daemon/api_test.go
index ed7ebd7f..38735743 100644
--- a/internals/daemon/api_test.go
+++ b/internals/daemon/api_test.go
@@ -58,6 +58,7 @@ func (s *apiSuite) daemon(c *check.C) *Daemon {
}
d, err := New(&Options{Dir: s.pebbleDir})
c.Assert(err, check.IsNil)
+ c.Assert(d.Overlord().StartUp(), check.IsNil)
d.addRoutes()
s.d = d
return d
diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go
index b7405ee0..dc46030f 100644
--- a/internals/daemon/daemon.go
+++ b/internals/daemon/daemon.go
@@ -385,6 +385,10 @@ func (d *Daemon) Init() error {
return nil
}
+func (d *Daemon) Overlord() *overlord.Overlord {
+ return d.overlord
+}
+
// SetDegradedMode puts the daemon into a degraded mode which will the
// error given in the "err" argument for commands that are not marked
// as readonlyOK.
@@ -452,13 +456,13 @@ func (d *Daemon) initStandbyHandling() {
d.standbyOpinions.Start()
}
-func (d *Daemon) Start() {
+func (d *Daemon) Start() error {
if d.rebootIsMissing {
// we need to schedule and wait for a system restart
d.tomb.Kill(nil)
// avoid systemd killing us again while we wait
systemdSdNotify("READY=1")
- return
+ return nil
}
if d.overlord == nil {
panic("internal error: no Overlord")
@@ -466,6 +470,10 @@ func (d *Daemon) Start() {
d.StartTime = time.Now()
+ // now perform expensive overlord/manages initialization
+ if err := d.overlord.StartUp(); err != nil {
+ return err
+ }
d.connTracker = &connTracker{conns: make(map[net.Conn]struct{})}
d.serve = &http.Server{
Handler: logit(d.router),
@@ -505,6 +513,7 @@ func (d *Daemon) Start() {
// notify systemd that we are ready
systemdSdNotify("READY=1")
+ return nil
}
// HandleRestart implements overlord.RestartBehavior.
@@ -673,7 +682,7 @@ func (d *Daemon) rebootDelay() (time.Duration, error) {
// see whether a reboot had already been scheduled
var rebootAt time.Time
err := d.state.Get("daemon-system-restart-at", &rebootAt)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return 0, err
}
rebootDelay := 1 * time.Minute
@@ -758,7 +767,7 @@ var errExpectedReboot = errors.New("expected reboot did not happen")
func (d *Daemon) RebootIsMissing(st *state.State) error {
var nTentative int
err := st.Get("daemon-system-restart-tentative", &nTentative)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
nTentative++
diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go
index c10d04d9..5e6c8f1e 100644
--- a/internals/daemon/daemon_test.go
+++ b/internals/daemon/daemon_test.go
@@ -103,7 +103,8 @@ func (s *daemonSuite) TestExplicitPaths(c *C) {
d := s.newDaemon(c)
d.Init()
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
defer d.Stop(nil)
info, err := os.Stat(s.socketPath)
@@ -459,7 +460,8 @@ func (s *daemonSuite) TestStartStop(c *check.C) {
untrustedAccept := make(chan struct{})
d.untrustedListener = &witnessAcceptListener{Listener: l2, accept: untrustedAccept}
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
generalDone := make(chan struct{})
go func() {
@@ -500,7 +502,8 @@ func (s *daemonSuite) TestRestartWiring(c *check.C) {
untrustedAccept := make(chan struct{})
d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept}
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
defer d.Stop(nil)
generalDone := make(chan struct{})
@@ -567,7 +570,8 @@ func (s *daemonSuite) TestGracefulStop(c *check.C) {
untrustedAccept := make(chan struct{})
d.untrustedListener = &witnessAcceptListener{Listener: untrustedL, accept: untrustedAccept}
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
generalAccepting := make(chan struct{})
go func() {
@@ -634,7 +638,8 @@ func (s *daemonSuite) TestRestartSystemWiring(c *check.C) {
untrustedAccept := make(chan struct{})
d.untrustedListener = &witnessAcceptListener{Listener: l, accept: untrustedAccept}
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
defer d.Stop(nil)
st := d.overlord.State()
@@ -781,7 +786,8 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) {
d := s.newDaemon(c)
makeDaemonListeners(c, d)
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
st := d.overlord.State()
st.Lock()
@@ -791,7 +797,7 @@ func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) {
ch := make(chan os.Signal, 2)
ch <- syscall.SIGTERM
// stop will check if we got a sigterm in between (which we did)
- err := d.Stop(ch)
+ err = d.Stop(ch)
c.Assert(err, check.IsNil)
}
@@ -813,7 +819,8 @@ func (s *daemonSuite) TestRestartShutdown(c *check.C) {
d := s.newDaemon(c)
makeDaemonListeners(c, d)
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
st := d.overlord.State()
st.Lock()
@@ -860,7 +867,8 @@ func (s *daemonSuite) TestRestartExpectedRebootIsMissing(c *check.C) {
c.Check(err, check.IsNil)
c.Check(n, check.Equals, 1)
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
c.Check(s.notified, check.DeepEquals, []string{"READY=1"})
@@ -896,8 +904,8 @@ func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) {
defer st.Unlock()
var v interface{}
// these were cleared
- c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
- c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState)
}
func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) {
@@ -920,9 +928,9 @@ func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) {
defer st.Unlock()
var v interface{}
// these were cleared
- c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
- c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
- c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-at", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &v), testutil.ErrorIs, state.ErrNoState)
+ c.Check(st.Get("daemon-system-restart-tentative", &v), testutil.ErrorIs, state.ErrNoState)
}
func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
@@ -936,7 +944,8 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
d := s.newDaemon(c)
makeDaemonListeners(c, d)
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
// pretend some ensure happened
for i := 0; i < 5; i++ {
@@ -955,7 +964,7 @@ func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
case <-time.After(15 * time.Second):
c.Errorf("daemon did not stop after 15s")
}
- err := d.Stop(nil)
+ err = d.Stop(nil)
c.Check(err, check.Equals, ErrRestartSocket)
c.Check(d.restartSocket, check.Equals, true)
}
@@ -972,7 +981,8 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) {
st := d.overlord.State()
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
// pretend some ensure happened
for i := 0; i < 5; i++ {
d.overlord.StateEngine().Ensure()
@@ -998,7 +1008,7 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) {
c.Errorf("daemon did not stop after 5s")
}
// when the daemon got a pending change it just restarts
- err := d.Stop(nil)
+ err = d.Stop(nil)
c.Check(err, check.IsNil)
c.Check(d.restartSocket, check.Equals, false)
}
@@ -1058,7 +1068,8 @@ func (s *daemonSuite) TestHTTPAPI(c *check.C) {
s.httpAddress = ":0" // Go will choose port (use listener.Addr() to find it)
d := s.newDaemon(c)
d.Init()
- d.Start()
+ err := d.Start()
+ c.Assert(err, check.IsNil)
port := d.httpListener.Addr().(*net.TCPAddr).Port
request, err := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d/v1/health", port), nil)
@@ -1101,7 +1112,8 @@ services:
d := s.newDaemon(c)
err := d.Init()
c.Assert(err, IsNil)
- d.Start()
+ err = d.Start()
+ c.Assert(err, check.IsNil)
// Start the test service.
payload := bytes.NewBufferString(`{"action": "start", "services": ["test1"]}`)
diff --git a/internals/jsonutil/json.go b/internals/jsonutil/json.go
new file mode 100644
index 00000000..056ab0a4
--- /dev/null
+++ b/internals/jsonutil/json.go
@@ -0,0 +1,66 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2017 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package jsonutil
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+
+ "github.com/canonical/x-go/strutil"
+)
+
+// DecodeWithNumber decodes input data using json.Decoder, ensuring numbers are preserved
+// via json.Number data type. It errors out on invalid json or any excess input.
+func DecodeWithNumber(r io.Reader, value interface{}) error {
+ dec := json.NewDecoder(r)
+ dec.UseNumber()
+ if err := dec.Decode(&value); err != nil {
+ return err
+ }
+ if dec.More() {
+ return fmt.Errorf("cannot parse json value")
+ }
+ return nil
+}
+
+// StructFields takes a pointer to a struct and a list of exceptions,
+// and returns a list of the fields in the struct that are JSON-tagged
+// and whose tag is not in the list of exceptions.
+// The struct can be nil.
+func StructFields(s interface{}, exceptions ...string) []string {
+ st := reflect.TypeOf(s).Elem()
+ num := st.NumField()
+ fields := make([]string, 0, num)
+ for i := 0; i < num; i++ {
+ tag := st.Field(i).Tag.Get("json")
+ idx := strings.IndexByte(tag, ',')
+ if idx > -1 {
+ tag = tag[:idx]
+ }
+ if tag != "" && !strutil.ListContains(exceptions, tag) {
+ fields = append(fields, tag)
+ }
+ }
+
+ return fields
+}
diff --git a/internals/jsonutil/json_test.go b/internals/jsonutil/json_test.go
new file mode 100644
index 00000000..2fb35c0b
--- /dev/null
+++ b/internals/jsonutil/json_test.go
@@ -0,0 +1,90 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2017 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package jsonutil_test
+
+import (
+ "encoding/json"
+ "strings"
+ "testing"
+
+ . "gopkg.in/check.v1"
+
+ "github.com/canonical/pebble/internals/jsonutil"
+)
+
+func Test(t *testing.T) { TestingT(t) }
+
+type utilSuite struct{}
+
+var _ = Suite(&utilSuite{})
+
+func (s *utilSuite) TestDecodeError(c *C) {
+ input := "{]"
+ var output interface{}
+ err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output)
+ c.Assert(err, NotNil)
+ c.Assert(err, ErrorMatches, `invalid character ']' looking for beginning of object key string`)
+}
+
+func (s *utilSuite) TestDecodeErrorOnExcessData(c *C) {
+ input := "1000000000[1,2]"
+ var output interface{}
+ err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output)
+ c.Assert(err, NotNil)
+ c.Assert(err, ErrorMatches, `cannot parse json value`)
+}
+
+func (s *utilSuite) TestDecodeSuccess(c *C) {
+ input := `{"a":1000000000, "b": 1.2, "c": "foo", "d":null}`
+ var output interface{}
+ err := jsonutil.DecodeWithNumber(strings.NewReader(input), &output)
+ c.Assert(err, IsNil)
+ c.Assert(output, DeepEquals, map[string]interface{}{
+ "a": json.Number("1000000000"),
+ "b": json.Number("1.2"),
+ "c": "foo",
+ "d": nil,
+ })
+}
+
+func (utilSuite) TestStructFields(c *C) {
+ type aStruct struct {
+ Foo int `json:"hello"`
+ Bar int `json:"potato,stuff"`
+ }
+ c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello", "potato"})
+}
+
+func (utilSuite) TestStructFieldsExcept(c *C) {
+ type aStruct struct {
+ Foo int `json:"hello"`
+ Bar int `json:"potato,stuff"`
+ }
+ c.Assert(jsonutil.StructFields((*aStruct)(nil), "potato"), DeepEquals, []string{"hello"})
+ c.Assert(jsonutil.StructFields((*aStruct)(nil), "hello"), DeepEquals, []string{"potato"})
+}
+
+func (utilSuite) TestStructFieldsSurvivesNoTag(c *C) {
+ type aStruct struct {
+ Foo int `json:"hello"`
+ Bar int
+ }
+ c.Assert(jsonutil.StructFields((*aStruct)(nil)), DeepEquals, []string{"hello"})
+}
diff --git a/internals/jsonutil/safejson/safejson.go b/internals/jsonutil/safejson/safejson.go
new file mode 100644
index 00000000..298e6945
--- /dev/null
+++ b/internals/jsonutil/safejson/safejson.go
@@ -0,0 +1,202 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package safejson
+
+import (
+ "fmt"
+ "strconv"
+ "unicode"
+ "unicode/utf16"
+ "unicode/utf8"
+
+ "github.com/canonical/x-go/strutil"
+)
+
+// String accepts any valid JSON string. Its Clean method will remove
+// characters that aren't expected in a short descriptive text.
+// I.e.: Cc, Co, Cf, Cs, noncharacters, and � (U+FFFD, the replacement
+// character) are removed.
+type String struct {
+ s string
+}
+
+func (str *String) UnmarshalJSON(in []byte) (err error) {
+ str.s, err = unmarshal(in, uOpt{})
+ return
+}
+
+// Clean returns the string, with Cc, Co, Cf, Cs, noncharacters,
+// and � (U+FFFD) removed.
+func (str String) Clean() string {
+ return str.s
+}
+
+// Paragraph accepts any valid JSON string. Its Clean method will remove
+// characters that aren't expected in a long descriptive text.
+// I.e.: Cc (except for \n), Co, Cf, Cs, noncharacters, and � (U+FFFD,
+// the replacement character) are removed.
+type Paragraph struct {
+ s string
+}
+
+func (par *Paragraph) UnmarshalJSON(in []byte) (err error) {
+ par.s, err = unmarshal(in, uOpt{nlOK: true})
+ return
+}
+
+// Clean returns the string, with Cc minus \n, Co, Cf, Cs, noncharacters,
+// and � (U+FFFD) removed.
+func (par Paragraph) Clean() string {
+ return par.s
+}
+
+func unescapeUCS2(in []byte) (rune, bool) {
+ if len(in) < 6 || in[0] != '\\' || in[1] != 'u' {
+ return -1, false
+ }
+ u, err := strconv.ParseUint(string(in[2:6]), 16, 32)
+ if err != nil {
+ return -1, false
+ }
+ return rune(u), true
+}
+
+type uOpt struct {
+ nlOK bool
+ simple bool
+}
+
+func unmarshal(in []byte, o uOpt) (string, error) {
+ // heavily based on (inspired by?) unquoteBytes from encoding/json
+
+ if len(in) < 2 || in[0] != '"' || in[len(in)-1] != '"' {
+ // maybe it's a null and that's alright
+ if len(in) == 4 && in[0] == 'n' && in[1] == 'u' && in[2] == 'l' && in[3] == 'l' {
+ return "", nil
+ }
+ return "", fmt.Errorf("missing string delimiters: %q", in)
+ }
+
+ // prune the quotes
+ in = in[1 : len(in)-1]
+ i := 0
+ // try the fast track
+ for i < len(in) {
+ // 0x00..0x19 is the first of Cc
+ // 0x20..0x7e is all of printable ASCII (minus control chars)
+ if in[i] < 0x20 || in[i] > 0x7e || in[i] == '\\' || in[i] == '"' {
+ break
+ }
+ i++
+ }
+ if i == len(in) {
+ // wee
+ return string(in), nil
+ }
+ if o.simple {
+ return "", fmt.Errorf("character %q in string %q unsupported for this value", in[i], in)
+ }
+ // in[i] is the first problematic one
+ out := make([]byte, i, len(in)+2*utf8.UTFMax)
+ copy(out, in)
+ var r, r2 rune
+ var n int
+ var c byte
+ var ubuf [utf8.UTFMax]byte
+ var ok bool
+ for i < len(in) {
+ c = in[i]
+ switch {
+ case c == '"':
+ return "", fmt.Errorf("unexpected unescaped quote at %d in \"%s\"", i, in)
+ case c < 0x20:
+ return "", fmt.Errorf("unexpected control character at %d in %q", i, in)
+ case c == '\\':
+ // handle escapes
+ i++
+ if i == len(in) {
+ return "", fmt.Errorf("unexpected end of string (trailing backslash) in \"%s\"", in)
+ }
+ switch in[i] {
+ case 'u':
+ // oh dear, a unicode wotsit
+ r, ok = unescapeUCS2(in[i-1:])
+ if !ok {
+ x := in[i-1:]
+ if len(x) > 6 {
+ x = x[:6]
+ }
+ return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in)
+ }
+ i += 5
+ if utf16.IsSurrogate(r) {
+ // sigh
+ r2, ok = unescapeUCS2(in[i:])
+ if !ok {
+ x := in[i:]
+ if len(x) > 6 {
+ x = x[:6]
+ }
+ return "", fmt.Errorf(`badly formed \u escape %q at %d of "%s"`, x, i, in)
+ }
+ i += 6
+ r = utf16.DecodeRune(r, r2)
+ }
+ if r <= 0x9f {
+ // otherwise, it's Cc (both halves, as we're looking at runes)
+ if (o.nlOK && r == '\n') || (r >= 0x20 && r <= 0x7e) {
+ out = append(out, byte(r))
+ }
+ } else if r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) {
+ n = utf8.EncodeRune(ubuf[:], r)
+ out = append(out, ubuf[:n]...)
+ }
+ case 'b', 'f', 'r', 't':
+ // do nothing
+ i++
+ case 'n':
+ if o.nlOK {
+ out = append(out, '\n')
+ }
+ i++
+ case '"', '/', '\\':
+ // the spec says just ", / and \ can be backslash-escaped
+ // but go adds ' to the list (in unquoteBytes)
+ out = append(out, in[i])
+ i++
+ default:
+ return "", fmt.Errorf(`unknown escape '%c' at %d of "%s"`, in[i], i, in)
+ }
+ case c <= 0x7e:
+ // printable ASCII, except " or \
+ out = append(out, c)
+ i++
+ default:
+ r, n = utf8.DecodeRune(in[i:])
+ j := i + n
+ if r > 0x9f && r != unicode.ReplacementChar && !unicode.Is(strutil.Ctrl, r) {
+ out = append(out, in[i:j]...)
+ }
+ i = j
+ }
+ }
+
+ return string(out), nil
+}
diff --git a/internals/jsonutil/safejson/safejson_test.go b/internals/jsonutil/safejson/safejson_test.go
new file mode 100644
index 00000000..aae12aea
--- /dev/null
+++ b/internals/jsonutil/safejson/safejson_test.go
@@ -0,0 +1,148 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package safejson_test
+
+import (
+ "encoding/json"
+ "testing"
+
+ "gopkg.in/check.v1"
+
+ "github.com/canonical/pebble/internals/jsonutil/safejson"
+)
+
+func Test(t *testing.T) { check.TestingT(t) }
+
+type escapeSuite struct{}
+
+var _ = check.Suite(escapeSuite{})
+
+var table = map[string]string{
+ "null": "",
+ `"hello"`: "hello",
+ `"árbol"`: "árbol",
+ `"\u0020"`: " ",
+ `"\uD83D\uDE00"`: "😀",
+ `"a\b\r\tb"`: "ab",
+ `"\\\""`: `\"`,
+ // escape sequences (NOTE just the control char is stripped)
+ `"\u001b[3mhello\u001b[m"`: "[3mhello[m",
+ `"a\u0080z"`: "az",
+ "\"a\u0080z\"": "az",
+ "\"a\u007fz\"": "az",
+ "\"a\u009fz\"": "az",
+ // replacement char
+ `"a\uFFFDb"`: "ab",
+ // private unicode chars
+ `"a\uE000b"`: "ab",
+ `"a\uDB80\uDC00b"`: "ab",
+}
+
+func (escapeSuite) TestStrings(c *check.C) {
+ var u safejson.String
+ for j, s := range table {
+ comm := check.Commentf(j)
+ c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm)
+ c.Check(u.Clean(), check.Equals, s, comm)
+
+ c.Assert(u.UnmarshalJSON([]byte(j)), check.IsNil, comm)
+ c.Check(u.Clean(), check.Equals, s, comm)
+ }
+}
+
+func (escapeSuite) TestBadStrings(c *check.C) {
+ var u1 safejson.String
+
+ cc0 := make([][]byte, 0x20)
+ for i := range cc0 {
+ cc0[i] = []byte{'"', byte(i), '"'}
+ }
+ badesc := make([][]byte, 0, 0x7f-0x21-9)
+ for c := byte('!'); c <= '~'; c++ {
+ switch c {
+ case '"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u':
+ continue
+ default:
+ badesc = append(badesc, []byte{'"', '\\', c, '"'})
+ }
+ }
+
+ table := map[string][][]byte{
+ // these are from json itself (so we're not checking them):
+ "invalid character '.+' in string literal": cc0,
+ "invalid character '.+' in string escape code": badesc,
+ `invalid character '.+' in \\u .*`: {[]byte(`"\u02"`), []byte(`"\u02zz"`)},
+ "invalid character '\"' after top-level value": {[]byte(`"""`)},
+ "unexpected end of JSON input": {[]byte(`"\"`)},
+ }
+
+ for e, js := range table {
+ for _, j := range js {
+ comm := check.Commentf("%q", j)
+ c.Check(json.Unmarshal(j, &u1), check.ErrorMatches, e, comm)
+ }
+ }
+
+ table = map[string][][]byte{
+ // these are from our lib
+ `missing string delimiters.*`: {{}, {'"'}},
+ `unexpected control character at 0 in "\\.+"`: cc0,
+ `unknown escape '.' at 1 of "\\."`: badesc,
+ `badly formed \\u escape.*`: {
+ []byte(`"\u02"`), []byte(`"\u02zz"`), []byte(`"a\u02xxz"`),
+ []byte(`"\uD83Da"`), []byte(`"\uD83Da\u20"`), []byte(`"\uD83Da\u20zzz"`),
+ },
+ `unexpected unescaped quote at 0 in """`: {[]byte(`"""`)},
+ `unexpected end of string \(trailing backslash\).*`: {[]byte(`"\"`)},
+ }
+
+ for e, js := range table {
+ for _, j := range js {
+ comm := check.Commentf("%q", j)
+ c.Check(u1.UnmarshalJSON(j), check.ErrorMatches, e, comm)
+ }
+ }
+}
+
+func (escapeSuite) TestParagraph(c *check.C) {
+ var u safejson.Paragraph
+ for j1, v1 := range table {
+ for j2, v2 := range table {
+ if j2 == "null" && j1 != "null" {
+ continue
+ }
+
+ var j, s string
+
+ if j1 == "null" {
+ j = j2
+ s = v2
+ } else {
+ j = j1[:len(j1)-1] + "\\n" + j2[1:]
+ s = v1 + "\n" + v2
+ }
+
+ comm := check.Commentf(j)
+ c.Assert(json.Unmarshal([]byte(j), &u), check.IsNil, comm)
+ c.Check(u.Clean(), check.Equals, s, comm)
+ }
+ }
+
+}
diff --git a/internals/osutil/io_test.go b/internals/osutil/io_test.go
index a57d88a1..61c6e69c 100644
--- a/internals/osutil/io_test.go
+++ b/internals/osutil/io_test.go
@@ -245,9 +245,9 @@ func (ts *AtomicWriteTestSuite) TestAtomicFileCancel(c *C) {
aw, err := osutil.NewAtomicFile(p, 0644, 0, osutil.NoChown, osutil.NoChown)
c.Assert(err, IsNil)
fn := aw.File.Name()
- c.Check(osutil.CanStat(fn), Equals, true)
+ c.Check(osutil.FileExists(fn), Equals, true)
c.Check(aw.Cancel(), IsNil)
- c.Check(osutil.CanStat(fn), Equals, false)
+ c.Check(osutil.FileExists(fn), Equals, false)
}
// SafeIoAtomicWriteTestSuite runs all AtomicWrite with safe
diff --git a/internals/osutil/squashfs/fstype.go b/internals/osutil/squashfs/fstype.go
index 07a8b8d5..f03fd213 100644
--- a/internals/osutil/squashfs/fstype.go
+++ b/internals/osutil/squashfs/fstype.go
@@ -25,7 +25,7 @@ import (
var useFuse = useFuseImpl
func useFuseImpl() bool {
- if !osutil.CanStat("/dev/fuse") {
+ if !osutil.FileExists("/dev/fuse") {
return false
}
diff --git a/internals/osutil/stat.go b/internals/osutil/stat.go
index c806dcad..8387e4c9 100644
--- a/internals/osutil/stat.go
+++ b/internals/osutil/stat.go
@@ -25,9 +25,9 @@ import (
"syscall"
)
-// CanStat returns true if stat succeeds on the given path.
+// FileExists returns true if stat succeeds on the given path.
// It may return false on permission issues.
-func CanStat(path string) bool {
+func FileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
diff --git a/internals/osutil/stat_test.go b/internals/osutil/stat_test.go
index a1235ced..cdf11672 100644
--- a/internals/osutil/stat_test.go
+++ b/internals/osutil/stat_test.go
@@ -34,8 +34,8 @@ func (ts *StatTestSuite) TestCanStat(c *C) {
err := ioutil.WriteFile(fname, []byte(fname), 0644)
c.Assert(err, IsNil)
- c.Assert(CanStat(fname), Equals, true)
- c.Assert(CanStat("/i-do-not-exist"), Equals, false)
+ c.Assert(FileExists(fname), Equals, true)
+ c.Assert(FileExists("/i-do-not-exist"), Equals, false)
}
func (ts *StatTestSuite) TestCanStatOddPerms(c *C) {
@@ -43,7 +43,7 @@ func (ts *StatTestSuite) TestCanStatOddPerms(c *C) {
err := ioutil.WriteFile(fname, []byte(fname), 0100)
c.Assert(err, IsNil)
- c.Assert(CanStat(fname), Equals, true)
+ c.Assert(FileExists(fname), Equals, true)
}
func (ts *StatTestSuite) TestIsDir(c *C) {
diff --git a/internals/overlord/export_test.go b/internals/overlord/export_test.go
index 22529702..8fe8f187 100644
--- a/internals/overlord/export_test.go
+++ b/internals/overlord/export_test.go
@@ -40,6 +40,14 @@ func FakePruneInterval(prunei, prunew, abortw time.Duration) (restore func()) {
}
}
+func FakePruneTicker(f func(t *time.Ticker) <-chan time.Time) (restore func()) {
+ old := pruneTickerC
+ pruneTickerC = f
+ return func() {
+ pruneTickerC = old
+ }
+}
+
// FakeEnsureNext sets o.ensureNext for tests.
func FakeEnsureNext(o *Overlord, t time.Time) {
o.ensureNext = t
diff --git a/internals/overlord/overlord.go b/internals/overlord/overlord.go
index 771a670d..bba805be 100644
--- a/internals/overlord/overlord.go
+++ b/internals/overlord/overlord.go
@@ -16,6 +16,7 @@
package overlord
import (
+ "errors"
"fmt"
"io"
"os"
@@ -49,6 +50,10 @@ var (
defaultCachedDownloads = 5
)
+var pruneTickerC = func(t *time.Ticker) <-chan time.Time {
+ return t.C
+}
+
// Overlord is the central manager of the system, keeping track
// of all available state managers and related helpers.
type Overlord struct {
@@ -63,8 +68,11 @@ type Overlord struct {
ensureRun int32
pruneTicker *time.Ticker
+ startOfOperationTime time.Time
+
// managers
inited bool
+ startedUp bool
runner *state.TaskRunner
serviceMgr *servstate.ServiceManager
commandMgr *cmdstate.CommandManager
@@ -136,6 +144,47 @@ func New(pebbleDir string, restartHandler restart.Handler, serviceOutput io.Writ
return o, nil
}
+var timeNow = time.Now
+
+// StartOfOperationTime returns the time when pebble started operating,
+// and sets it in the state when called for the first time.
+// The StartOfOperationTime time is seed-time if available,
+// or current time otherwise.
+func (m *Overlord) StartOfOperationTime() (time.Time, error) {
+ var opTime time.Time
+ err := m.State().Get("start-of-operation-time", &opTime)
+ if err == nil {
+ return opTime, nil
+ }
+ if err != nil && !errors.Is(err, state.ErrNoState) {
+ return opTime, err
+ }
+
+ opTime = timeNow()
+ m.State().Set("start-of-operation-time", opTime)
+ return opTime, nil
+}
+
+// StartUp proceeds to run any expensive Overlord or managers initialization.
+// After this is done once it is a noop.
+func (o *Overlord) StartUp() error {
+ if o.startedUp {
+ return nil
+ }
+ o.startedUp = true
+
+ var err error
+ st := o.State()
+ st.Lock()
+ o.startOfOperationTime, err = o.StartOfOperationTime()
+ st.Unlock()
+ if err != nil {
+ return fmt.Errorf("cannot get start of operation time: %s", err)
+ }
+
+ return o.stateEng.StartUp()
+}
+
func (o *Overlord) addManager(mgr StateManager) {
o.stateEng.AddManager(mgr)
}
@@ -159,7 +208,7 @@ func loadState(statePath string, restartHandler restart.Handler, backend state.B
}
}
- if !osutil.CanStat(statePath) {
+ if !osutil.FileExists(statePath) {
// fail fast, mostly interesting for tests, this dir is set up by pebble
stateDir := filepath.Dir(statePath)
if !osutil.IsDir(stateDir) {
@@ -255,6 +304,9 @@ func (o *Overlord) ensureBefore(d time.Duration) {
// Loop runs a loop in a goroutine to ensure the current state regularly through StateEngine Ensure.
func (o *Overlord) Loop() {
o.ensureTimerSetup()
+ if o.loopTomb == nil {
+ o.loopTomb = new(tomb.Tomb)
+ }
o.loopTomb.Go(func() error {
for {
// TODO: pass a proper context into Ensure
@@ -263,14 +315,15 @@ func (o *Overlord) Loop() {
// continue to the next Ensure() try for now
o.stateEng.Ensure()
o.ensureDidRun()
+ pruneC := pruneTickerC(o.pruneTicker)
select {
case <-o.loopTomb.Dying():
return nil
case <-o.ensureTimer.C:
- case <-o.pruneTicker.C:
+ case <-pruneC:
st := o.State()
st.Lock()
- st.Prune(pruneWait, abortWait, pruneMaxChanges)
+ st.Prune(o.startOfOperationTime, pruneWait, abortWait, pruneMaxChanges)
st.Unlock()
}
}
@@ -295,6 +348,10 @@ func (o *Overlord) Stop() error {
}
func (o *Overlord) settle(timeout time.Duration, beforeCleanups func()) error {
+ if err := o.StartUp(); err != nil {
+ return err
+ }
+
func() {
o.ensureLock.Lock()
defer o.ensureLock.Unlock()
diff --git a/internals/overlord/overlord_test.go b/internals/overlord/overlord_test.go
index c6b9cc7e..33696609 100644
--- a/internals/overlord/overlord_test.go
+++ b/internals/overlord/overlord_test.go
@@ -46,6 +46,26 @@ type overlordSuite struct {
var _ = Suite(&overlordSuite{})
+type ticker struct {
+ tickerChannel chan time.Time
+}
+
+func (w *ticker) tick(n int) {
+ for i := 0; i < n; i++ {
+ w.tickerChannel <- time.Now()
+ }
+}
+
+func fakePruneTicker() (w *ticker, restore func()) {
+ w = &ticker{
+ tickerChannel: make(chan time.Time),
+ }
+ restore = overlord.FakePruneTicker(func(t *time.Ticker) <-chan time.Time {
+ return w.tickerChannel
+ })
+ return w, restore
+}
+
func (ovs *overlordSuite) SetUpTest(c *C) {
ovs.dir = c.MkDir()
ovs.statePath = filepath.Join(ovs.dir, ".pebble.state")
@@ -161,6 +181,12 @@ type witnessManager struct {
expectedEnsure int
ensureCalled chan struct{}
ensureCallback func(s *state.State) error
+ startedUp int
+}
+
+func (wm *witnessManager) StartUp() error {
+ wm.startedUp++
+ return nil
}
func (wm *witnessManager) Ensure() error {
@@ -178,6 +204,9 @@ func (ovs *overlordSuite) TestTrivialRunAndStop(c *C) {
o, err := overlord.New(ovs.dir, nil, nil)
c.Assert(err, IsNil)
+ err = o.StartUp()
+ c.Assert(err, IsNil)
+
o.Loop()
err = o.Stop()
@@ -216,6 +245,10 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) {
}
o.AddManager(witness)
+ err := o.StartUp()
+
+ c.Assert(err, IsNil)
+
o.Loop()
defer o.Stop()
@@ -227,8 +260,9 @@ func (ovs *overlordSuite) TestEnsureLoopRunAndStop(c *C) {
}
c.Check(time.Since(t0) >= 10*time.Millisecond, Equals, true)
- err := o.Stop()
+ err = o.Stop()
c.Assert(err, IsNil)
+ c.Check(witness.startedUp, Equals, 1)
}
func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) {
@@ -248,7 +282,7 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeImmediate(c *C) {
ensureCallback: ensure,
}
o.AddManager(witness)
-
+ c.Assert(o.StartUp(), IsNil)
o.Loop()
defer o.Stop()
@@ -277,6 +311,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBefore(c *C) {
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
o.Loop()
defer o.Stop()
@@ -306,6 +342,8 @@ func (ovs *overlordSuite) TestEnsureBeforeSleepy(c *C) {
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
o.Loop()
defer o.Stop()
@@ -335,6 +373,8 @@ func (ovs *overlordSuite) TestEnsureBeforeLater(c *C) {
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
o.Loop()
defer o.Stop()
@@ -364,6 +404,8 @@ func (ovs *overlordSuite) TestEnsureLoopMediatedEnsureBeforeOutsideEnsure(c *C)
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
o.Loop()
defer o.Stop()
@@ -420,6 +462,8 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) {
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
o.Loop()
select {
@@ -441,7 +485,7 @@ func (ovs *overlordSuite) TestEnsureLoopPrune(c *C) {
}
func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) {
- restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour)
+ restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 5*time.Millisecond, 1*time.Hour)
defer restoreIntv()
o := overlord.Fake()
@@ -459,20 +503,26 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) {
c.Check(st.Changes(), HasLen, 2)
st.Unlock()
+ w, restoreTicker := fakePruneTicker()
+ defer restoreTicker()
+
// start the loop that runs the prune ticker
o.Loop()
- // ensure the first change is pruned
- time.Sleep(1500 * time.Millisecond)
- st.Lock()
- c.Check(st.Changes(), HasLen, 1)
- st.Unlock()
+ // this needs to be more than pruneWait=5ms mocked above
+ time.Sleep(10 * time.Millisecond)
+ w.tick(2)
- // ensure the second is also purged after it is ready
st.Lock()
+ c.Check(st.Changes(), HasLen, 1)
chg2.SetStatus(state.DoneStatus)
st.Unlock()
- time.Sleep(1500 * time.Millisecond)
+
+ // this needs to be more than pruneWait=5ms mocked above
+ time.Sleep(10 * time.Millisecond)
+ // tick twice for extra Ensure
+ w.tick(2)
+
st.Lock()
c.Check(st.Changes(), HasLen, 0)
st.Unlock()
@@ -482,6 +532,134 @@ func (ovs *overlordSuite) TestEnsureLoopPruneRunsMultipleTimes(c *C) {
c.Assert(err, IsNil)
}
+func (ovs *overlordSuite) TestOverlordStartUpSetsStartOfOperation(c *C) {
+ restoreIntv := overlord.FakePruneInterval(100*time.Millisecond, 1000*time.Millisecond, 1*time.Hour)
+ defer restoreIntv()
+
+ o, err := overlord.New(ovs.dir, nil, nil)
+ c.Assert(err, IsNil)
+
+ st := o.State()
+ st.Lock()
+ defer st.Unlock()
+
+ // validity check, not set
+ var opTime time.Time
+ c.Assert(st.Get("start-of-operation-time", &opTime), testutil.ErrorIs, state.ErrNoState)
+ st.Unlock()
+
+ c.Assert(o.StartUp(), IsNil)
+
+ st.Lock()
+ c.Assert(st.Get("start-of-operation-time", &opTime), IsNil)
+}
+
+func (ovs *overlordSuite) TestEnsureLoopPruneDoesntAbortShortlyAfterStartOfOperation(c *C) {
+ w, restoreTicker := fakePruneTicker()
+ defer restoreTicker()
+
+ o, err := overlord.New(ovs.dir, nil, nil)
+ c.Assert(err, IsNil)
+
+ // avoid immediate transition to Done due to unknown kind
+ o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error {
+ return &state.Retry{}
+ }, nil)
+
+ st := o.State()
+ st.Lock()
+
+ // start of operation time is 50min ago, this is less then abort limit
+ opTime := time.Now().Add(-50 * time.Minute)
+ st.Set("start-of-operation-time", opTime)
+
+ // spawn time one month ago
+ spawnTime := time.Now().AddDate(0, -1, 0)
+ restoreTimeNow := state.MockTime(spawnTime)
+
+ t := st.NewTask("bar", "...")
+ chg := st.NewChange("other-change", "...")
+ chg.AddTask(t)
+
+ restoreTimeNow()
+
+ // validity
+ c.Check(st.Changes(), HasLen, 1)
+
+ st.Unlock()
+ c.Assert(o.StartUp(), IsNil)
+
+ // start the loop that runs the prune ticker
+ o.Loop()
+ w.tick(2)
+
+ c.Assert(o.Stop(), IsNil)
+
+ st.Lock()
+ defer st.Unlock()
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+}
+
+func (ovs *overlordSuite) TestEnsureLoopPruneAbortsOld(c *C) {
+ // Ensure interval is not relevant for this test
+ restoreEnsureIntv := overlord.FakeEnsureInterval(10 * time.Hour)
+ defer restoreEnsureIntv()
+
+ w, restoreTicker := fakePruneTicker()
+ defer restoreTicker()
+
+ o, err := overlord.New(ovs.dir, nil, nil)
+ c.Assert(err, IsNil)
+
+ // avoid immediate transition to Done due to having unknown kind
+ o.TaskRunner().AddHandler("bar", func(t *state.Task, _ *tomb.Tomb) error {
+ return &state.Retry{}
+ }, nil)
+
+ st := o.State()
+ st.Lock()
+
+ // start of operation time is a year ago
+ opTime := time.Now().AddDate(-1, 0, 0)
+ st.Set("start-of-operation-time", opTime)
+
+ st.Unlock()
+ c.Assert(o.StartUp(), IsNil)
+ st.Lock()
+
+ // spawn time one month ago
+ spawnTime := time.Now().AddDate(0, -1, 0)
+ restoreTimeNow := state.MockTime(spawnTime)
+ t := st.NewTask("bar", "...")
+ chg := st.NewChange("other-change", "...")
+ chg.AddTask(t)
+
+ restoreTimeNow()
+
+ // validity
+ c.Check(st.Changes(), HasLen, 1)
+ st.Unlock()
+
+ // start the loop that runs the prune ticker
+ o.Loop()
+ w.tick(2)
+
+ c.Assert(o.Stop(), IsNil)
+
+ st.Lock()
+ defer st.Unlock()
+
+ // validity
+ op, err := o.StartOfOperationTime()
+ c.Assert(err, IsNil)
+ c.Check(op.Equal(opTime), Equals, true)
+
+ c.Assert(st.Changes(), HasLen, 1)
+ // change was aborted
+ c.Check(chg.Status(), Equals, state.HoldStatus)
+}
+
func (ovs *overlordSuite) TestCheckpoint(c *C) {
oldUmask := syscall.Umask(0)
defer syscall.Umask(oldUmask)
@@ -856,6 +1034,8 @@ func (ovs *overlordSuite) TestOverlordCanStandby(c *C) {
}
o.AddManager(witness)
+ c.Assert(o.StartUp(), IsNil)
+
// can only standby after loop ran once
c.Assert(o.CanStandby(), Equals, false)
diff --git a/internals/overlord/patch/patch.go b/internals/overlord/patch/patch.go
index afc93186..58514a8a 100644
--- a/internals/overlord/patch/patch.go
+++ b/internals/overlord/patch/patch.go
@@ -20,6 +20,7 @@
package patch
import (
+ "errors"
"fmt"
"github.com/canonical/pebble/internals/logger"
@@ -43,12 +44,12 @@ var patches = make(map[int][]PatchFunc)
func Init(s *state.State) {
s.Lock()
defer s.Unlock()
- if s.Get("patch-level", new(int)) != state.ErrNoState {
+ if err := s.Get("patch-level", new(int)); !errors.Is(err, state.ErrNoState) {
panic("internal error: expected empty state, attempting to override patch-level without actual patching")
}
s.Set("patch-level", Level)
- if s.Get("patch-sublevel", new(int)) != state.ErrNoState {
+ if err := s.Get("patch-sublevel", new(int)); !errors.Is(err, state.ErrNoState) {
panic("internal error: expected empty state, attempting to override patch-sublevel without actual patching")
}
s.Set("patch-sublevel", Sublevel)
@@ -76,12 +77,12 @@ func Apply(s *state.State) error {
var stateLevel, stateSublevel int
s.Lock()
err := s.Get("patch-level", &stateLevel)
- if err == nil || err == state.ErrNoState {
+ if err == nil || errors.Is(err, state.ErrNoState) {
err = s.Get("patch-sublevel", &stateSublevel)
}
s.Unlock()
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
diff --git a/internals/overlord/restart/restart.go b/internals/overlord/restart/restart.go
index 0c1a8840..6c11f62f 100644
--- a/internals/overlord/restart/restart.go
+++ b/internals/overlord/restart/restart.go
@@ -16,6 +16,8 @@
package restart
import (
+ "errors"
+
"github.com/canonical/pebble/internals/overlord/state"
)
@@ -61,7 +63,7 @@ func Init(st *state.State, curBootID string, h Handler) error {
}
var fromBootID string
err := st.Get("system-restart-from-boot-id", &fromBootID)
- if err != nil && err != state.ErrNoState {
+ if err != nil && !errors.Is(err, state.ErrNoState) {
return err
}
st.Cache(restartStateKey{}, rs)
diff --git a/internals/overlord/restart/restart_test.go b/internals/overlord/restart/restart_test.go
index 8a8617ed..74fc0764 100644
--- a/internals/overlord/restart/restart_test.go
+++ b/internals/overlord/restart/restart_test.go
@@ -21,6 +21,7 @@ import (
"github.com/canonical/pebble/internals/overlord/restart"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
func TestRestart(t *testing.T) { TestingT(t) }
@@ -134,5 +135,5 @@ func (s *restartSuite) TestRequestRestartSystemAndVerifyReboot(c *C) {
err = restart.Init(st, "boot-id-2", h2)
c.Assert(err, IsNil)
c.Check(h2.rebootAsExpected, Equals, true)
- c.Check(st.Get("system-restart-from-boot-id", &fromBootID), Equals, state.ErrNoState)
+ c.Check(st.Get("system-restart-from-boot-id", &fromBootID), testutil.ErrorIs, state.ErrNoState)
}
diff --git a/internals/overlord/state/change.go b/internals/overlord/state/change.go
index bfb3b42d..22bfa123 100644
--- a/internals/overlord/state/change.go
+++ b/internals/overlord/state/change.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2023 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -23,8 +23,11 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "sort"
"strings"
"time"
+
+ "github.com/canonical/pebble/internals/logger"
)
// Status is used for status values for changes and tasks.
@@ -37,7 +40,8 @@ const (
// to an aggregation of its tasks' statuses. See Change.Status for details.
DefaultStatus Status = 0
- // HoldStatus means the task should not run, perhaps as a consequence of an error on another task.
+ // HoldStatus means the task should not run for the moment, perhaps as a
+ // consequence of an error on another task.
HoldStatus Status = 1
// DoStatus means the change or task is ready to start.
@@ -65,6 +69,12 @@ const (
// ErrorStatus means the change or task has errored out while running or being undone.
ErrorStatus Status = 9
+ // WaitStatus means the task was accomplished successfully but some
+ // external event needs to happen before work can progress further
+ // (e.g. on classic we require the user to reboot after a
+ // kernel snap update).
+ WaitStatus Status = 10
+
nStatuses = iota
)
@@ -88,6 +98,8 @@ func (s Status) String() string {
return "Doing"
case DoneStatus:
return "Done"
+ case WaitStatus:
+ return "Wait"
case AbortStatus:
return "Abort"
case UndoStatus:
@@ -104,6 +116,18 @@ func (s Status) String() string {
panic(fmt.Sprintf("internal error: unknown task status code: %d", s))
}
+// taskWaitComputeStatus is used while computing the wait status of a
+// change. It keeps track of whether a task is waiting or not waiting, or the
+// computation for it is still in-progress to detect cyclic dependencies.
+type taskWaitComputeStatus int
+
+const (
+ taskWaitStatusNotComputed taskWaitComputeStatus = iota
+ taskWaitStatusComputing
+ taskWaitStatusNotWaiting
+ taskWaitStatusWaiting
+)
+
// Change represents a tracked modification to the system state.
//
// The Change provides both the justification for individual tasks
@@ -115,16 +139,16 @@ func (s Status) String() string {
// while the individual Task values would track the running of
// the hooks themselves.
type Change struct {
- state *State
- id string
- kind string
- summary string
- status Status
- clean bool
- data customData
- taskIDs []string
- lanes int
- ready chan struct{}
+ state *State
+ id string
+ kind string
+ summary string
+ status Status
+ clean bool
+ data customData
+ taskIDs []string
+ ready chan struct{}
+ lastObservedStatus Status
spawnTime time.Time
readyTime time.Time
@@ -157,7 +181,6 @@ type marshalledChange struct {
Clean bool `json:"clean,omitempty"`
Data map[string]*json.RawMessage `json:"data,omitempty"`
TaskIDs []string `json:"task-ids,omitempty"`
- Lanes int `json:"lanes,omitempty"`
SpawnTime time.Time `json:"spawn-time"`
ReadyTime *time.Time `json:"ready-time,omitempty"`
@@ -178,7 +201,6 @@ func (c *Change) MarshalJSON() ([]byte, error) {
Clean: c.clean,
Data: c.data,
TaskIDs: c.taskIDs,
- Lanes: c.lanes,
SpawnTime: c.spawnTime,
ReadyTime: readyTime,
@@ -206,7 +228,6 @@ func (c *Change) UnmarshalJSON(data []byte) error {
}
c.data = custData
c.taskIDs = unmarshalled.TaskIDs
- c.lanes = unmarshalled.Lanes
c.ready = make(chan struct{})
c.spawnTime = unmarshalled.SpawnTime
if unmarshalled.ReadyTime != nil {
@@ -251,12 +272,19 @@ func (c *Change) Get(key string, value interface{}) error {
return c.data.get(key, value)
}
+// Has returns whether the provided key has an associated value.
+func (c *Change) Has(key string) bool {
+ c.state.reading()
+ return c.data.has(key)
+}
+
var statusOrder = []Status{
AbortStatus,
UndoingStatus,
UndoStatus,
DoingStatus,
DoStatus,
+ WaitStatus,
ErrorStatus,
UndoneStatus,
DoneStatus,
@@ -269,32 +297,138 @@ func init() {
}
}
+func (c *Change) isTaskWaiting(visited map[string]taskWaitComputeStatus, t *Task, deps []*Task) bool {
+ taskID := t.ID()
+ // Retrieve the compute status of the wait for the task, if not
+ // computed this defaults to 0 (taskWaitStatusNotComputed).
+ computeStatus := visited[taskID]
+ switch computeStatus {
+ case taskWaitStatusComputing:
+ // Cyclic dependency detected, return false to short-circuit.
+ logger.Noticef("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind())
+ // Make sure errors show up in "snap change " too
+ t.Logf("detected cyclic dependencies for task %q in change %q", t.Kind(), t.Change().Kind())
+ return false
+ case taskWaitStatusWaiting, taskWaitStatusNotWaiting:
+ return computeStatus == taskWaitStatusWaiting
+ }
+ visited[taskID] = taskWaitStatusComputing
+
+ var isWaiting bool
+depscheck:
+ for _, wt := range deps {
+ switch wt.Status() {
+ case WaitStatus:
+ isWaiting = true
+ // States that can be valid when waiting
+ // - Done, Undone, ErrorStatus, HoldStatus
+ case DoneStatus, UndoneStatus, ErrorStatus, HoldStatus:
+ continue
+ // For 'Do' and 'Undo' we have to check whether the task is waiting
+ // for any dependencies. The logic is the same, but the set of tasks
+ // varies.
+ case DoStatus:
+ isWaiting = c.isTaskWaiting(visited, wt, wt.WaitTasks())
+ if !isWaiting {
+ // Cancel early if we detect something is runnable.
+ break depscheck
+ }
+ case UndoStatus:
+ isWaiting = c.isTaskWaiting(visited, wt, wt.HaltTasks())
+ if !isWaiting {
+ // Cancel early if we detect something is runnable.
+ break depscheck
+ }
+ default:
+ // When we determine the change can not be in a wait-state then
+ // break early.
+ isWaiting = false
+ break depscheck
+ }
+ }
+ if isWaiting {
+ visited[taskID] = taskWaitStatusWaiting
+ } else {
+ visited[taskID] = taskWaitStatusNotWaiting
+ }
+ return isWaiting
+}
+
+// isChangeWaiting should only ever return true iff it determines all tasks in Do/Undo
+// are blocked by tasks in either of three states: 'DoneStatus', 'UndoneStatus' or 'WaitStatus',
+// if this fails, we default to the normal status ordering logic.
+func (c *Change) isChangeWaiting() bool {
+ // Since we might visit tasks more than once, we store results to avoid recomputing them.
+ visited := make(map[string]taskWaitComputeStatus)
+ for _, t := range c.Tasks() {
+ switch t.Status() {
+ case WaitStatus, DoneStatus, UndoneStatus, ErrorStatus, HoldStatus:
+ continue
+ case DoStatus:
+ if !c.isTaskWaiting(visited, t, t.WaitTasks()) {
+ return false
+ }
+ case UndoStatus:
+ if !c.isTaskWaiting(visited, t, t.HaltTasks()) {
+ return false
+ }
+ default:
+ return false
+ }
+ }
+ // If we end up here, then return true as we know we
+ // have at least one waiter in this change.
+ return true
+}
+
// Status returns the current status of the change.
// If the status was not explicitly set the result is derived from the status
// of the individual tasks related to the change, according to the following
// decision sequence:
//
+// - With all pending tasks blocked by other tasks in WaitStatus, return WaitStatus
// - With at least one task in DoStatus, return DoStatus
// - With at least one task in ErrorStatus, return ErrorStatus
// - Otherwise, return DoneStatus
func (c *Change) Status() Status {
c.state.reading()
- if c.status == DefaultStatus {
- if len(c.taskIDs) == 0 {
- return HoldStatus
- }
- statusStats := make([]int, nStatuses)
- for _, tid := range c.taskIDs {
- statusStats[c.state.tasks[tid].Status()]++
+ if c.status != DefaultStatus {
+ return c.status
+ }
+
+ if len(c.taskIDs) == 0 {
+ return HoldStatus
+ }
+
+ statusStats := make([]int, nStatuses)
+ for _, tid := range c.taskIDs {
+ statusStats[c.state.tasks[tid].Status()]++
+ }
+
+ // If the change has any waiters, check for any runnable tasks
+ // or whether it's completely blocked by waiters.
+ if statusStats[WaitStatus] > 0 {
+ // Only if the change has all tasks blocked we return WaitStatus.
+ if c.isChangeWaiting() {
+ return WaitStatus
}
- for _, s := range statusOrder {
- if statusStats[s] > 0 {
- return s
- }
+ }
+
+ // Otherwise we return the current status with the highest priority.
+ for _, s := range statusOrder {
+ if statusStats[s] > 0 {
+ return s
}
- panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats))
}
- return c.status
+ panic(fmt.Sprintf("internal error: cannot process change status: %v", statusStats))
+}
+
+func (c *Change) notifyStatusChange(new Status) {
+ if c.lastObservedStatus == new {
+ return
+ }
+ c.state.notifyChangeStatusChangedHandlers(c, c.lastObservedStatus, new)
+ c.lastObservedStatus = new
}
// SetStatus sets the change status, overriding the default behavior (see Status method).
@@ -304,6 +438,7 @@ func (c *Change) SetStatus(s Status) {
if s.Ready() {
c.markReady()
}
+ c.notifyStatusChange(c.Status())
}
func (c *Change) markReady() {
@@ -322,15 +457,10 @@ func (c *Change) Ready() <-chan struct{} {
return c.ready
}
-// taskStatusChanged is called by tasks when their status is changed,
-// to give the opportunity for the change to close its ready channel.
-func (c *Change) taskStatusChanged(t *Task, old, new Status) {
- if old.Ready() == new.Ready() {
- return
- }
+func (c *Change) detectChangeReady(excludeTask *Task) {
for _, tid := range c.taskIDs {
task := c.state.tasks[tid]
- if task != t && !task.status.Ready() {
+ if task != excludeTask && !task.status.Ready() {
return
}
}
@@ -343,6 +473,21 @@ func (c *Change) taskStatusChanged(t *Task, old, new Status) {
c.markReady()
}
+// taskStatusChanged is called by tasks when their status is changed,
+// to give the opportunity for the change to close its ready channel, and
+// notify observers of Change changes.
+func (c *Change) taskStatusChanged(t *Task, old, new Status) {
+ cs := c.Status()
+ // If the task changes from ready => unready or unready => ready,
+ // update the ready status for the change.
+ if old.Ready() == new.Ready() {
+ c.notifyStatusChange(cs)
+ return
+ }
+ c.detectChangeReady(t)
+ c.notifyStatusChange(cs)
+}
+
// IsClean returns whether all tasks in the change have been cleaned. See SetClean.
func (c *Change) IsClean() bool {
c.state.reading()
@@ -519,6 +664,44 @@ func (c *Change) AbortLanes(lanes []int) {
c.abortLanes(lanes, make(map[int]bool), make(map[string]bool))
}
+// AbortUnreadyLanes aborts the tasks from lanes that aren't fully ready, where
+// a ready lane is one in which all tasks are ready.
+func (c *Change) AbortUnreadyLanes() {
+ c.state.writing()
+ c.abortUnreadyLanes()
+}
+
+func (c *Change) abortUnreadyLanes() {
+ lanesWithLiveTasks := map[int]bool{}
+
+ for _, tid := range c.taskIDs {
+ t := c.state.tasks[tid]
+ if !t.Status().Ready() {
+ for _, tlane := range t.Lanes() {
+ lanesWithLiveTasks[tlane] = true
+ }
+ }
+ }
+
+ abortLanes := []int{}
+ for lane := range lanesWithLiveTasks {
+ abortLanes = append(abortLanes, lane)
+ }
+ c.abortLanes(abortLanes, make(map[int]bool), make(map[string]bool))
+}
+
+// taskEffectiveStatus returns the 'effective' status. This means it accounts
+// for tasks being in WaitStatus, and instead of returning the WaitStatus we
+// return the actual status. (The status after the wait).
+func taskEffectiveStatus(t *Task) Status {
+ status := t.Status()
+ if status == WaitStatus {
+ // If the task is waiting, then use the effective status instead.
+ status = t.WaitedStatus()
+ }
+ return status
+}
+
func (c *Change) abortLanes(lanes []int, abortedLanes map[int]bool, seenTasks map[string]bool) {
var hasLive = make(map[int]bool)
var hasDead = make(map[int]bool)
@@ -528,7 +711,7 @@ NextChangeTask:
t := c.state.tasks[tid]
var live bool
- switch t.Status() {
+ switch taskEffectiveStatus(t) {
case DoStatus, DoingStatus, DoneStatus:
live = true
}
@@ -579,7 +762,7 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks
continue
}
seenTasks[t.id] = true
- switch t.Status() {
+ switch taskEffectiveStatus(t) {
case DoStatus:
// Still pending so don't even start.
t.SetStatus(HoldStatus)
@@ -607,3 +790,87 @@ func (c *Change) abortTasks(tasks []*Task, abortedLanes map[int]bool, seenTasks
c.abortLanes(lanes, abortedLanes, seenTasks)
}
}
+
+type TaskDependencyCycleError struct {
+ IDs []string
+ msg string
+}
+
+func (e *TaskDependencyCycleError) Error() string { return e.msg }
+
+func (e *TaskDependencyCycleError) Is(err error) bool {
+ _, ok := err.(*TaskDependencyCycleError)
+ return ok
+}
+
+// CheckTaskDependencies checks the tasks in the change for cyclic dependencies
+// and returns an error in such case.
+func (c *Change) CheckTaskDependencies() error {
+ tasks := c.Tasks()
+ // count how many tasks any given non-independent task waits for
+ predecessors := make(map[string]int, len(tasks))
+
+ taskByID := map[string]*Task{}
+ for _, t := range tasks {
+ taskByID[t.id] = t
+ if l := len(t.waitTasks); l > 0 {
+ // only add an entry if the task is not independent
+ predecessors[t.id] = l
+ }
+ }
+
+ // Kahn topological sort: make our way starting with tasks that are
+ // independent (their predecessors count is 0), then visit their direct
+ // successors (halt tasks), and for each reduce their predecessors
+ // count; once the count drops to 0, all direct dependencies of a given
+ // task have been accounted for and the task becomes independent.
+
+ // queue of tasks to check
+ queue := make([]string, 0, len(tasks))
+ // identify all independent tasks
+ for _, t := range tasks {
+ if predecessors[t.id] == 0 {
+ queue = append(queue, t.id)
+ }
+ }
+
+ for len(queue) > 0 {
+ // take the first independent task
+ id := queue[0]
+ queue = queue[1:]
+ // reduce the incoming edge of its successors
+ for _, successor := range taskByID[id].haltTasks {
+ predecessors[successor]--
+ if predecessors[successor] == 0 {
+ // a task that was a successor has become
+ // independent
+ delete(predecessors, successor)
+ queue = append(queue, successor)
+ }
+ }
+ }
+
+ if len(predecessors) != 0 {
+ // tasks that are left cannot have their dependencies satisfied
+ var unsatisfiedTasks []string
+ for id := range predecessors {
+ unsatisfiedTasks = append(unsatisfiedTasks, id)
+ }
+ sort.Strings(unsatisfiedTasks)
+ msg := strings.Builder{}
+ msg.WriteString("dependency cycle involving tasks [")
+ for i, id := range unsatisfiedTasks {
+ t := taskByID[id]
+ msg.WriteString(fmt.Sprintf("%v:%v", t.id, t.kind))
+ if i < len(unsatisfiedTasks)-1 {
+ msg.WriteRune(' ')
+ }
+ }
+ msg.WriteRune(']')
+ return &TaskDependencyCycleError{
+ IDs: unsatisfiedTasks,
+ msg: msg.String(),
+ }
+ }
+ return nil
+}
diff --git a/internals/overlord/state/change_test.go b/internals/overlord/state/change_test.go
index 52b77bfc..e25c4f55 100644
--- a/internals/overlord/state/change_test.go
+++ b/internals/overlord/state/change_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -20,15 +20,16 @@
package state_test
import (
+ "errors"
"fmt"
"sort"
"strconv"
"strings"
"time"
- "github.com/canonical/pebble/internals/overlord/state"
-
. "gopkg.in/check.v1"
+
+ "github.com/canonical/pebble/internals/overlord/state"
)
type changeSuite struct{}
@@ -68,7 +69,7 @@ func (cs *changeSuite) TestReadyTime(c *C) {
}
func (cs *changeSuite) TestStatusString(c *C) {
- for s := state.Status(0); s < state.ErrorStatus+1; s++ {
+ for s := state.Status(0); s < state.WaitStatus+1; s++ {
c.Assert(s.String(), Matches, ".+")
}
}
@@ -88,6 +89,21 @@ func (cs *changeSuite) TestGetSet(c *C) {
c.Check(v, Equals, 1)
}
+func (cs *changeSuite) TestHas(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("install", "...")
+ c.Check(chg.Has("a"), Equals, false)
+
+ chg.Set("a", 1)
+ c.Check(chg.Has("a"), Equals, true)
+
+ chg.Set("a", nil)
+ c.Check(chg.Has("a"), Equals, false)
+}
+
// TODO Better testing of full change roundtripping via JSON.
func (cs *changeSuite) TestNewTaskAddTaskAndTasks(c *C) {
@@ -227,9 +243,13 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
tasks := make(map[state.Status]*state.Task)
- for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ {
+ for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ {
t := st.NewTask("download", s.String())
- t.SetStatus(s)
+ if s == state.WaitStatus {
+ t.SetToWait(state.DoneStatus)
+ } else {
+ t.SetStatus(s)
+ }
chg.AddTask(t)
tasks[s] = t
}
@@ -240,6 +260,7 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
state.UndoStatus,
state.DoingStatus,
state.DoStatus,
+ state.WaitStatus,
state.ErrorStatus,
state.UndoneStatus,
state.DoneStatus,
@@ -252,7 +273,11 @@ func (cs *changeSuite) TestStatusDerivedFromTasks(c *C) {
if s == s2 {
break
}
- tasks[s2].SetStatus(s)
+ if s == state.WaitStatus {
+ tasks[s2].SetToWait(state.DoneStatus)
+ } else {
+ tasks[s2].SetStatus(s)
+ }
}
c.Assert(chg.Status(), Equals, s)
}
@@ -432,9 +457,13 @@ func (cs *changeSuite) TestAbort(c *C) {
chg := st.NewChange("install", "...")
- for s := state.DefaultStatus + 1; s < state.ErrorStatus+1; s++ {
+ for s := state.DefaultStatus + 1; s < state.WaitStatus+1; s++ {
t := st.NewTask("download", s.String())
- t.SetStatus(s)
+ if s == state.WaitStatus {
+ t.SetToWait(state.DoneStatus)
+ } else {
+ t.SetStatus(s)
+ }
t.Set("old-status", s)
chg.AddTask(t)
}
@@ -451,7 +480,7 @@ func (cs *changeSuite) TestAbort(c *C) {
switch s {
case state.DoStatus:
c.Assert(t.Status(), Equals, state.HoldStatus)
- case state.DoneStatus:
+ case state.DoneStatus, state.WaitStatus:
c.Assert(t.Status(), Equals, state.UndoStatus)
case state.DoingStatus:
c.Assert(t.Status(), Equals, state.AbortStatus)
@@ -526,11 +555,13 @@ func (cs *changeSuite) TestAbortKⁿ(c *C) {
// Task wait order:
//
-// => t21 => t22
-// / \
-// t11 => t12 => t41 => t42
-// \ /
-// => t31 => t32
+// => t21 => t22
+// / \
+//
+// t11 => t12 => t41 => t42
+//
+// \ /
+// => t31 => t32
//
// setup and result lines are :[:,...]
//
@@ -577,6 +608,10 @@ var abortLanesTests = []struct {
setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
abort: []int{2},
result: "t21:hold t22:hold t41:hold t42:hold *:do",
+ }, {
+ setup: "t11:done:1 t12:wait:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ abort: []int{2},
+ result: "t21:hold t22:hold t41:hold t42:hold t11:done t12:wait *:do",
}, {
setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
abort: []int{3},
@@ -660,7 +695,7 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
c.Logf("Testing setup: %s", test.setup)
statuses := make(map[string]state.Status)
- for s := state.DefaultStatus; s <= state.ErrorStatus; s++ {
+ for s := state.DefaultStatus; s <= state.WaitStatus; s++ {
statuses[strings.ToLower(s.String())] = s
}
@@ -680,7 +715,11 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
}
seen[parts[0]] = true
task := tasks[parts[0]]
- task.SetStatus(statuses[parts[1]])
+ if statuses[parts[1]] == state.WaitStatus {
+ task.SetToWait(state.DoneStatus)
+ } else {
+ task.SetStatus(statuses[parts[1]])
+ }
if len(parts) > 2 {
lanes := strings.Split(parts[2], ",")
for _, lane := range lanes {
@@ -723,3 +762,691 @@ func (ts *taskRunnerSuite) TestAbortLanes(c *C) {
c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup))
}
}
+
+// setup and result lines are :[:,...]
+// order is -> (implies task2 waits for task 1)
+// "*" as task name means "all remaining".
+var abortUnreadyLanesTests = []struct {
+ setup string
+ order string
+ result string
+}{
+
+ // Some basics.
+ {
+ setup: "*:do",
+ result: "*:hold",
+ }, {
+ setup: "*:wait",
+ result: "*:undo",
+ }, {
+ setup: "*:done",
+ result: "*:done",
+ }, {
+ setup: "*:error",
+ result: "*:error",
+ },
+
+ // t11 (1) => t12 (1) => t21 (1) => t22 (1)
+ // t31 (2) => t32 (2) => t41 (2) => t42 (2)
+ {
+ setup: "t11:do:1 t12:do:1 t21:do:1 t22:do:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:do:2 t32:do:2 t41:do:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:hold t32:hold t41:hold t42:hold",
+ }, {
+ setup: "t11:done:1 t12:done:1 t21:done:1 t22:done:1 t31:done:2 t32:done:2 t41:done:2 t42:do:2",
+ order: "t11->t12 t12->t21 t21->t22 t31->t32 t32->t41 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:undo t32:undo t41:undo t42:hold",
+ },
+ // => t21 (2) => t22 (2)
+ // / \
+ // t11 (2,3) => t12 (2,3) => t41 (4) => t42 (4)
+ // \ /
+ // => t31 (3) => t32 (3)
+ {
+ setup: "t11:do:2,3 t12:do:2,3 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ // lane 2 is fully complete so it does not get aborted
+ result: "t11:done t12:done t21:done t22:done t31:abort t32:hold t41:hold t42:hold *:undo",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:done:2 t22:done:2 t31:wait:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ // lane 2 is fully complete so it does not get aborted
+ result: "t11:done t12:done t21:done t22:done t31:undo t32:hold t41:hold t42:hold *:undo",
+ }, {
+ setup: "t11:done:2,3 t12:done:2,3 t21:doing:2 t22:do:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ result: "t21:abort t22:hold t31:abort t32:hold t41:hold t42:hold *:undo",
+ },
+
+ // t11 (1) => t12 (1)
+ // t21 (2) => t22 (2)
+ // t31 (3) => t32 (3)
+ // t41 (4) => t42 (4)
+ {
+ setup: "t11:do:1 t12:do:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ result: "*:hold",
+ }, {
+ setup: "t11:do:1 t12:do:1 t21:doing:2 t22:do:2 t31:done:3 t32:doing:3 t41:undone:4 t42:error:4",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ result: "t11:hold t12:hold t21:abort t22:hold t31:undo t32:abort t41:undone t42:error",
+ },
+ // auto refresh like arrangement
+ //
+ // (apps)
+ // => t31 (3) => t32 (3)
+ // (snapd) (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11:done:1 t12:done:1 t21:done:2 t22:done:2 t31:doing:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ result: "t11:done t12:done t21:done t22:done t31:abort *:hold",
+ }, {
+ //
+ setup: "t11:done:1 t12:done:1 t21:done:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ result: "t11:done t12:done t21:undo *:hold",
+ },
+ // arrangement with a cyclic dependency between tasks
+ //
+ // /-----------------------------------------\
+ // | |
+ // | => t31 (3) => t32 (3) /
+ // (snapd) v (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11:done:1 t12:done:1 t21:do:2 t22:do:2 t31:do:3 t32:do:3 t41:do:4 t42:do:4",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21",
+ result: "t11:done t12:done *:hold",
+ },
+}
+
+func (ts *taskRunnerSuite) TestAbortUnreadyLanes(c *C) {
+
+ names := strings.Fields("t11 t12 t21 t22 t31 t32 t41 t42")
+
+ for i, test := range abortUnreadyLanesTests {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+
+ c.Logf("----- %v", i)
+ c.Logf("Testing setup: %s", test.setup)
+
+ for _, wp := range strings.Fields(test.order) {
+ pair := strings.Split(wp, "->")
+ c.Assert(pair, HasLen, 2)
+ // task 2 waits for task 1 is denoted as:
+ // task1->task2
+ tasks[pair[1]].WaitFor(tasks[pair[0]])
+ }
+
+ statuses := make(map[string]state.Status)
+ for s := state.DefaultStatus; s <= state.WaitStatus; s++ {
+ statuses[strings.ToLower(s.String())] = s
+ }
+
+ items := strings.Fields(test.setup)
+ seen := make(map[string]bool)
+ for i := 0; i < len(items); i++ {
+ item := items[i]
+ parts := strings.Split(item, ":")
+ if parts[0] == "*" {
+ c.Assert(i, Equals, len(items)-1, Commentf("*: can only be used as the last entry"))
+ for _, name := range names {
+ if !seen[name] {
+ parts[0] = name
+ items = append(items, strings.Join(parts, ":"))
+ }
+ }
+ continue
+ }
+ seen[parts[0]] = true
+ task := tasks[parts[0]]
+ if statuses[parts[1]] == state.WaitStatus {
+ task.SetToWait(state.DoneStatus)
+ } else {
+ task.SetStatus(statuses[parts[1]])
+ }
+ if len(parts) > 2 {
+ lanes := strings.Split(parts[2], ",")
+ for _, lane := range lanes {
+ n, err := strconv.Atoi(lane)
+ c.Assert(err, IsNil)
+ task.JoinLane(n)
+ }
+ }
+ }
+
+ c.Logf("Aborting")
+
+ chg.AbortUnreadyLanes()
+
+ c.Logf("Expected result: %s", test.result)
+
+ seen = make(map[string]bool)
+ var expected = strings.Fields(test.result)
+ var obtained []string
+ for i := 0; i < len(expected); i++ {
+ item := expected[i]
+ parts := strings.Split(item, ":")
+ if parts[0] == "*" {
+ c.Assert(i, Equals, len(expected)-1, Commentf("*: can only be used as the last entry"))
+ var expanded []string
+ for _, name := range names {
+ if !seen[name] {
+ parts[0] = name
+ expanded = append(expanded, strings.Join(parts, ":"))
+ }
+ }
+ expected = append(expected[:i], append(expanded, expected[i+1:]...)...)
+ i--
+ continue
+ }
+ name := parts[0]
+ seen[parts[0]] = true
+ obtained = append(obtained, name+":"+strings.ToLower(tasks[name].Status().String()))
+ }
+
+ c.Assert(strings.Join(obtained, " "), Equals, strings.Join(expected, " "), Commentf("setup: %s", test.setup))
+ }
+}
+
+// setup is a list of tasks " ", order is ->
+// (implies task2 waits for task 1)
+var cyclicDependencyTests = []struct {
+ setup string
+ order string
+ err string
+ errIDs []string
+}{
+
+ // Some basics.
+ {
+ setup: "t1",
+ }, {
+ setup: "",
+ }, {
+ // independent tasks
+ setup: "t1 t2 t3",
+ }, {
+ // some independent and some ordered tasks
+ setup: "t1 t2 t3 t4",
+ order: "t2->t3",
+ },
+ // some independent, dependencies as if added by WaitAll()
+ // t1 => t2
+ // t1,t2 => t3
+ // t1,t2,t3 => t4
+ {
+ setup: "t1 t2 t3 t4",
+ order: "t1->t2 t1->t3 t2->t3 t1->t4 t2->t4 t3->t4",
+ }, {
+ // simple loop
+ setup: "t1 t2",
+ order: "t1->t2 t2->t1",
+ err: `dependency cycle involving tasks \[1:t1 2:t2\]`,
+ errIDs: []string{"1", "2"},
+ },
+
+ // t1 => t2 => t3 => t4
+ // t5 => t6 => t7 => t8
+ {
+ setup: "t1 t2 t3 t4 t5 t6 t7 t8",
+ order: "t1->t2 t2->t3 t3->t4 t5->t6 t6->t7 t7->t8",
+ },
+ // => t21 => t22
+ // / \
+ // t11 => t12 => t41 => t42
+ // \ /
+ // => t31 => t32
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t12->t31 t21->t22 t31->t32 t22->t41 t32->t41 t41->t42",
+ },
+ // t11 (1) => t12 (1)
+ // t21 (2) => t22 (2)
+ // t31 (3) => t32 (3)
+ // t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t21->t22 t31->t32 t41->t42",
+ },
+ // auto refresh like arrangement
+ //
+ // (apps)
+ // => t31 (3) => t32 (3)
+ // (snapd) (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42",
+ },
+ // arrangement with a cyclic dependency between tasks
+ //
+ // /-----------------------------------------\
+ // | |
+ // | => t31 (3) => t32 (3) /
+ // (snapd) v (base) /
+ // t11 (1) => t12 (1) => t21 (2) => t22 (2)
+ // \
+ // => t41 (4) => t42 (4)
+ {
+ setup: "t11 t12 t21 t22 t31 t32 t41 t42",
+ order: "t11->t12 t12->t21 t21->t22 t22->t31 t22->t41 t31->t32 t41->t42 t32->t21",
+ err: `dependency cycle involving tasks \[3:t21 4:t22 5:t31 6:t32 7:t41 8:t42\]`,
+ errIDs: []string{"3", "4", "5", "6", "7", "8"},
+ },
+ // t1 => t2 => t3 => t4 --> t6
+ // t5 => t6 => t7 => t8 --> t2
+ {
+ setup: "t1 t2 t3 t4 t5 t6 t7 t8",
+ order: "t1->t2 t2->t3 t3->t4 t4->t6 t5->t6 t6->t7 t7->t8 t8->t2",
+ err: `dependency cycle involving tasks \[2:t2 3:t3 4:t4 6:t6 7:t7 8:t8\]`,
+ errIDs: []string{"2", "3", "4", "6", "7", "8"},
+ },
+}
+
+func (ts *taskRunnerSuite) TestCheckTaskDependencies(c *C) {
+
+ for i, test := range cyclicDependencyTests {
+ names := strings.Fields(test.setup)
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask(name, name)
+ chg.AddTask(tasks[name])
+ }
+
+ c.Logf("----- %v", i)
+ c.Logf("Testing setup: %s", test.setup)
+
+ for _, wp := range strings.Fields(test.order) {
+ pair := strings.Split(wp, "->")
+ c.Assert(pair, HasLen, 2)
+ // task 2 waits for task 1 is denoted as:
+ // task1->task2
+ tasks[pair[1]].WaitFor(tasks[pair[0]])
+ }
+
+ err := chg.CheckTaskDependencies()
+
+ if test.err != "" {
+ c.Assert(err, ErrorMatches, test.err)
+ c.Assert(errors.Is(err, &state.TaskDependencyCycleError{}), Equals, true)
+ errTasksDepCycle := err.(*state.TaskDependencyCycleError)
+ c.Assert(errTasksDepCycle.IDs, DeepEquals, test.errIDs)
+ } else {
+ c.Assert(err, IsNil)
+ }
+ }
+}
+
+func (cs *changeSuite) TestIsWaitingStatusOrderWithWaits(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+ t1.WaitFor(t2)
+ t1.WaitFor(t3)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // Set the wait-task into WaitStatus, to ensure we trigger the isWaiting
+ // logic and that it doesn't return WaitStatus for statuses which are in
+ // higher order
+ t4.SetToWait(state.DoneStatus)
+
+ // Test the following sequences:
+ // task1 (do) => task2 (done) => task3 (doing)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.DoingStatus)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+
+ // task1 (done) => task2 (done) => task3 (undoing)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoingStatus)
+ c.Check(chg.Status(), Equals, state.UndoingStatus)
+
+ // task1 (done) => task2 (done) => task3 (abort)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.AbortStatus)
+ c.Check(chg.Status(), Equals, state.AbortStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingSingle(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+
+ chg.AddTask(t1)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ t1.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingTwoTasks(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("wait-task", "...")
+ t2.WaitFor(t1)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+
+ // Put t3 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t3.SetToWait(state.DoneStatus)
+
+ // task1 (do) => task2 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) => task2 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (wait) => task2 (do) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) => task2 (wait) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingCircularDependency(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+
+ // Setup circular dependency between t1,t2 and t3, they should
+ // still act normally.
+ t2.WaitFor(t1)
+ t3.WaitFor(t2)
+ t1.WaitFor(t3)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // To trigger the cyclic dependency check, we must trigger the isWaiting logic
+ // and we do this by putting t4 into WaitStatus.
+ t4.SetToWait(state.DoneStatus)
+
+ // task1 (do) => task2 (do) => task3 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) => task2 (do) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoingStatus)
+ c.Check(chg.Status(), Equals, state.DoingStatus)
+
+ // task1 (wait) => task2 (do) => task3 (do) means need a reboot
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) => task2 (wait) => task3 (do) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingMultipleDependencies(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("wait-task", "...")
+ t3.WaitFor(t1)
+ t3.WaitFor(t2)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+
+ // Put t4 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t4.SetToWait(state.DoneStatus)
+
+ // task1 (do) + task2 (do) => task3 (do) no reboot
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) + task2 (done) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // task1 (done) + task2 (do) => task3 (do) no reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoStatus)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // For the next two cases we are testing that a task with dependencies
+ // which have completed, but in a non-successful way is handled correctly.
+ // task1 (error) + task2 (wait) => task3 (do) means need reboot
+ // to finalize task2
+ t1.SetStatus(state.ErrorStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (error) => task3 (do) means need reboot
+ // to finalize task1
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.ErrorStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) + task2 (wait) => task3 (do) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) => task3 (do) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (done) + task2 (done) => task3 (wait) means need a reboot
+ t1.SetStatus(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ t3.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (abort) => task3 (do)
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.AbortStatus)
+ t3.SetStatus(state.DoStatus)
+ c.Check(chg.Status(), Equals, state.AbortStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingUndoTwoTasks(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("wait-task", "...")
+ t2.WaitFor(t1)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+
+ // Put t3 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t3.SetToWait(state.DoneStatus)
+
+ // we use <=| to denote the reverse dependence relationship
+ // followed by undo logic
+
+ // task1 (undo) <=| task2 (undo) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) <=| task2 (undone) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) <=| task2 (wait) means need a reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetToWait(state.DoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) <=| task2 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
+
+func (cs *changeSuite) TestIsWaitingUndoMultipleDependencies(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+
+ t1 := st.NewTask("task1", "...")
+ t2 := st.NewTask("task2", "...")
+ t3 := st.NewTask("task3", "...")
+ t4 := st.NewTask("task4", "...")
+ t5 := st.NewTask("wait-task", "...")
+ t3.WaitFor(t1)
+ t3.WaitFor(t2)
+ t4.WaitFor(t1)
+ t4.WaitFor(t2)
+
+ chg.AddTask(t1)
+ chg.AddTask(t2)
+ chg.AddTask(t3)
+ chg.AddTask(t4)
+ chg.AddTask(t5)
+
+ // Put t5 into wait-status to trigger the isWaiting logic each time
+ // for the change.
+ t5.SetToWait(state.DoneStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (undo) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ t3.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (undone) no reboot
+ t1.SetStatus(state.UndoStatus)
+ t2.SetStatus(state.UndoStatus)
+ t3.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (undo) + task2 (undo) <=| task3 (wait) + task4 (error) means
+ // need reboot to continue undoing 1 and 2
+ t3.SetStatus(state.ErrorStatus)
+ t4.SetToWait(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (undo) + task2 (undo) => task3 (error) + task4 (wait) means
+ // need reboot to continue undoing 1 and 2
+ t3.SetToWait(state.UndoneStatus)
+ t4.SetStatus(state.ErrorStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undo) no reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoStatus)
+ c.Check(chg.Status(), Equals, state.UndoStatus)
+
+ // task1 (wait) + task2 (done) <=| task3 (undone) + task4 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetStatus(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+
+ // task1 (wait) + task2 (wait) <=| task3 (undone) + task4 (undone) means need a reboot
+ t1.SetToWait(state.DoneStatus)
+ t2.SetToWait(state.DoneStatus)
+ t3.SetStatus(state.UndoneStatus)
+ t4.SetStatus(state.UndoneStatus)
+ c.Check(chg.Status(), Equals, state.WaitStatus)
+}
diff --git a/internals/overlord/state/copy.go b/internals/overlord/state/copy.go
new file mode 100644
index 00000000..f87c2e85
--- /dev/null
+++ b/internals/overlord/state/copy.go
@@ -0,0 +1,141 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2020 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package state
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/canonical/pebble/internals/jsonutil"
+ "github.com/canonical/pebble/internals/osutil"
+)
+
+type checkpointOnlyBackend struct {
+ path string
+}
+
+func (b *checkpointOnlyBackend) Checkpoint(data []byte) error {
+ if err := os.MkdirAll(filepath.Dir(b.path), 0755); err != nil {
+ return err
+ }
+ return osutil.AtomicWriteFile(b.path, data, 0600, 0)
+}
+
+func (b *checkpointOnlyBackend) EnsureBefore(d time.Duration) {
+ panic("cannot use EnsureBefore in checkpointOnlyBackend")
+}
+
+// copyData will copy the given subkeys specifier from srcData to dstData.
+//
+// The subkeys is constructed from a dotted path like "user.auth". This copy
+// helper is recursive and the pos parameter tells the function the current
+// position of the copy.
+func copyData(subkeys []string, pos int, srcData map[string]*json.RawMessage, dstData map[string]interface{}) error {
+ if pos < 0 || pos > len(subkeys) {
+ return fmt.Errorf("internal error: copyData used with an out-of-bounds position: %v not in [0:%v]", pos, len(subkeys))
+ }
+ raw, ok := srcData[subkeys[pos]]
+ if !ok {
+ return ErrNoState
+ }
+
+ if pos+1 == len(subkeys) {
+ dstData[subkeys[pos]] = raw
+ return nil
+ }
+
+ var srcDatam map[string]*json.RawMessage
+ if err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &srcDatam); err != nil {
+ return fmt.Errorf("cannot unmarshal state entry %q with value %q as a map while trying to copy over %q", strings.Join(subkeys[:pos+1], "."), *raw, strings.Join(subkeys, "."))
+ }
+
+ // no subkey entry -> create one
+ if _, ok := dstData[subkeys[pos]]; !ok {
+ dstData[subkeys[pos]] = make(map[string]interface{})
+ }
+ // and use existing data
+ var dstDatam map[string]interface{}
+ switch dstDataEntry := dstData[subkeys[pos]].(type) {
+ case map[string]interface{}:
+ dstDatam = dstDataEntry
+ case *json.RawMessage:
+ dstDatam = make(map[string]interface{})
+ if err := jsonutil.DecodeWithNumber(bytes.NewReader(*dstDataEntry), &dstDatam); err != nil {
+ return fmt.Errorf("internal error: cannot decode subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstDataEntry)
+ }
+ default:
+ return fmt.Errorf("internal error: cannot create subkey %s (%q) for %v (%T)", subkeys[pos], strings.Join(subkeys, "."), dstData, dstData[subkeys[pos]])
+ }
+
+ return copyData(subkeys, pos+1, srcDatam, dstDatam)
+}
+
+// CopyState takes a state from the srcStatePath and copies all
+// dataEntries to the dstPath. Note that srcStatePath should never
+// point to a state that is in use.
+func CopyState(srcStatePath, dstStatePath string, dataEntries []string) error {
+ if osutil.FileExists(dstStatePath) {
+ // XXX: TOCTOU - look into moving this check into
+ // checkpointOnlyBackend. The issue is right now State
+ // will simply panic if Commit() returns an error
+ return fmt.Errorf("cannot copy state: %q already exists", dstStatePath)
+ }
+ if len(dataEntries) == 0 {
+ return fmt.Errorf("cannot copy state: must provide at least one data entry to copy")
+ }
+
+ f, err := os.Open(srcStatePath)
+ if err != nil {
+ return fmt.Errorf("cannot open state: %s", err)
+ }
+ defer f.Close()
+
+ // No need to lock/unlock the state here, srcState should not be
+ // in use at all.
+ srcState, err := ReadState(nil, f)
+ if err != nil {
+ return err
+ }
+
+ // copy relevant data
+ dstData := make(map[string]interface{})
+ for _, dataEntry := range dataEntries {
+ subkeys := strings.Split(dataEntry, ".")
+ if err := copyData(subkeys, 0, srcState.data, dstData); err != nil && !errors.Is(err, ErrNoState) {
+ return err
+ }
+ }
+
+ // write it out
+ dstState := New(&checkpointOnlyBackend{path: dstStatePath})
+ dstState.Lock()
+ defer dstState.Unlock()
+ for k, v := range dstData {
+ dstState.Set(k, v)
+ }
+
+ return nil
+}
diff --git a/internals/overlord/state/copy_test.go b/internals/overlord/state/copy_test.go
new file mode 100644
index 00000000..e7130e88
--- /dev/null
+++ b/internals/overlord/state/copy_test.go
@@ -0,0 +1,146 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2016-2020 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package state_test
+
+import (
+ "io/ioutil"
+ "path/filepath"
+
+ . "gopkg.in/check.v1"
+
+ "github.com/canonical/pebble/internals/overlord/state"
+)
+
+func (ss *stateSuite) TestCopyStateAlreadyExists(c *C) {
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+ err := ioutil.WriteFile(dstStateFile, nil, 0644)
+ c.Assert(err, IsNil)
+
+ err = state.CopyState(srcStateFile, dstStateFile, []string{"some-data"})
+ c.Assert(err, ErrorMatches, `cannot copy state: "/.*/dst-state.json" already exists`)
+}
+func (ss *stateSuite) TestCopyStateNoDataEntriesToCopy(c *C) {
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+
+ err := state.CopyState(srcStateFile, dstStateFile, nil)
+ c.Assert(err, ErrorMatches, `cannot copy state: must provide at least one data entry to copy`)
+}
+
+var srcStateContent = []byte(`
+{
+ "data": {
+ "api-download-tokens-secret": "123",
+ "api-download-tokens-secret-time": "2020-02-21T10:32:37.916147296Z",
+ "auth": {
+ "last-id": 1,
+ "users": [
+ {
+ "id": 1,
+ "email": "some@user.com",
+ "macaroon": "1234",
+ "store-macaroon": "5678",
+ "store-discharges": [
+ "9012345"
+ ]
+ }
+ ],
+ "device": {
+ "brand": "generic",
+ "model": "generic-classic",
+ "serial": "xxxxx-yyyyy-",
+ "key-id": "xxxxxx",
+ "session-macaroon": "xxxx"
+ },
+ "macaroon-key": "xxxx="
+ },
+ "config": {
+ }
+ }
+}
+`)
+
+const stateSuffix = `,"changes":{},"tasks":{},"last-change-id":0,"last-task-id":0,"last-lane-id":0}`
+
+func (ss *stateSuite) TestCopyStateIntegration(c *C) {
+ // create a mock srcState
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+ err := ioutil.WriteFile(srcStateFile, srcStateContent, 0644)
+ c.Assert(err, IsNil)
+
+ // copy
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+ err = state.CopyState(srcStateFile, dstStateFile, []string{"auth.users", "no-existing-does-not-error", "auth.last-id"})
+ c.Assert(err, IsNil)
+
+ // and check that the right bits got copied
+ dstContent, err := ioutil.ReadFile(dstStateFile)
+ c.Assert(err, IsNil)
+ c.Check(string(dstContent), Equals, `{"data":{"auth":{"last-id":1,"users":[{"id":1,"email":"some@user.com","macaroon":"1234","store-macaroon":"5678","store-discharges":["9012345"]}]}}`+stateSuffix)
+}
+
+var srcStateContent1 = []byte(`{
+ "data": {
+ "A": {"B": [{"C": 1}, {"D": 2}]},
+ "E": {"F": 2, "G": 3},
+ "H": 4,
+ "I": null
+ }
+}`)
+
+func (ss *stateSuite) TestCopyState(c *C) {
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+ err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644)
+ c.Assert(err, IsNil)
+
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+ err = state.CopyState(srcStateFile, dstStateFile, []string{"A.B", "no-existing-does-not-error", "E.F", "E", "I", "E.non-existing"})
+ c.Assert(err, IsNil)
+
+ dstContent, err := ioutil.ReadFile(dstStateFile)
+ c.Assert(err, IsNil)
+ c.Check(string(dstContent), Equals, `{"data":{"A":{"B":[{"C":1},{"D":2}]},"E":{"F":2,"G":3},"I":null}`+stateSuffix)
+}
+
+func (ss *stateSuite) TestCopyStateUnmarshalNotMap(c *C) {
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+ err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644)
+ c.Assert(err, IsNil)
+
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+ err = state.CopyState(srcStateFile, dstStateFile, []string{"E.F.subkey-not-in-a-map"})
+ c.Assert(err, ErrorMatches, `cannot unmarshal state entry "E.F" with value "2" as a map while trying to copy over "E.F.subkey-not-in-a-map"`)
+}
+
+func (ss *stateSuite) TestCopyStateDuplicatesInDataEntriesAreFine(c *C) {
+ srcStateFile := filepath.Join(c.MkDir(), "src-state.json")
+ err := ioutil.WriteFile(srcStateFile, srcStateContent1, 0644)
+ c.Assert(err, IsNil)
+
+ dstStateFile := filepath.Join(c.MkDir(), "dst-state.json")
+ err = state.CopyState(srcStateFile, dstStateFile, []string{"E", "E"})
+ c.Assert(err, IsNil)
+
+ dstContent, err := ioutil.ReadFile(dstStateFile)
+ c.Assert(err, IsNil)
+ c.Check(string(dstContent), Equals, `{"data":{"E":{"F":2,"G":3}}`+stateSuffix)
+}
diff --git a/internals/overlord/state/export_test.go b/internals/overlord/state/export_test.go
index adbcc5d5..3bf6f8ef 100644
--- a/internals/overlord/state/export_test.go
+++ b/internals/overlord/state/export_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -23,8 +23,8 @@ import (
"time"
)
-// FakeCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime.
-func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) {
+// MockCheckpointRetryDelay changes unlockCheckpointRetryInterval and unlockCheckpointRetryMaxTime.
+func MockCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restore func()) {
oldInterval := unlockCheckpointRetryInterval
oldMaxTime := unlockCheckpointRetryMaxTime
unlockCheckpointRetryInterval = retryInterval
@@ -35,12 +35,12 @@ func FakeCheckpointRetryDelay(retryInterval, retryMaxTime time.Duration) (restor
}
}
-func FakeChangeTimes(chg *Change, spawnTime, readyTime time.Time) {
+func MockChangeTimes(chg *Change, spawnTime, readyTime time.Time) {
chg.spawnTime = spawnTime
chg.readyTime = readyTime
}
-func FakeTaskTimes(t *Task, spawnTime, readyTime time.Time) {
+func MockTaskTimes(t *Task, spawnTime, readyTime time.Time) {
t.spawnTime = spawnTime
t.readyTime = readyTime
}
diff --git a/internals/overlord/state/state.go b/internals/overlord/state/state.go
index 660a8a64..86733024 100644
--- a/internals/overlord/state/state.go
+++ b/internals/overlord/state/state.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -46,7 +46,7 @@ type customData map[string]*json.RawMessage
func (data customData) get(key string, value interface{}) error {
entryJSON := data[key]
if entryJSON == nil {
- return ErrNoState
+ return &NoStateError{Key: key}
}
err := json.Unmarshal(*entryJSON, value)
if err != nil {
@@ -87,6 +87,9 @@ type State struct {
lastTaskId int
lastChangeId int
lastLaneId int
+ // lastHandlerId is not serialized, it's only used during runtime
+ // for registering runtime callbacks
+ lastHandlerId int
backend Backend
data customData
@@ -97,18 +100,27 @@ type State struct {
modified bool
cache map[interface{}]interface{}
+
+ pendingChangeByAttr map[string]func(*Change) bool
+
+ // task/changes observing
+ taskHandlers map[int]func(t *Task, old, new Status)
+ changeHandlers map[int]func(chg *Change, old, new Status)
}
// New returns a new empty state.
func New(backend Backend) *State {
return &State{
- backend: backend,
- data: make(customData),
- changes: make(map[string]*Change),
- tasks: make(map[string]*Task),
- warnings: make(map[string]*Warning),
- modified: true,
- cache: make(map[interface{}]interface{}),
+ backend: backend,
+ data: make(customData),
+ changes: make(map[string]*Change),
+ tasks: make(map[string]*Task),
+ warnings: make(map[string]*Warning),
+ modified: true,
+ cache: make(map[interface{}]interface{}),
+ pendingChangeByAttr: make(map[string]func(*Change) bool),
+ taskHandlers: make(map[int]func(t *Task, old Status, new Status)),
+ changeHandlers: make(map[int]func(chg *Change, old Status, new Status)),
}
}
@@ -241,6 +253,28 @@ func (s *State) EnsureBefore(d time.Duration) {
// ErrNoState represents the case of no state entry for a given key.
var ErrNoState = errors.New("no state entry for key")
+// NoStateError represents the case where no state could be found for a given key.
+type NoStateError struct {
+ // Key is the key for which no state could be found.
+ Key string
+}
+
+func (e *NoStateError) Error() string {
+ var keyMsg string
+ if e.Key != "" {
+ keyMsg = fmt.Sprintf(" %q", e.Key)
+ }
+
+ return fmt.Sprintf("no state entry for key%s", keyMsg)
+}
+
+// Is returns true if the error is of type *NoStateError or equal to ErrNoState.
+// NoStateError's key isn't compared between errors.
+func (e *NoStateError) Is(err error) bool {
+ _, ok := err.(*NoStateError)
+ return ok || errors.Is(err, ErrNoState)
+}
+
// Get unmarshals the stored value associated with the provided key
// into the value parameter.
// It returns ErrNoState if there is no entry for key.
@@ -249,6 +283,12 @@ func (s *State) Get(key string, value interface{}) error {
return s.data.get(key, value)
}
+// Has returns whether the provided key has an associated value.
+func (s *State) Has(key string) bool {
+ s.reading()
+ return s.data.has(key)
+}
+
// Set associates value with key for future consulting by managers.
// The provided value must properly marshal and unmarshal with encoding/json.
func (s *State) Set(key string, value interface{}) {
@@ -357,15 +397,25 @@ func (s *State) tasksIn(tids []string) []*Task {
return res
}
+// RegisterPendingChangeByAttr registers predicates that will be invoked by
+// Prune on changes with the specified attribute set to check whether even if
+// they meet the time criteria they must not be aborted yet.
+func (s *State) RegisterPendingChangeByAttr(attr string, f func(*Change) bool) {
+ s.pendingChangeByAttr[attr] = f
+}
+
// Prune does several cleanup tasks to the in-memory state:
//
// - it removes changes that became ready for more than pruneWait and aborts
-// tasks spawned for more than abortWait.
+// tasks spawned for more than abortWait unless prevented by predicates
+// registered with RegisterPendingChangeByAttr.
+//
// - it removes tasks unlinked to changes after pruneWait. When there are more
// changes than the limit set via "maxReadyChanges" those changes in ready
// state will also removed even if they are below the pruneWait duration.
+//
// - it removes expired warnings.
-func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
+func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) {
now := time.Now()
pruneLimit := now.Add(-pruneWait)
abortLimit := now.Add(-abortWait)
@@ -392,15 +442,24 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
}
}
+NextChange:
for _, chg := range changes {
- spawnTime := chg.SpawnTime()
readyTime := chg.ReadyTime()
+ spawnTime := chg.SpawnTime()
+ if spawnTime.Before(startOfOperation) {
+ spawnTime = startOfOperation
+ }
if readyTime.IsZero() {
if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 {
chg.Abort()
delete(s.changes, chg.ID())
} else if spawnTime.Before(abortLimit) {
- chg.Abort()
+ for attr, pending := range s.pendingChangeByAttr {
+ if chg.Has(attr) && pending(chg) {
+ continue NextChange
+ }
+ }
+ chg.AbortUnreadyLanes()
}
continue
}
@@ -424,6 +483,75 @@ func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
}
}
+// GetMaybeTimings implements snapcore/snapd/timings.GetSaver
+func (s *State) GetMaybeTimings(timings interface{}) error {
+ if err := s.Get("timings", timings); err != nil && !errors.Is(err, ErrNoState) {
+ return err
+ }
+ return nil
+}
+
+// AddTaskStatusChangedHandler adds a callback function that will be invoked
+// whenever tasks change status.
+// NOTE: Callbacks registered this way may be invoked in the context
+// of the taskrunner, so the callbacks should be as simple as possible, and return
+// as quickly as possible, and should avoid the use of i/o code or blocking, as this
+// will stop the entire task system.
+func (s *State) AddTaskStatusChangedHandler(f func(t *Task, old, new Status)) (id int) {
+ // We are reading here as we want to ensure access to the state is serialized,
+ // and not writing as we are not changing the part of state that goes on the disk.
+ s.reading()
+ id = s.lastHandlerId
+ s.lastHandlerId++
+ s.taskHandlers[id] = f
+ return id
+}
+
+func (s *State) RemoveTaskStatusChangedHandler(id int) {
+ s.reading()
+ delete(s.taskHandlers, id)
+}
+
+func (s *State) notifyTaskStatusChangedHandlers(t *Task, old, new Status) {
+ s.reading()
+ for _, f := range s.taskHandlers {
+ f(t, old, new)
+ }
+}
+
+// AddChangeStatusChangedHandler adds a callback function that will be invoked
+// whenever a Change changes status.
+// NOTE: Callbacks registered this way may be invoked in the context
+// of the taskrunner, so the callbacks should be as simple as possible, and return
+// as quickly as possible, and should avoid the use of i/o code or blocking, as this
+// will stop the entire task system.
+func (s *State) AddChangeStatusChangedHandler(f func(chg *Change, old, new Status)) (id int) {
+ // We are reading here as we want to ensure access to the state is serialized,
+ // and not writing as we are not changing the part of state that goes on the disk.
+ s.reading()
+ id = s.lastHandlerId
+ s.lastHandlerId++
+ s.changeHandlers[id] = f
+ return id
+}
+
+func (s *State) RemoveChangeStatusChangedHandler(id int) {
+ s.reading()
+ delete(s.changeHandlers, id)
+}
+
+func (s *State) notifyChangeStatusChangedHandlers(chg *Change, old, new Status) {
+ s.reading()
+ for _, f := range s.changeHandlers {
+ f(chg, old, new)
+ }
+}
+
+// SaveTimings implements snapcore/snapd/timings.GetSaver
+func (s *State) SaveTimings(timings interface{}) {
+ s.Set("timings", timings)
+}
+
// ReadState returns the state deserialized from r.
func ReadState(backend Backend, r io.Reader) (*State, error) {
s := new(State)
@@ -437,5 +565,6 @@ func ReadState(backend Backend, r io.Reader) (*State, error) {
s.backend = backend
s.modified = false
s.cache = make(map[interface{}]interface{})
+ s.pendingChangeByAttr = make(map[string]func(*Change) bool)
return s, err
}
diff --git a/internals/overlord/state/state_test.go b/internals/overlord/state/state_test.go
index e1e8a707..2ab46026 100644
--- a/internals/overlord/state/state_test.go
+++ b/internals/overlord/state/state_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -29,6 +29,7 @@ import (
. "gopkg.in/check.v1"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
func TestState(t *testing.T) { TestingT(t) }
@@ -76,6 +77,37 @@ func (ss *stateSuite) TestGetAndSet(c *C) {
c.Check(&mSt2B, DeepEquals, mSt2)
}
+func (ss *stateSuite) TestHas(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(st.Has("a"), Equals, false)
+
+ st.Set("a", 1)
+ c.Check(st.Has("a"), Equals, true)
+
+ st.Set("a", nil)
+ c.Check(st.Has("a"), Equals, false)
+}
+
+func (ss *stateSuite) TestStrayTaskWithNoChange(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ chg := st.NewChange("change", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+ _ = st.NewTask("bar", "...")
+
+ // only the task with associate change is returned
+ c.Assert(st.Tasks(), HasLen, 1)
+ c.Assert(st.Tasks()[0].ID(), Equals, t1.ID())
+ // but count includes all tasks
+ c.Assert(st.TaskCount(), Equals, 2)
+}
+
func (ss *stateSuite) TestSetPanic(c *C) {
st := state.New(nil)
st.Lock()
@@ -94,7 +126,7 @@ func (ss *stateSuite) TestGetNoState(c *C) {
var mSt1B mgrState1
err := st.Get("mgr9", &mSt1B)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
}
func (ss *stateSuite) TestSetToNilDeletes(c *C) {
@@ -112,7 +144,7 @@ func (ss *stateSuite) TestSetToNilDeletes(c *C) {
var v1 map[string]int
err = st.Get("a", &v1)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
c.Check(v1, HasLen, 0)
}
@@ -126,7 +158,7 @@ func (ss *stateSuite) TestNullMeansNoState(c *C) {
var v1 map[string]int
err = st.Get("a", &v1)
- c.Check(err, Equals, state.ErrNoState)
+ c.Check(err, testutil.ErrorIs, state.ErrNoState)
c.Check(v1, HasLen, 0)
}
@@ -227,7 +259,7 @@ func (ss *stateSuite) TestImplicitCheckpointAndRead(c *C) {
}
func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) {
- restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second)
+ restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second)
defer restore()
retries := 0
@@ -250,7 +282,7 @@ func (ss *stateSuite) TestImplicitCheckpointRetry(c *C) {
}
func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) {
- restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond)
+ restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 80*time.Millisecond)
defer restore()
boom := errors.New("boom")
@@ -272,7 +304,7 @@ func (ss *stateSuite) TestImplicitCheckpointPanicsAfterFailedRetries(c *C) {
}
func (ss *stateSuite) TestImplicitCheckpointModifiedOnly(c *C) {
- restore := state.FakeCheckpointRetryDelay(2*time.Millisecond, 1*time.Second)
+ restore := state.MockCheckpointRetryDelay(2*time.Millisecond, 1*time.Second)
defer restore()
b := &fakeStateBackend{}
@@ -421,7 +453,7 @@ func (ss *stateSuite) TestNewTaskAndCheckpoint(c *C) {
t1ID := t1.ID()
t1.Set("a", 1)
t1.SetStatus(state.DoneStatus)
- t1.SetProgress("foo", 5, 10)
+ t1.SetProgress("snap", 5, 10)
t1.JoinLane(42)
t1.JoinLane(43)
@@ -702,7 +734,7 @@ func (ss *stateSuite) TestMethodEntrance(c *C) {
func() { st.Tasks() },
func() { st.Task("foo") },
func() { st.MarshalJSON() },
- func() { st.Prune(time.Hour, time.Hour, 100) },
+ func() { st.Prune(time.Now(), time.Hour, time.Hour, 100) },
func() { st.TaskCount() },
func() { st.AllWarnings() },
func() { st.PendingWarnings() },
@@ -744,31 +776,32 @@ func (ss *stateSuite) TestPrune(c *C) {
chg1 := st.NewChange("abort", "...")
chg1.AddTask(t1)
- state.FakeChangeTimes(chg1, now.Add(-abortWait), unset)
+ state.MockChangeTimes(chg1, now.Add(-abortWait), unset)
chg2 := st.NewChange("prune", "...")
chg2.AddTask(t2)
c.Assert(chg2.Status(), Equals, state.DoStatus)
- state.FakeChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait))
+ state.MockChangeTimes(chg2, now.Add(-pruneWait), now.Add(-pruneWait))
chg3 := st.NewChange("ready-but-recent", "...")
chg3.AddTask(t3)
- state.FakeChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2))
+ state.MockChangeTimes(chg3, now.Add(-pruneWait), now.Add(-pruneWait/2))
chg4 := st.NewChange("old-but-not-ready", "...")
chg4.AddTask(t4)
- state.FakeChangeTimes(chg4, now.Add(-pruneWait/2), unset)
+ state.MockChangeTimes(chg4, now.Add(-pruneWait/2), unset)
// unlinked task
t5 := st.NewTask("unliked", "...")
c.Check(st.Task(t5.ID()), IsNil)
- state.FakeTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait))
+ state.MockTaskTimes(t5, now.Add(-pruneWait), now.Add(-pruneWait))
// two warnings, one expired
st.AddWarning("hello", now, never, time.Nanosecond, state.DefaultRepeatAfter)
st.Warnf("hello again")
- st.Prune(pruneWait, abortWait, 100)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
c.Assert(st.Change(chg1.ID()), Equals, chg1)
c.Assert(st.Change(chg2.ID()), IsNil)
@@ -793,6 +826,55 @@ func (ss *stateSuite) TestPrune(c *C) {
c.Check(st.AllWarnings(), HasLen, 1)
}
+func (ss *stateSuite) TestRegisterPendingChangeByAttr(c *C) {
+ st := state.New(&fakeStateBackend{})
+ st.Lock()
+ defer st.Unlock()
+
+ now := time.Now()
+ pruneWait := 1 * time.Hour
+ abortWait := 3 * time.Hour
+
+ unset := time.Time{}
+
+ t1 := st.NewTask("foo", "...")
+ t2 := st.NewTask("foo", "...")
+ t3 := st.NewTask("foo", "...")
+ t4 := st.NewTask("foo", "...")
+
+ chg1 := st.NewChange("abort", "...")
+ chg1.AddTask(t1)
+ chg1.AddTask(t2)
+ state.MockChangeTimes(chg1, now.Add(-abortWait), unset)
+
+ chg2 := st.NewChange("pending", "...")
+ chg2.AddTask(t3)
+ chg2.AddTask(t4)
+ state.MockChangeTimes(chg2, now.Add(-abortWait), unset)
+ chg2.Set("pending-flag", true)
+ t3.SetStatus(state.HoldStatus)
+
+ st.RegisterPendingChangeByAttr("pending-flag", func(chg *state.Change) bool {
+ c.Check(chg.ID(), Equals, chg2.ID())
+ return true
+ })
+
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
+
+ c.Assert(st.Change(chg1.ID()), Equals, chg1)
+ c.Assert(st.Change(chg2.ID()), Equals, chg2)
+ c.Assert(st.Task(t1.ID()), Equals, t1)
+ c.Assert(st.Task(t2.ID()), Equals, t2)
+ c.Assert(st.Task(t3.ID()), Equals, t3)
+ c.Assert(st.Task(t4.ID()), Equals, t4)
+
+ c.Assert(t1.Status(), Equals, state.HoldStatus)
+ c.Assert(t2.Status(), Equals, state.HoldStatus)
+ c.Assert(t3.Status(), Equals, state.HoldStatus)
+ c.Assert(t4.Status(), Equals, state.DoStatus)
+}
+
func (ss *stateSuite) TestPruneEmptyChange(c *C) {
// Empty changes are a bit special because they start out on Hold
// which is a Ready status, but the change itself is not considered Ready
@@ -807,9 +889,10 @@ func (ss *stateSuite) TestPruneEmptyChange(c *C) {
abortWait := 3 * time.Hour
chg := st.NewChange("abort", "...")
- state.FakeChangeTimes(chg, now.Add(-pruneWait), time.Time{})
+ state.MockChangeTimes(chg, now.Add(-pruneWait), time.Time{})
- st.Prune(pruneWait, abortWait, 100)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, pruneWait, abortWait, 100)
c.Assert(st.Change(chg.ID()), IsNil)
}
@@ -831,7 +914,7 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) {
t.SetStatus(state.DoneStatus)
when := time.Duration(i) * time.Second
- state.FakeChangeTimes(chg, now.Add(-when), now.Add(-when))
+ state.MockChangeTimes(chg, now.Add(-when), now.Add(-when))
}
c.Assert(st.Changes(), HasLen, 10)
@@ -844,13 +927,14 @@ func (ss *stateSuite) TestPruneMaxChangesHappy(c *C) {
// test that nothing is done when we are within pruneWait and
// maxReadyChanges
+ past := time.Now().AddDate(-1, 0, 0)
maxReadyChanges := 100
- st.Prune(pruneWait, abortWait, maxReadyChanges)
+ st.Prune(past, pruneWait, abortWait, maxReadyChanges)
c.Assert(st.Changes(), HasLen, 15)
// but with maxReadyChanges we remove the ready ones
maxReadyChanges = 5
- st.Prune(pruneWait, abortWait, maxReadyChanges)
+ st.Prune(past, pruneWait, abortWait, maxReadyChanges)
c.Assert(st.Changes(), HasLen, 10)
remaining := map[string]bool{}
for _, chg := range st.Changes() {
@@ -886,8 +970,9 @@ func (ss *stateSuite) TestPruneMaxChangesSomeNotReady(c *C) {
c.Assert(st.Changes(), HasLen, 10)
// nothing can be pruned
+ past := time.Now().AddDate(-1, 0, 0)
maxChanges := 5
- st.Prune(1*time.Hour, 3*time.Hour, maxChanges)
+ st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges)
c.Assert(st.Changes(), HasLen, 10)
}
@@ -905,10 +990,10 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) {
c.Assert(st.Changes(), HasLen, 10)
// one extra change that just now entered ready state
- chg := st.NewChange(fmt.Sprintf("chg99"), "so-ready")
+ chg := st.NewChange("chg99", "so-ready")
t := st.NewTask("foo", "so-ready")
when := 1 * time.Second
- state.FakeChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when))
+ state.MockChangeTimes(chg, time.Now().Add(-when), time.Now().Add(-when))
t.SetStatus(state.DoneStatus)
chg.AddTask(t)
@@ -916,11 +1001,45 @@ func (ss *stateSuite) TestPruneMaxChangesHonored(c *C) {
//
// this test we do not purge the freshly ready change
maxChanges := 10
- st.Prune(1*time.Hour, 3*time.Hour, maxChanges)
+ past := time.Now().AddDate(-1, 0, 0)
+ st.Prune(past, 1*time.Hour, 3*time.Hour, maxChanges)
c.Assert(st.Changes(), HasLen, 11)
}
-func (ss *stateSuite) TestReadStateInitsCache(c *C) {
+func (ss *stateSuite) TestPruneHonorsStartOperationTime(c *C) {
+ st := state.New(&fakeStateBackend{})
+ st.Lock()
+ defer st.Unlock()
+
+ now := time.Now()
+
+ startTime := 2 * time.Hour
+ spawnTime := 10 * time.Hour
+ pruneWait := 1 * time.Hour
+ abortWait := 3 * time.Hour
+
+ chg := st.NewChange("change", "...")
+ t := st.NewTask("foo", "")
+ chg.AddTask(t)
+ // change spawned 10h ago
+ state.MockChangeTimes(chg, now.Add(-spawnTime), time.Time{})
+
+ // start operation time is 2h ago, change is not aborted because
+ // it's less than abortWait limit.
+ opTime := now.Add(-startTime)
+ st.Prune(opTime, pruneWait, abortWait, 100)
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.DoStatus)
+
+ // start operation time is 9h ago, change is aborted.
+ startTime = 9 * time.Hour
+ opTime = time.Now().Add(-startTime)
+ st.Prune(opTime, pruneWait, abortWait, 100)
+ c.Assert(st.Changes(), HasLen, 1)
+ c.Check(chg.Status(), Equals, state.HoldStatus)
+}
+
+func (ss *stateSuite) TestReadStateInitsTransientMapFields(c *C) {
st, err := state.ReadState(nil, bytes.NewBufferString("{}"))
c.Assert(err, IsNil)
st.Lock()
@@ -928,4 +1047,195 @@ func (ss *stateSuite) TestReadStateInitsCache(c *C) {
st.Cache("key", "value")
c.Assert(st.Cached("key"), Equals, "value")
+ st.RegisterPendingChangeByAttr("attr", func(*state.Change) bool { return false })
+}
+
+func (ss *stateSuite) TestTimingsSupport(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var tims []int
+
+ err := st.GetMaybeTimings(&tims)
+ c.Assert(err, IsNil)
+ c.Check(tims, IsNil)
+
+ st.SaveTimings([]int{1, 2, 3})
+
+ err = st.GetMaybeTimings(&tims)
+ c.Assert(err, IsNil)
+ c.Check(tims, DeepEquals, []int{1, 2, 3})
+}
+
+func (ss *stateSuite) TestNoStateErrorIs(c *C) {
+ err := &state.NoStateError{Key: "foo"}
+ c.Assert(err, testutil.ErrorIs, &state.NoStateError{})
+ c.Assert(err, testutil.ErrorIs, &state.NoStateError{Key: "bar"})
+ c.Assert(err, testutil.ErrorIs, state.ErrNoState)
+}
+
+func (ss *stateSuite) TestNoStateErrorString(c *C) {
+ err := &state.NoStateError{}
+ c.Assert(err.Error(), Equals, `no state entry for key`)
+ err.Key = "foo"
+ c.Assert(err.Error(), Equals, `no state entry for key "foo"`)
+}
+
+type taskAndStatus struct {
+ t *state.Task
+ old, new state.Status
+}
+
+func (ss *stateSuite) TestTaskChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var taskObservedChanges []taskAndStatus
+ oId := st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) {
+ taskObservedChanges = append(taskObservedChanges, taskAndStatus{
+ t: t,
+ old: old,
+ new: new,
+ })
+ })
+
+ t1 := st.NewTask("foo", "...")
+
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task status to identical status, we don't want
+ // task events when task don't actually change status.
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task to done.
+ t1.SetStatus(state.DoneStatus)
+
+ // Unregister us, and make sure we do not receive more events.
+ st.RemoveTaskStatusChangedHandler(oId)
+
+ // must not appear in list.
+ t1.SetStatus(state.DoingStatus)
+
+ c.Check(taskObservedChanges, DeepEquals, []taskAndStatus{
+ {
+ t: t1,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ t: t1,
+ old: state.DoingStatus,
+ new: state.DoneStatus,
+ },
+ })
+}
+
+type changeAndStatus struct {
+ chg *state.Change
+ old, new state.Status
+}
+
+func (ss *stateSuite) TestChangeChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var observedChanges []changeAndStatus
+ oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) {
+ observedChanges = append(observedChanges, changeAndStatus{
+ chg: chg,
+ old: old,
+ new: new,
+ })
+ })
+
+ chg := st.NewChange("test-chg", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task status to identical status, we don't want
+ // change events when changes don't actually change status.
+ t1.SetStatus(state.DoingStatus)
+
+ // Set task to waiting
+ t1.SetToWait(state.DoneStatus)
+
+ // Unregister us, and make sure we do not receive more events.
+ st.RemoveChangeStatusChangedHandler(oId)
+
+ // must not appear in list.
+ t1.SetStatus(state.DoneStatus)
+
+ c.Check(observedChanges, DeepEquals, []changeAndStatus{
+ {
+ chg: chg,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ chg: chg,
+ old: state.DoingStatus,
+ new: state.WaitStatus,
+ },
+ })
+}
+
+func (ss *stateSuite) TestChangeSetStatusChangedHandler(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ var observedChanges []changeAndStatus
+ oId := st.AddChangeStatusChangedHandler(func(chg *state.Change, old, new state.Status) {
+ observedChanges = append(observedChanges, changeAndStatus{
+ chg: chg,
+ old: old,
+ new: new,
+ })
+ })
+
+ chg := st.NewChange("test-chg", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+
+ t1.SetStatus(state.DoingStatus)
+
+ // We have a single task in Doing, now we manipulate the status
+ // of the change to ensure we are receiving correct events
+ chg.SetStatus(state.WaitStatus)
+
+ // Change to a new status
+ chg.SetStatus(state.ErrorStatus)
+
+ // Now return the status back to Default, which should result
+ // in the change reporting Doing
+ chg.SetStatus(state.DefaultStatus)
+ st.RemoveChangeStatusChangedHandler(oId)
+
+ c.Check(observedChanges, DeepEquals, []changeAndStatus{
+ {
+ chg: chg,
+ old: state.DefaultStatus,
+ new: state.DoingStatus,
+ },
+ {
+ chg: chg,
+ old: state.DoingStatus,
+ new: state.WaitStatus,
+ },
+ {
+ chg: chg,
+ old: state.WaitStatus,
+ new: state.ErrorStatus,
+ },
+ {
+ chg: chg,
+ old: state.ErrorStatus,
+ new: state.DoingStatus,
+ },
+ })
}
diff --git a/internals/overlord/state/task.go b/internals/overlord/state/task.go
index a3e2a486..fb5ec826 100644
--- a/internals/overlord/state/task.go
+++ b/internals/overlord/state/task.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -38,19 +38,22 @@ type progress struct {
//
// See Change for more details.
type Task struct {
- state *State
- id string
- kind string
- summary string
- status Status
- clean bool
- progress *progress
- data customData
- waitTasks []string
- haltTasks []string
- lanes []int
- log []string
- change string
+ state *State
+ id string
+ kind string
+ summary string
+ status Status
+ // waitedStatus is the Status that should be used instead of
+ // WaitStatus once the wait is complete (i.e post reboot).
+ waitedStatus Status
+ clean bool
+ progress *progress
+ data customData
+ waitTasks []string
+ haltTasks []string
+ lanes []int
+ log []string
+ change string
spawnTime time.Time
readyTime time.Time
@@ -77,18 +80,19 @@ func newTask(state *State, id, kind, summary string) *Task {
}
type marshalledTask struct {
- ID string `json:"id"`
- Kind string `json:"kind"`
- Summary string `json:"summary"`
- Status Status `json:"status"`
- Clean bool `json:"clean,omitempty"`
- Progress *progress `json:"progress,omitempty"`
- Data map[string]*json.RawMessage `json:"data,omitempty"`
- WaitTasks []string `json:"wait-tasks,omitempty"`
- HaltTasks []string `json:"halt-tasks,omitempty"`
- Lanes []int `json:"lanes,omitempty"`
- Log []string `json:"log,omitempty"`
- Change string `json:"change"`
+ ID string `json:"id"`
+ Kind string `json:"kind"`
+ Summary string `json:"summary"`
+ Status Status `json:"status"`
+ WaitedStatus Status `json:"waited-status"`
+ Clean bool `json:"clean,omitempty"`
+ Progress *progress `json:"progress,omitempty"`
+ Data map[string]*json.RawMessage `json:"data,omitempty"`
+ WaitTasks []string `json:"wait-tasks,omitempty"`
+ HaltTasks []string `json:"halt-tasks,omitempty"`
+ Lanes []int `json:"lanes,omitempty"`
+ Log []string `json:"log,omitempty"`
+ Change string `json:"change"`
SpawnTime time.Time `json:"spawn-time"`
ReadyTime *time.Time `json:"ready-time,omitempty"`
@@ -111,18 +115,19 @@ func (t *Task) MarshalJSON() ([]byte, error) {
atTime = &t.atTime
}
return json.Marshal(marshalledTask{
- ID: t.id,
- Kind: t.kind,
- Summary: t.summary,
- Status: t.status,
- Clean: t.clean,
- Progress: t.progress,
- Data: t.data,
- WaitTasks: t.waitTasks,
- HaltTasks: t.haltTasks,
- Lanes: t.lanes,
- Log: t.log,
- Change: t.change,
+ ID: t.id,
+ Kind: t.kind,
+ Summary: t.summary,
+ Status: t.status,
+ WaitedStatus: t.waitedStatus,
+ Clean: t.clean,
+ Progress: t.progress,
+ Data: t.data,
+ WaitTasks: t.waitTasks,
+ HaltTasks: t.haltTasks,
+ Lanes: t.lanes,
+ Log: t.log,
+ Change: t.change,
SpawnTime: t.spawnTime,
ReadyTime: readyTime,
@@ -148,6 +153,13 @@ func (t *Task) UnmarshalJSON(data []byte) error {
t.kind = unmarshalled.Kind
t.summary = unmarshalled.Summary
t.status = unmarshalled.Status
+ t.waitedStatus = unmarshalled.WaitedStatus
+ if t.waitedStatus == DefaultStatus {
+ // For backwards-compatibility, default the waitStatus, which is
+ // the result status after a wait, to DoneStatus to keep any previous
+ // behaviour before any upgrade.
+ t.waitedStatus = DoneStatus
+ }
t.clean = unmarshalled.Clean
t.progress = unmarshalled.Progress
custData := unmarshalled.Data
@@ -188,6 +200,40 @@ func (t *Task) Summary() string {
}
// Status returns the current task status.
+//
+// Possible state transitions:
+//
+// /----aborting lane--Do
+// | |
+// V V
+// Hold Doing-->Wait
+// ^ / | \
+// | abort / V V
+// no undo / Done Error
+// | V |
+// \----------Abort aborting lane
+// / | |
+// | finished or |
+// running not running |
+// V \------->|
+// kill goroutine |
+// | V
+// / \ ----->Undo
+// / no error / |
+// | from goroutine |
+// error |
+// from goroutine |
+// | V
+// | Undoing-->Wait
+// V | \
+// Error V V
+// Undone Error
+//
+// Do -> Doing -> Done is the direct succcess scenario.
+//
+// Wait can transition to its waited status,
+// usually Done|Undone or back to Doing.
+// See Wait struct, SetToWait and WaitedStatus.
func (t *Task) Status() Status {
t.state.reading()
if t.status == DefaultStatus {
@@ -196,10 +242,10 @@ func (t *Task) Status() Status {
return t.status
}
-// SetStatus sets the task status, overriding the default behavior (see Status method).
-func (t *Task) SetStatus(new Status) {
- t.state.writing()
- old := t.status
+func (t *Task) changeStatus(old, new Status) {
+ if old == new {
+ return
+ }
t.status = new
if !old.Ready() && new.Ready() {
t.readyTime = timeNow()
@@ -208,6 +254,55 @@ func (t *Task) SetStatus(new Status) {
if chg != nil {
chg.taskStatusChanged(t, old, new)
}
+ t.state.notifyTaskStatusChangedHandlers(t, old, new)
+}
+
+// SetStatus sets the task status, overriding the default behavior (see Status method).
+func (t *Task) SetStatus(new Status) {
+ if new == WaitStatus {
+ panic("Task.SetStatus() called with WaitStatus, which is not allowed. Use SetToWait() instead")
+ }
+
+ t.state.writing()
+ old := t.status
+ if new == DoneStatus && old == AbortStatus {
+ // if the task is in AbortStatus (because some other task ran
+ // in parallel and had an error so the change is aborted) and
+ // DoneStatus was requested (which can happen if the
+ // task handler sets its status explicitly) then keep it at
+ // aborted so it can transition to Undo.
+ return
+ }
+ t.changeStatus(old, new)
+}
+
+// SetToWait puts the task into WaitStatus, and sets the status the task should be restored
+// to after the SetToWait.
+func (t *Task) SetToWait(resultStatus Status) {
+ switch resultStatus {
+ case DefaultStatus, WaitStatus:
+ panic("Task.SetToWait() cannot be invoked with either of DefaultStatus or WaitStatus")
+ }
+
+ t.state.writing()
+ old := t.status
+ if old == AbortStatus {
+ // if the task is in AbortStatus (because some other task ran
+ // in parallel and had an error so the change is aborted) and
+ // WaitStatus was requested (which can happen if the
+ // task handler sets its status explicitly) then keep it at
+ // aborted so it can transition to Undo.
+ return
+ }
+ t.waitedStatus = resultStatus
+ t.changeStatus(old, WaitStatus)
+}
+
+// WaitedStatus returns the status the Task should return to once the current WaitStatus
+// has been resolved.
+func (t *Task) WaitedStatus() Status {
+ t.state.reading()
+ return t.waitedStatus
}
// IsClean returns whether the task has been cleaned. See SetClean.
@@ -320,7 +415,7 @@ const (
var timeNow = time.Now
-func FakeTime(now time.Time) (restore func()) {
+func MockTime(now time.Time) (restore func()) {
timeNow = func() time.Time { return now }
return func() { timeNow = time.Now }
}
@@ -332,7 +427,7 @@ func (t *Task) addLog(kind, format string, args []interface{}) {
}
tstr := timeNow().Format(time.RFC3339)
- msg := fmt.Sprintf(tstr+" "+kind+" "+format, args...)
+ msg := tstr + " " + kind + " " + fmt.Sprintf(format, args...)
t.log = append(t.log, msg)
logger.Debugf(msg)
}
@@ -468,7 +563,7 @@ func (t *Task) At(when time.Time) {
// TaskSetEdge designates tasks inside a TaskSet for outside reference.
//
// This is useful to give tasks inside TaskSets a special meaning. It
-// is used to mark e.g. the last task in a task set.
+// is used to mark e.g. the last task used for downloading a snap.
type TaskSetEdge string
// A TaskSet holds a set of tasks.
@@ -484,10 +579,16 @@ func NewTaskSet(tasks ...*Task) *TaskSet {
return &TaskSet{tasks, nil}
}
-// Edge returns the task marked with the given edge name.
+// MaybeEdge returns the task marked with the given edge name or nil if no such
+// task exists.
+func (ts TaskSet) MaybeEdge(e TaskSetEdge) *Task {
+ return ts.edges[e]
+}
+
+// Edge returns the task marked with the given edge name or an error.
func (ts TaskSet) Edge(e TaskSetEdge) (*Task, error) {
- t, ok := ts.edges[e]
- if !ok {
+ t := ts.MaybeEdge(e)
+ if t == nil {
return nil, fmt.Errorf("internal error: missing %q edge in task set", e)
}
return t, nil
@@ -509,7 +610,7 @@ func (ts *TaskSet) WaitAll(anotherTs *TaskSet) {
}
}
-// AddTask adds the the task to the task set.
+// AddTask adds the task to the task set.
func (ts *TaskSet) AddTask(task *Task) {
for _, t := range ts.tasks {
if t == task {
@@ -522,6 +623,9 @@ func (ts *TaskSet) AddTask(task *Task) {
// MarkEdge marks the given task as a specific edge. Any pre-existing
// edge mark will be overridden.
func (ts *TaskSet) MarkEdge(task *Task, edge TaskSetEdge) {
+ if task == nil {
+ panic(fmt.Sprintf("cannot set edge %q with nil task", edge))
+ }
if ts.edges == nil {
ts.edges = make(map[TaskSetEdge]*Task)
}
diff --git a/internals/overlord/state/task_test.go b/internals/overlord/state/task_test.go
index 9c073d80..d1e26e54 100644
--- a/internals/overlord/state/task_test.go
+++ b/internals/overlord/state/task_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -127,7 +127,7 @@ func (ts *taskSuite) TestClear(c *C) {
t.Clear("a")
- c.Check(t.Get("a", &v), Equals, state.ErrNoState)
+ c.Check(t.Get("a", &v), testutil.ErrorIs, state.ErrNoState)
}
func (ts *taskSuite) TestStatusAndSetStatus(c *C) {
@@ -144,6 +144,60 @@ func (ts *taskSuite) TestStatusAndSetStatus(c *C) {
c.Check(t.Status(), Equals, state.DoneStatus)
}
+func (ts *taskSuite) TestSetDoneAfterAbortNoop(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetStatus(state.AbortStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ t.SetStatus(state.DoneStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+}
+
+func (ts *taskSuite) TestSetWaitAfterAbortNoop(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetStatus(state.AbortStatus)
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ t.SetToWait(state.DoneStatus) // noop
+ c.Check(t.Status(), Equals, state.AbortStatus)
+ c.Check(t.WaitedStatus(), Equals, state.DefaultStatus)
+}
+
+func (ts *taskSuite) TestSetWait(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t := st.NewTask("download", "1...")
+ t.SetToWait(state.DoneStatus)
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, state.DoneStatus)
+ t.SetToWait(state.UndoStatus)
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, state.UndoStatus)
+}
+
+func (ts *taskSuite) TestTaskMarshalsWaitStatus(c *C) {
+ st := state.New(nil)
+ st.Lock()
+ defer st.Unlock()
+
+ t1 := st.NewTask("download", "1...")
+ t1.SetToWait(state.UndoStatus)
+
+ d, err := t1.MarshalJSON()
+ c.Assert(err, IsNil)
+
+ needle := fmt.Sprintf(`"waited-status":%d`, t1.WaitedStatus())
+ c.Assert(string(d), testutil.Contains, needle)
+}
+
func (ts *taskSuite) TestIsCleanAndSetClean(c *C) {
st := state.New(nil)
st.Lock()
@@ -174,9 +228,9 @@ func (ts *taskSuite) TestProgressAndSetProgress(c *C) {
t := st.NewTask("download", "1...")
- t.SetProgress("foo", 2, 99)
+ t.SetProgress("snap", 2, 99)
label, cur, tot := t.Progress()
- c.Check(label, Equals, "foo")
+ c.Check(label, Equals, "snap")
c.Check(cur, Equals, 2)
c.Check(tot, Equals, 99)
@@ -305,7 +359,7 @@ func (ts *taskSuite) TestAt(c *C) {
t := st.NewTask("download", "1...")
now := time.Now()
- restore := state.FakeTime(now)
+ restore := state.MockTime(now)
defer restore()
when := now.Add(10 * time.Second)
t.At(when)
@@ -561,6 +615,10 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) {
// edges are just typed strings
edge1 := state.TaskSetEdge("on-edge")
edge2 := state.TaskSetEdge("eddie")
+ edge3 := state.TaskSetEdge("not-found")
+
+ // nil task causes panic
+ c.Check(func() { ts.MarkEdge(nil, edge1) }, PanicMatches, `cannot set edge "on-edge" with nil task`)
// no edge marked yet
t, err := ts.Edge(edge1)
@@ -590,6 +648,12 @@ func (cs *taskSuite) TestTaskSetEdge(c *C) {
t, err = ts.Edge(edge1)
c.Assert(t, Equals, t3)
c.Assert(err, IsNil)
+
+ // it is possible to check if edge exists without failing
+ t = ts.MaybeEdge(edge1)
+ c.Assert(t, Equals, t3)
+ t = ts.MaybeEdge(edge3)
+ c.Assert(t, IsNil)
}
func (cs *taskSuite) TestTaskAddAllWithEdges(c *C) {
diff --git a/internals/overlord/state/taskrunner.go b/internals/overlord/state/taskrunner.go
index 5d9ff244..c6a2b835 100644
--- a/internals/overlord/state/taskrunner.go
+++ b/internals/overlord/state/taskrunner.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -46,6 +46,23 @@ func (r *Retry) Error() string {
return "task should be retried"
}
+// Wait is returned from a handler to signal that the task cannot
+// proceed at the moment maybe because some manual action from the
+// user required at this point or because of errors. The task
+// will be set to WaitStatus, and it's wait complete status will be
+// set to WaitedStatus.
+type Wait struct {
+ Reason string
+ // If not explicitly set, then WaitedStatus will default to
+ // DoneStatus, meaning that the task will be set to DoneStatus
+ // after the wait has resolved.
+ WaitedStatus Status
+}
+
+func (r *Wait) Error() string {
+ return "task set to wait, manual action required"
+}
+
type blockedFunc func(t *Task, running []*Task) bool
// TaskRunner controls the running of goroutines to execute known task kinds.
@@ -62,6 +79,9 @@ type TaskRunner struct {
blocked []blockedFunc
someBlocked bool
+ // optional callback executed on task errors
+ taskErrorCallback func(err error)
+
// go-routines lifecycle
tombs map[string]*tomb.Tomb
}
@@ -85,6 +105,11 @@ func NewTaskRunner(s *State) *TaskRunner {
}
}
+// OnTaskError sets an error callback executed when any task errors out.
+func (r *TaskRunner) OnTaskError(f func(err error)) {
+ r.taskErrorCallback = f
+}
+
// AddHandler registers the functions to concurrently call for doing and
// undoing tasks of the given kind. The undo handler may be nil.
func (r *TaskRunner) AddHandler(kind string, do, undo HandlerFunc) {
@@ -214,7 +239,7 @@ func (r *TaskRunner) run(t *Task) {
switch err.(type) {
case nil:
// we are ok
- case *Retry:
+ case *Retry, *Wait:
// preserve
default:
if r.stopped {
@@ -227,13 +252,24 @@ func (r *TaskRunner) run(t *Task) {
switch x := err.(type) {
case *Retry:
// Handler asked to be called again later.
- // TODO Allow postponing retries past the next Ensure.
if t.Status() == AbortStatus {
// Would work without it but might take two ensures.
r.tryUndo(t)
} else if x.After != 0 {
t.At(timeNow().Add(x.After))
}
+ case *Wait:
+ if t.Status() == AbortStatus {
+ // Would work without it but might take two ensures.
+ r.tryUndo(t)
+ } else {
+ // Default to DoneStatus if no status is set in Wait
+ waitedStatus := x.WaitedStatus
+ if waitedStatus == DefaultStatus {
+ waitedStatus = DoneStatus
+ }
+ t.SetToWait(waitedStatus)
+ }
case nil:
var next []*Task
switch t.Status() {
@@ -259,6 +295,11 @@ func (r *TaskRunner) run(t *Task) {
r.abortLanes(t.Change(), t.Lanes())
t.SetStatus(ErrorStatus)
t.Errorf("%s", err)
+ // ensure the error is available in the global log too
+ logger.Noticef("[change %s %q task] failed: %v", t.Change().ID(), t.Summary(), err)
+ if r.taskErrorCallback != nil {
+ r.taskErrorCallback(err)
+ }
}
return nil
@@ -388,6 +429,10 @@ ConsiderTasks:
}
continue
}
+ if status == WaitStatus {
+ // nothing more to run
+ continue
+ }
if mustWait(t) {
// Dependencies still unhandled.
diff --git a/internals/overlord/state/taskrunner_test.go b/internals/overlord/state/taskrunner_test.go
index 38c5685a..2b9de82a 100644
--- a/internals/overlord/state/taskrunner_test.go
+++ b/internals/overlord/state/taskrunner_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2016 Canonical Ltd
+ * Copyright (C) 2016-2022 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -31,11 +31,14 @@ import (
. "gopkg.in/check.v1"
"gopkg.in/tomb.v2"
- "github.com/canonical/pebble/internals/overlord/restart"
+ "github.com/canonical/pebble/internals/logger"
"github.com/canonical/pebble/internals/overlord/state"
+ "github.com/canonical/pebble/internals/testutil"
)
-type taskRunnerSuite struct{}
+type taskRunnerSuite struct {
+ testutil.BaseTest
+}
var _ = Suite(&taskRunnerSuite{})
@@ -58,8 +61,6 @@ func (b *stateBackend) EnsureBefore(d time.Duration) {
}
}
-func (b *stateBackend) RequestRestart(t restart.RestartType) {}
-
func ensureChange(c *C, r *state.TaskRunner, sb *stateBackend, chg *state.Change) {
for i := 0; i < 20; i++ {
sb.ensureBefore = time.Hour
@@ -142,6 +143,10 @@ var sequenceTests = []struct{ setup, result string }{{
result: "t31:undo t32:do t32:do-error t21:undo",
}}
+func (ts *taskRunnerSuite) SetUpTest(c *C) {
+ ts.BaseTest.SetUpTest(c)
+}
+
func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -185,11 +190,12 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
r.AddHandler("do", fn("do"), nil)
r.AddHandler("do-undo", fn("do"), fn("undo"))
+ past := time.Now().AddDate(-1, 0, 0)
for _, test := range sequenceTests {
st.Lock()
// Delete previous changes.
- st.Prune(1, 1, 1)
+ st.Prune(past, 1, 1, 1)
chg := st.NewChange("install", "...")
tasks := make(map[string]*state.Task)
@@ -342,6 +348,206 @@ func (ts *taskRunnerSuite) TestSequenceTests(c *C) {
}
}
+func (ts *taskRunnerSuite) TestAbortAcrossLanesDescendantTask(c *C) {
+
+ // ()
+ // t11(1) -> t12(1) => t15(1)
+ // \ /
+ // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2)
+ // / \
+ // t21(2) -> t22(2) => t25(2)
+ //
+ names := strings.Fields("t11 t12 t13 t14 t15 t21 t22 t23 t24 t25")
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+ tasks["t12"].WaitFor(tasks["t11"])
+ tasks["t13"].WaitFor(tasks["t12"])
+ tasks["t14"].WaitFor(tasks["t13"])
+ tasks["t15"].WaitFor(tasks["t14"])
+ for lane, names := range map[int][]string{
+ 1: {"t11", "t12", "t13", "t14", "t15", "t23", "t24"},
+ 2: {"t21", "t22", "t23", "t24", "t25", "t13", "t14"},
+ } {
+ for _, name := range names {
+ tasks[name].JoinLane(lane)
+ }
+ }
+
+ tasks["t22"].WaitFor(tasks["t21"])
+ tasks["t23"].WaitFor(tasks["t22"])
+ tasks["t24"].WaitFor(tasks["t23"])
+ tasks["t25"].WaitFor(tasks["t24"])
+
+ tasks["t13"].WaitFor(tasks["t22"])
+ tasks["t15"].WaitFor(tasks["t24"])
+ tasks["t23"].WaitFor(tasks["t14"])
+
+ ch := make(chan string, 256)
+ do := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("do %q", task.Summary())
+ label := task.Summary()
+ if label == "t15" {
+ ch <- "t15:error"
+ return fmt.Errorf("mock error")
+ }
+ ch <- fmt.Sprintf("%s:do", label)
+ return nil
+ }
+ undo := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("undo %q", task.Summary())
+ label := task.Summary()
+ ch <- fmt.Sprintf("%s:undo", label)
+ return nil
+ }
+ r.AddHandler("do", do, undo)
+
+ c.Logf("-----")
+
+ st.Unlock()
+ ensureChange(c, r, sb, chg)
+ st.Lock()
+ close(ch)
+ var sequence []string
+ for event := range ch {
+ sequence = append(sequence, event)
+ }
+ for _, name := range names {
+ task := tasks[name]
+ c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status())
+ }
+ c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{
+ "t11:do", "t12:do",
+ "t21:do", "t22:do",
+ })
+ c.Assert(sequence[4:8], DeepEquals, []string{
+ "t13:do", "t14:do", "t23:do", "t24:do",
+ })
+ c.Assert(sequence[8:10], testutil.DeepUnsortedMatches, []string{
+ "t25:do",
+ "t15:error",
+ })
+ c.Assert(sequence[10:11], testutil.DeepUnsortedMatches, []string{
+ "t25:undo",
+ })
+ c.Assert(sequence[11:15], DeepEquals, []string{
+ "t24:undo", "t23:undo", "t14:undo", "t13:undo",
+ })
+ c.Assert(sequence[15:19], testutil.DeepUnsortedMatches, []string{
+ "t21:undo", "t22:undo",
+ "t12:undo", "t11:undo",
+ })
+}
+
+func (ts *taskRunnerSuite) TestAbortAcrossLanesStriclyOrderedTasks(c *C) {
+
+ // ()
+ // t11(1) -> t12(1)
+ // \
+ // => t13(1,2) => t14(1,2) => t23(1,2) => t24(1,2)
+ // /
+ // t21(2) -> t22(2)
+ //
+ names := strings.Fields("t11 t12 t13 t14 t21 t22 t23 t24")
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Assert(len(st.Tasks()), Equals, 0)
+
+ chg := st.NewChange("install", "...")
+ tasks := make(map[string]*state.Task)
+ for _, name := range names {
+ tasks[name] = st.NewTask("do", name)
+ chg.AddTask(tasks[name])
+ }
+ tasks["t12"].WaitFor(tasks["t11"])
+ tasks["t13"].WaitFor(tasks["t12"])
+ tasks["t14"].WaitFor(tasks["t13"])
+ for lane, names := range map[int][]string{
+ 1: {"t11", "t12", "t13", "t14", "t23", "t24"},
+ 2: {"t21", "t22", "t23", "t24", "t13", "t14"},
+ } {
+ for _, name := range names {
+ tasks[name].JoinLane(lane)
+ }
+ }
+
+ tasks["t22"].WaitFor(tasks["t21"])
+ tasks["t23"].WaitFor(tasks["t22"])
+ tasks["t24"].WaitFor(tasks["t23"])
+
+ tasks["t13"].WaitFor(tasks["t22"])
+ tasks["t23"].WaitFor(tasks["t14"])
+
+ ch := make(chan string, 256)
+ do := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("do %q", task.Summary())
+ label := task.Summary()
+ if label == "t24" {
+ ch <- "t24:error"
+ return fmt.Errorf("mock error")
+ }
+ ch <- fmt.Sprintf("%s:do", label)
+ return nil
+ }
+ undo := func(task *state.Task, tomb *tomb.Tomb) error {
+ c.Logf("undo %q", task.Summary())
+ label := task.Summary()
+ ch <- fmt.Sprintf("%s:undo", label)
+ return nil
+ }
+ r.AddHandler("do", do, undo)
+
+ c.Logf("-----")
+
+ st.Unlock()
+ ensureChange(c, r, sb, chg)
+ st.Lock()
+ close(ch)
+ var sequence []string
+ for event := range ch {
+ sequence = append(sequence, event)
+ }
+ for _, name := range names {
+ task := tasks[name]
+ c.Logf("%5s %5s lanes: %v status: %v", task.ID(), task.Summary(), task.Lanes(), task.Status())
+ }
+ c.Assert(sequence[:4], testutil.DeepUnsortedMatches, []string{
+ "t11:do", "t12:do",
+ "t21:do", "t22:do",
+ })
+ c.Assert(sequence[4:8], DeepEquals, []string{
+ "t13:do", "t14:do", "t23:do", "t24:error",
+ })
+ c.Assert(sequence[8:11], DeepEquals, []string{
+ "t23:undo", "t14:undo", "t13:undo",
+ })
+ c.Assert(sequence[11:], testutil.DeepUnsortedMatches, []string{
+ "t21:undo", "t22:undo",
+ "t12:undo", "t11:undo",
+ })
+}
+
func (ts *taskRunnerSuite) TestExternalAbort(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -372,6 +578,101 @@ func (ts *taskRunnerSuite) TestExternalAbort(c *C) {
ensureChange(c, r, sb, chg)
}
+func (ts *taskRunnerSuite) TestUndoSingleLane(c *C) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ r.AddHandler("noop", func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ }, func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ })
+
+ r.AddHandler("noop-slow", func(t *state.Task, tb *tomb.Tomb) error {
+ time.Sleep(10 * time.Millisecond)
+ t.State().Lock()
+ defer t.State().Unlock()
+ // critical
+ t.SetStatus(state.DoneStatus)
+ return nil
+ }, func(t *state.Task, tb *tomb.Tomb) error {
+ return nil
+ })
+
+ r.AddHandler("fail", func(t *state.Task, tb *tomb.Tomb) error {
+ return fmt.Errorf("fail")
+ }, nil)
+
+ st.Lock()
+
+ lane := st.NewLane()
+ chg := st.NewChange("install", "...")
+
+ // first taskset
+ var prev *state.Task
+ for i := 0; i < 10; i++ {
+ t := st.NewTask("noop-slow", "...")
+ if prev != nil {
+ t.WaitFor(prev)
+ }
+ chg.AddTask(t)
+ t.JoinLane(lane)
+
+ prev = t
+ }
+
+ // second taskset with a failing task that triggers undo of the change
+ prev = nil
+ for i := 0; i < 10; i++ {
+ t := st.NewTask("noop", "...")
+ if prev != nil {
+ t.WaitFor(prev)
+ }
+ chg.AddTask(t)
+ t.JoinLane(lane)
+ prev = t
+ }
+
+ // error trigger
+ t := st.NewTask("fail", "...")
+ t.WaitFor(prev)
+ chg.AddTask(t)
+ t.JoinLane(lane)
+
+ st.Unlock()
+
+ var done bool
+ for !done {
+ c.Assert(r.Ensure(), Equals, nil)
+ st.Lock()
+ done = chg.IsReady() && chg.IsClean()
+ st.Unlock()
+ }
+
+ st.Lock()
+ defer st.Unlock()
+
+ // make sure all tasks are either undone or on hold (except for "fail" task which
+ // is in error).
+ for _, t := range st.Tasks() {
+ switch t.Kind() {
+ case "fail":
+ c.Assert(t.Status(), Equals, state.ErrorStatus)
+ case "noop", "noop-slow":
+ if t.Status() != state.UndoneStatus && t.Status() != state.HoldStatus {
+ for _, tsk := range st.Tasks() {
+ fmt.Printf("%s -> %s\n", tsk.Kind(), tsk.Status())
+ }
+ c.Fatalf("unexpected status: %s", t.Status())
+ }
+ default:
+ c.Fatalf("unexpected kind: %s", t.Kind())
+ }
+ }
+}
+
func (ts *taskRunnerSuite) TestStopHandlerJustFinishing(c *C) {
sb := &stateBackend{}
st := state.New(sb)
@@ -508,6 +809,55 @@ func (ts *taskRunnerSuite) TestStopAskForRetry(c *C) {
c.Check(t.AtTime().IsZero(), Equals, false)
}
+func (ts *taskRunnerSuite) testTaskReturningWait(c *C, waitedStatus, expectedStatus state.Status) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+ defer r.Stop()
+
+ r.AddHandler("ask-for-wait", func(t *state.Task, tb *tomb.Tomb) error {
+ // ask for wait
+ return &state.Wait{WaitedStatus: waitedStatus}
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "...")
+ t := st.NewTask("ask-for-wait", "...")
+ chg.AddTask(t)
+ st.Unlock()
+
+ r.Ensure()
+ // wait for handler to finish
+ r.Wait()
+
+ st.Lock()
+ defer st.Unlock()
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(t.WaitedStatus(), Equals, expectedStatus)
+ c.Check(chg.Status().Ready(), Equals, false)
+
+ st.Unlock()
+ defer st.Lock()
+ // does nothing
+ r.Ensure()
+
+ // state is unchanged
+ st.Lock()
+ defer st.Unlock()
+ c.Check(t.Status(), Equals, state.WaitStatus)
+ c.Check(chg.Status().Ready(), Equals, false)
+}
+
+func (ts *taskRunnerSuite) TestTaskReturningWaitNormal(c *C) {
+ ts.testTaskReturningWait(c, state.UndoneStatus, state.UndoneStatus)
+}
+
+func (ts *taskRunnerSuite) TestTaskReturningWaitDefaultStatus(c *C) {
+ // If no state was set (DefaultStatus), then it should default to
+ // DoneStatus instead.
+ ts.testTaskReturningWait(c, state.DefaultStatus, state.DoneStatus)
+}
+
func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
ensureBeforeTick := make(chan bool, 1)
sb := &stateBackend{
@@ -536,7 +886,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
st.Unlock()
tock := time.Now()
- restore := state.FakeTime(tock)
+ restore := state.MockTime(tock)
defer restore()
r.Ensure() // will run and be rescheduled in a minute
select {
@@ -554,7 +904,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
schedule := t.AtTime()
c.Check(schedule.IsZero(), Equals, false)
- state.FakeTime(tock.Add(5 * time.Second))
+ state.MockTime(tock.Add(5 * time.Second))
sb.ensureBefore = time.Hour
st.Unlock()
r.Ensure() // too soon
@@ -565,7 +915,7 @@ func (ts *taskRunnerSuite) TestRetryAfterDuration(c *C) {
c.Check(sb.ensureBefore, Equals, 55*time.Second)
c.Check(t.AtTime().Equal(schedule), Equals, true)
- state.FakeTime(schedule)
+ state.MockTime(schedule)
sb.ensureBefore = time.Hour
st.Unlock()
r.Ensure() // time to run again
@@ -823,7 +1173,7 @@ func (ts *taskRunnerSuite) TestUndoSequence(c *C) {
terr.WaitFor(prev)
chg.AddTask(terr)
- c.Check(chg.Tasks(), HasLen, 9) // sanity check
+ c.Check(chg.Tasks(), HasLen, 9) // validity check
st.Unlock()
@@ -910,3 +1260,71 @@ func (ts *taskRunnerSuite) TestCleanup(c *C) {
c.Assert(chgIsClean(), Equals, true)
c.Assert(called, Equals, 2)
}
+
+func (ts *taskRunnerSuite) TestErrorCallbackCalledOnError(c *C) {
+ logbuf, restore := logger.MockLogger("PREFIX: ")
+ defer restore()
+
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+
+ var called bool
+ r.OnTaskError(func(err error) {
+ called = true
+ })
+
+ r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error {
+ return fmt.Errorf("handler error for %q", t.Kind())
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "change summary")
+ t1 := st.NewTask("foo", "task summary")
+ chg.AddTask(t1)
+ st.Unlock()
+
+ // Mark tasks as done.
+ ensureChange(c, r, sb, chg)
+ r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(t1.Status(), Equals, state.ErrorStatus)
+ c.Check(strings.Join(t1.Log(), ""), Matches, `.*handler error for "foo"`)
+ c.Check(called, Equals, true)
+
+ c.Check(logbuf.String(), Matches, `(?m).*: \[change 1 "task summary" task\] failed: handler error for "foo".*`)
+}
+
+func (ts *taskRunnerSuite) TestErrorCallbackNotCalled(c *C) {
+ sb := &stateBackend{}
+ st := state.New(sb)
+ r := state.NewTaskRunner(st)
+
+ var called bool
+ r.OnTaskError(func(err error) {
+ called = true
+ })
+
+ r.AddHandler("foo", func(t *state.Task, tomb *tomb.Tomb) error {
+ return nil
+ }, nil)
+
+ st.Lock()
+ chg := st.NewChange("install", "...")
+ t1 := st.NewTask("foo", "...")
+ chg.AddTask(t1)
+ st.Unlock()
+
+ // Mark tasks as done.
+ ensureChange(c, r, sb, chg)
+ r.Stop()
+
+ st.Lock()
+ defer st.Unlock()
+
+ c.Check(t1.Status(), Equals, state.DoneStatus)
+ c.Check(called, Equals, false)
+}
diff --git a/internals/overlord/state/warning.go b/internals/overlord/state/warning.go
index 69dac7c8..6211bf37 100644
--- a/internals/overlord/state/warning.go
+++ b/internals/overlord/state/warning.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2018 Canonical Ltd
+ * Copyright (C) 2018 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
diff --git a/internals/overlord/state/warning_test.go b/internals/overlord/state/warning_test.go
index bf30f69e..26479694 100644
--- a/internals/overlord/state/warning_test.go
+++ b/internals/overlord/state/warning_test.go
@@ -1,7 +1,7 @@
// -*- Mode: Go; indent-tabs-mode: t -*-
/*
- * Copyright (c) 2018 Canonical Ltd
+ * Copyright (C) 2018 Canonical Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 3 as
@@ -99,7 +99,7 @@ func (stateSuite) TestUnmarshalErrors(c *check.C) {
}
for _, t := range []T1{
- // sanity check
+ // validity check
{`{"message": "x", "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, nil},
// remove one field at a time:
{`{ "first-added": "2006-01-02T15:04:05Z", "expire-after": "1h", "repeat-after": "1h"}`, state.ErrNoWarningMessage},
diff --git a/internals/overlord/stateengine.go b/internals/overlord/stateengine.go
index 25710333..b468aea5 100644
--- a/internals/overlord/stateengine.go
+++ b/internals/overlord/stateengine.go
@@ -1,16 +1,21 @@
-// Copyright (c) 2014-2020 Canonical Ltd
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License version 3 as
-// published by the Free Software Foundation.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program. If not, see .
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2016 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
package overlord
@@ -30,10 +35,18 @@ type StateManager interface {
Ensure() error
}
+// StateStarterUp is optionally implemented by StateManager that have expensive
+// initialization to perform before the main Overlord loop.
+type StateStarterUp interface {
+ // StartUp asks manager to perform any expensive initialization.
+ StartUp() error
+}
+
// StateWaiter is optionally implemented by StateManagers that have running
// activities that can be waited.
type StateWaiter interface {
- // Wait asks manager to wait for all running activities to finish.
+ // Wait asks manager to wait for all running activities to
+ // finish.
Wait()
}
@@ -53,8 +66,9 @@ type StateStopper interface {
// cope with Ensure calls in any order, coordinating among themselves
// solely via the state.
type StateEngine struct {
- state *state.State
- stopped bool
+ state *state.State
+ stopped bool
+ startedUp bool
// managers in use
mgrLock sync.Mutex
managers []StateManager
@@ -72,6 +86,37 @@ func (se *StateEngine) State() *state.State {
return se.state
}
+type startupError struct {
+ errs []error
+}
+
+func (e *startupError) Error() string {
+ return fmt.Sprintf("state startup errors: %v", e.errs)
+}
+
+// StartUp asks all managers to perform any expensive initialization. It is a noop after the first invocation.
+func (se *StateEngine) StartUp() error {
+ se.mgrLock.Lock()
+ defer se.mgrLock.Unlock()
+ if se.startedUp {
+ return nil
+ }
+ se.startedUp = true
+ var errs []error
+ for _, m := range se.managers {
+ if starterUp, ok := m.(StateStarterUp); ok {
+ err := starterUp.StartUp()
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) != 0 {
+ return &startupError{errs}
+ }
+ return nil
+}
+
type ensureError struct {
errs []error
}
@@ -91,6 +136,9 @@ func (e *ensureError) Error() string {
func (se *StateEngine) Ensure() error {
se.mgrLock.Lock()
defer se.mgrLock.Unlock()
+ if !se.startedUp {
+ return fmt.Errorf("state engine skipped startup")
+ }
if se.stopped {
return fmt.Errorf("state engine already stopped")
}
diff --git a/internals/overlord/stateengine_test.go b/internals/overlord/stateengine_test.go
index f12ed98b..a427e661 100644
--- a/internals/overlord/stateengine_test.go
+++ b/internals/overlord/stateengine_test.go
@@ -1,16 +1,21 @@
-// Copyright (c) 2014-2020 Canonical Ltd
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License version 3 as
-// published by the Free Software Foundation.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program. If not, see .
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2016 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
package overlord_test
@@ -35,9 +40,14 @@ func (ses *stateEngineSuite) TestNewAndState(c *C) {
}
type fakeManager struct {
- name string
- calls *[]string
- ensureError, stopError error
+ name string
+ calls *[]string
+ ensureError, startupError error
+}
+
+func (fm *fakeManager) StartUp() error {
+ *fm.calls = append(*fm.calls, "startup:"+fm.name)
+ return fm.startupError
}
func (fm *fakeManager) Ensure() error {
@@ -55,6 +65,48 @@ func (fm *fakeManager) Wait() {
var _ overlord.StateManager = (*fakeManager)(nil)
+func (ses *stateEngineSuite) TestStartUp(c *C) {
+ s := state.New(nil)
+ se := overlord.NewStateEngine(s)
+
+ calls := []string{}
+
+ mgr1 := &fakeManager{name: "mgr1", calls: &calls}
+ mgr2 := &fakeManager{name: "mgr2", calls: &calls}
+
+ se.AddManager(mgr1)
+ se.AddManager(mgr2)
+
+ err := se.StartUp()
+ c.Assert(err, IsNil)
+ c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"})
+
+ // noop
+ err = se.StartUp()
+ c.Assert(err, IsNil)
+ c.Check(calls, HasLen, 2)
+}
+
+func (ses *stateEngineSuite) TestStartUpError(c *C) {
+ s := state.New(nil)
+ se := overlord.NewStateEngine(s)
+
+ calls := []string{}
+
+ err1 := errors.New("boom1")
+ err2 := errors.New("boom2")
+
+ mgr1 := &fakeManager{name: "mgr1", calls: &calls, startupError: err1}
+ mgr2 := &fakeManager{name: "mgr2", calls: &calls, startupError: err2}
+
+ se.AddManager(mgr1)
+ se.AddManager(mgr2)
+
+ err := se.StartUp()
+ c.Check(err.Error(), DeepEquals, "state startup errors: [boom1 boom2]")
+ c.Check(calls, DeepEquals, []string{"startup:mgr1", "startup:mgr2"})
+}
+
func (ses *stateEngineSuite) TestEnsure(c *C) {
s := state.New(nil)
se := overlord.NewStateEngine(s)
@@ -68,6 +120,11 @@ func (ses *stateEngineSuite) TestEnsure(c *C) {
se.AddManager(mgr2)
err := se.Ensure()
+ c.Check(err, ErrorMatches, "state engine skipped startup")
+ c.Assert(se.StartUp(), IsNil)
+ calls = []string{}
+
+ err = se.Ensure()
c.Assert(err, IsNil)
c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"})
@@ -91,6 +148,9 @@ func (ses *stateEngineSuite) TestEnsureError(c *C) {
se.AddManager(mgr1)
se.AddManager(mgr2)
+ c.Assert(se.StartUp(), IsNil)
+ calls = []string{}
+
err := se.Ensure()
c.Check(err.Error(), DeepEquals, "state ensure errors: [boom1 boom2]")
c.Check(calls, DeepEquals, []string{"ensure:mgr1", "ensure:mgr2"})
@@ -108,6 +168,9 @@ func (ses *stateEngineSuite) TestStop(c *C) {
se.AddManager(mgr1)
se.AddManager(mgr2)
+ c.Assert(se.StartUp(), IsNil)
+ calls = []string{}
+
se.Stop()
c.Check(calls, DeepEquals, []string{"stop:mgr1", "stop:mgr2"})
se.Stop()
diff --git a/internals/systemd/sdnotify.go b/internals/systemd/sdnotify.go
index 5f1279cf..ed16aceb 100644
--- a/internals/systemd/sdnotify.go
+++ b/internals/systemd/sdnotify.go
@@ -27,7 +27,7 @@ var osGetenv = os.Getenv
func SocketAvailable() bool {
notifySocket := osGetenv("NOTIFY_SOCKET")
- return notifySocket != "" && osutil.CanStat(notifySocket)
+ return notifySocket != "" && osutil.FileExists(notifySocket)
}
// SdNotify sends the given state string notification to systemd.
diff --git a/internals/systemd/systemd.go b/internals/systemd/systemd.go
index c5e11114..0811fbee 100644
--- a/internals/systemd/systemd.go
+++ b/internals/systemd/systemd.go
@@ -713,7 +713,7 @@ func (s *systemd) RemoveMountUnitFile(mountedDir string) error {
}
unit := MountUnitPath(unitNamePath)
- if !osutil.CanStat(unit) {
+ if !osutil.FileExists(unit) {
return nil
}
diff --git a/internals/systemd/systemd_test.go b/internals/systemd/systemd_test.go
index f32c363d..20967271 100644
--- a/internals/systemd/systemd_test.go
+++ b/internals/systemd/systemd_test.go
@@ -567,7 +567,7 @@ WantedBy=multi-user.target
}
func (s *SystemdTestSuite) TestFuseInContainer(c *C) {
- if !osutil.CanStat("/dev/fuse") {
+ if !osutil.FileExists("/dev/fuse") {
c.Skip("No /dev/fuse on the system")
}
@@ -713,7 +713,7 @@ func (s *SystemdTestSuite) TestRemoveMountUnit(c *C) {
c.Assert(err, IsNil)
// the file is gone
- c.Check(osutil.CanStat(mountUnit), Equals, false)
+ c.Check(osutil.FileExists(mountUnit), Equals, false)
// and the unit is disabled and the daemon reloaded
c.Check(s.argses, DeepEquals, [][]string{
{"--root", s.rootDir, "disable", "snap-foo-42.mount"},
diff --git a/internals/testutil/containschecker.go b/internals/testutil/containschecker.go
index 82ff2ee5..644ef158 100644
--- a/internals/testutil/containschecker.go
+++ b/internals/testutil/containschecker.go
@@ -1,16 +1,21 @@
-// Copyright (c) 2014-2020 Canonical Ltd
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License version 3 as
-// published by the Free Software Foundation.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program. If not, see .
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2015-2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
package testutil
@@ -81,7 +86,7 @@ func (c *containsChecker) Check(params []interface{}, names []string) (result bo
var container interface{} = params[0]
var elem interface{} = params[1]
if commonEquals(container, elem, &result, &error) {
- return
+ return result, error
}
// Do the actual test using ==
switch containerV := reflect.ValueOf(container); containerV.Kind() {
@@ -111,7 +116,7 @@ type deepContainsChecker struct {
}
// DeepContains is a Checker that looks for a elem in a container using
-// DeepEqual. The elem can be any object. The container can be an array, slice
+// DeepEqual. The elem can be any object. The container can be an array, slice
// or string.
var DeepContains check.Checker = &deepContainsChecker{
&check.CheckerInfo{Name: "DeepContains", Params: []string{"container", "elem"}},
@@ -121,7 +126,7 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul
var container interface{} = params[0]
var elem interface{} = params[1]
if commonEquals(container, elem, &result, &error) {
- return
+ return result, error
}
// Do the actual test using reflect.DeepEqual
switch containerV := reflect.ValueOf(container); containerV.Kind() {
@@ -145,3 +150,126 @@ func (c *deepContainsChecker) Check(params []interface{}, names []string) (resul
return false, fmt.Sprintf("%T is not a supported container", container)
}
}
+
+type deepUnsortedMatchChecker struct {
+ *check.CheckerInfo
+}
+
+// DeepUnsortedMatches checks if two containers contain the same elements in
+// the same number (but possibly different order) using DeepEqual. The container
+// can be an array, a slice or a map.
+var DeepUnsortedMatches check.Checker = &deepUnsortedMatchChecker{
+ &check.CheckerInfo{Name: "DeepUnsortedMatches", Params: []string{"container1", "container2"}},
+}
+
+func (c *deepUnsortedMatchChecker) Check(params []interface{}, _ []string) (bool, string) {
+ container1 := reflect.ValueOf(params[0])
+ container2 := reflect.ValueOf(params[1])
+
+ // if both args are nil, return true
+ if container1.Kind() == reflect.Invalid && container2.Kind() == reflect.Invalid {
+ return true, ""
+ }
+
+ if container1.Kind() == reflect.Invalid || container2.Kind() == reflect.Invalid {
+ return false, "only one container was nil"
+ }
+
+ if container1.Kind() != container2.Kind() {
+ return false, fmt.Sprintf("containers are of different types: %s != %s", container1.Kind(), container2.Kind())
+ }
+
+ switch container1.Kind() {
+ case reflect.Array, reflect.Slice:
+ return deepSequenceMatch(container1, container2)
+ case reflect.Map:
+ return deepMapMatch(container1, container2)
+ default:
+ return false, fmt.Sprintf("'%s' is not a supported type: must be slice, array or map", container1.Kind().String())
+ }
+}
+
+func deepMapMatch(container1, container2 reflect.Value) (bool, string) {
+ if valid, output := validateContainerTypesAndLengths(container1, container2); !valid {
+ return false, output
+ }
+
+ switch container1.Type().Elem().Kind() {
+ case reflect.Slice, reflect.Array, reflect.Map:
+ // only run the unsorted match if the map values are containers
+ default:
+ if !reflect.DeepEqual(container1.Interface(), container2.Interface()) {
+ return false, "maps don't match"
+ }
+ return true, ""
+ }
+
+ for _, key := range container1.MapKeys() {
+ el1 := container1.MapIndex(key)
+ el2 := container2.MapIndex(key)
+
+ absent := el2 == reflect.Value{}
+ if absent {
+ return false, fmt.Sprintf("key %q from one map is absent from the other map", key)
+ }
+
+ var ok bool
+ var msg string
+ switch el1.Kind() {
+ case reflect.Array, reflect.Slice:
+ ok, msg = deepSequenceMatch(el1, el2)
+ case reflect.Map:
+ ok, msg = deepMapMatch(el1, el2)
+ }
+
+ if !ok {
+ return false, msg
+ }
+ }
+
+ return true, ""
+}
+
+func deepSequenceMatch(container1, container2 reflect.Value) (bool, string) {
+ if valid, output := validateContainerTypesAndLengths(container1, container2); !valid {
+ return false, output
+ }
+
+ matched := make([]bool, container1.Len())
+out:
+ for i := 0; i < container1.Len(); i++ {
+ el1 := container1.Index(i).Interface()
+
+ for e := 0; e < container2.Len(); e++ {
+ el2 := container2.Index(e).Interface()
+
+ if !matched[e] && reflect.DeepEqual(el1, el2) {
+ // mark already matched elements, so that duplicate elements in
+ // one container can't be matched to the same element in the other.
+ matched[e] = true
+ continue out
+ }
+ }
+
+ return false, fmt.Sprintf("element [%d]=%s was unmatched in the second container", i, el1)
+ }
+
+ return true, ""
+}
+
+func validateContainerTypesAndLengths(container1, container2 reflect.Value) (bool, string) {
+ if container1.Len() != container2.Len() {
+ return false, fmt.Sprintf("containers have different lengths: %d != %d", container1.Len(), container2.Len())
+ } else if container1.Type().Elem() != container2.Type().Elem() {
+ return false, fmt.Sprintf("containers have different element types: %s != %s", container1.Type().Elem(), container2.Type().Elem())
+ }
+
+ if container1.Kind() == reflect.Map && container2.Kind() == reflect.Map {
+ keyType1, keyType2 := container1.Type().Key(), container2.Type().Key()
+ if keyType1 != keyType2 {
+ return false, fmt.Sprintf("maps have different key types: %s != %s", keyType1, keyType2)
+ }
+ }
+
+ return true, ""
+}
diff --git a/internals/testutil/containschecker_test.go b/internals/testutil/containschecker_test.go
index 86c5c0cc..8a94c897 100644
--- a/internals/testutil/containschecker_test.go
+++ b/internals/testutil/containschecker_test.go
@@ -1,16 +1,21 @@
-// Copyright (c) 2014-2020 Canonical Ltd
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU General Public License version 3 as
-// published by the Free Software Foundation.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU General Public License for more details.
-//
-// You should have received a copy of the GNU General Public License
-// along with this program. If not, see .
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2015-2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
package testutil_test
@@ -216,3 +221,152 @@ func (*containsCheckerSuite) TestDeepContainsUncomparableType(c *check.C) {
testCheck(c, DeepContains, true, "", containerSlice, elem)
testCheck(c, DeepContains, true, "", containerMap, elem)
}
+
+type example struct {
+ a string
+ b map[string]int
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceSuccess(c *check.C) {
+ slice1 := []example{
+ {a: "one", b: map[string]int{"a": 1}},
+ {a: "two", b: map[string]int{"b": 2}},
+ }
+ slice2 := []example{
+ {a: "two", b: map[string]int{"b": 2}},
+ {a: "one", b: map[string]int{"a": 1}},
+ }
+
+ c.Check(slice1, DeepUnsortedMatches, slice2)
+ c.Check(slice2, DeepUnsortedMatches, slice1)
+ c.Check([]string{"a", "a"}, DeepUnsortedMatches, []string{"a", "a"})
+ c.Check([]string{"a", "b", "a"}, DeepUnsortedMatches, []string{"b", "a", "a"})
+ slice := [1]int{1}
+ c.Check(slice, DeepUnsortedMatches, slice)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesSliceFailure(c *check.C) {
+ slice1 := []string{"a", "a", "b"}
+ slice2 := []string{"b", "a", "c"}
+
+ testCheck(c, DeepUnsortedMatches, false, "element [1]=a was unmatched in the second container", slice1, slice2)
+ testCheck(c, DeepUnsortedMatches, false, "element [2]=c was unmatched in the second container", slice2, slice1)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapSuccess(c *check.C) {
+ map1 := map[string]example{
+ "a": {a: "a", b: map[string]int{"a": 1, "b": 2}},
+ "c": {a: "c", b: map[string]int{"c": 3, "d": 4}},
+ }
+ map2 := map[string]example{
+ "c": {a: "c", b: map[string]int{"c": 3, "d": 4}},
+ "a": {a: "a", b: map[string]int{"a": 1, "b": 2}},
+ }
+
+ c.Check(map1, DeepUnsortedMatches, map2)
+ c.Check(map2, DeepUnsortedMatches, map1)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapStructFail(c *check.C) {
+ map1 := map[string]example{
+ "a": {a: "a", b: map[string]int{"a": 2, "b": 1}},
+ }
+ map2 := map[string]example{
+ "a": {a: "a", b: map[string]int{"a": 1, "b": 2}},
+ }
+
+ testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedKeyFailure(c *check.C) {
+ map1 := map[string]int{"a": 1, "c": 2}
+ map2 := map[string]int{"a": 1, "b": 2}
+
+ testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2)
+ testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapUnmatchedValueFailure(c *check.C) {
+ map1 := map[string]int{"a": 1, "b": 2}
+ map2 := map[string]int{"a": 1, "b": 3}
+
+ testCheck(c, DeepUnsortedMatches, false, "maps don't match", map1, map2)
+ testCheck(c, DeepUnsortedMatches, false, "maps don't match", map2, map1)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentTypeFailure(c *check.C) {
+ testCheck(c, DeepUnsortedMatches, false, "containers are of different types: slice != array", []int{}, [1]int{})
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentElementType(c *check.C) {
+ testCheck(c, DeepUnsortedMatches, false, "containers have different element types: int != string", []int{1}, []string{"a"})
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesDifferentLengthFailure(c *check.C) {
+ testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 1 != 2", []int{1}, []int{1, 1})
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesNilArgFailure(c *check.C) {
+ testCheck(c, DeepUnsortedMatches, false, "only one container was nil", nil, []int{1})
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesBothNilArgSuccess(c *check.C) {
+ c.Check(nil, DeepUnsortedMatches, nil)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesNonContainerValues(c *check.C) {
+ testCheck(c, DeepUnsortedMatches, false, "'string' is not a supported type: must be slice, array or map", "a", "a")
+ testCheck(c, DeepUnsortedMatches, false, "'int' is not a supported type: must be slice, array or map", 1, 2)
+ testCheck(c, DeepUnsortedMatches, false, "'bool' is not a supported type: must be slice, array or map", true, false)
+ testCheck(c, DeepUnsortedMatches, false, "'ptr' is not a supported type: must be slice, array or map", &[]string{"a", "b"}, &[]string{"a", "b"})
+ testCheck(c, DeepUnsortedMatches, false, "'func' is not a supported type: must be slice, array or map", func() {}, func() {})
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsOfSlices(c *check.C) {
+ map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}}
+ map2 := map[string][]string{"a": {"bar", "foo"}, "b": {"bar", "foo"}}
+
+ c.Check(map1, DeepUnsortedMatches, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentKeyTypes(c *check.C) {
+ map1 := map[string][]string{"a": {"foo", "bar"}}
+ map2 := map[int][]string{1: {"bar", "foo"}}
+
+ testCheck(c, DeepUnsortedMatches, false, "maps have different key types: string != int", map1, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentValueTypes(c *check.C) {
+ map1 := map[string][]string{"a": {"foo", "bar"}}
+ map2 := map[string][2]string{"a": {"foo", "bar"}}
+
+ testCheck(c, DeepUnsortedMatches, false, "containers have different element types: []string != [2]string", map1, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsDifferentLengths(c *check.C) {
+ map1 := map[string][]string{"a": {"foo", "bar"}, "b": {"foo", "bar"}}
+ map2 := map[string][]string{"a": {"bar", "foo"}}
+
+ testCheck(c, DeepUnsortedMatches, false, "containers have different lengths: 2 != 1", map1, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesMapsMissingKey(c *check.C) {
+ map1 := map[string][]string{"a": {"foo", "bar"}}
+ map2 := map[string][]string{"b": {"bar", "foo"}}
+
+ testCheck(c, DeepUnsortedMatches, false, "key \"a\" from one map is absent from the other map", map1, map2)
+}
+
+func (*containsCheckerSuite) TestDeepUnsortedMatchesNestedMaps(c *check.C) {
+ map1 := map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}}
+ map2 := map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}}
+ c.Check(map1, DeepUnsortedMatches, map2)
+
+ map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}}
+ map2 = map[string]map[string][]string{"a": {"c": []string{"bar", "foo"}}}
+ testCheck(c, DeepUnsortedMatches, false, "key \"b\" from one map is absent from the other map", map1, map2)
+
+ map1 = map[string]map[string][]string{"a": {"b": []string{"foo", "bar"}}, "c": {"b": []string{"foo"}}}
+ map2 = map[string]map[string][]string{"a": {"b": []string{"bar", "foo"}}, "c": {"b": []string{"bar"}}}
+ testCheck(c, DeepUnsortedMatches, false, "element [0]=foo was unmatched in the second container", map1, map2)
+}
diff --git a/internals/testutil/errorischecker.go b/internals/testutil/errorischecker.go
new file mode 100644
index 00000000..e22030c6
--- /dev/null
+++ b/internals/testutil/errorischecker.go
@@ -0,0 +1,53 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2022 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package testutil
+
+import (
+ "errors"
+
+ "gopkg.in/check.v1"
+)
+
+// ErrorIs calls errors.Is with the provided arguments.
+var ErrorIs = &errorIsChecker{
+ &check.CheckerInfo{Name: "ErrorIs", Params: []string{"error", "target"}},
+}
+
+type errorIsChecker struct {
+ *check.CheckerInfo
+}
+
+func (*errorIsChecker) Check(params []interface{}, names []string) (result bool, errMsg string) {
+ if params[0] == nil {
+ return params[1] == nil, ""
+ }
+
+ err, ok := params[0].(error)
+ if !ok {
+ return false, "first argument must be an error"
+ }
+
+ target, ok := params[1].(error)
+ if !ok {
+ return false, "second argument must be an error"
+ }
+
+ return errors.Is(err, target), ""
+}
diff --git a/internals/testutil/errorischecker_test.go b/internals/testutil/errorischecker_test.go
new file mode 100644
index 00000000..5f2fc28c
--- /dev/null
+++ b/internals/testutil/errorischecker_test.go
@@ -0,0 +1,72 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2022 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package testutil_test
+
+import (
+ "errors"
+
+ . "github.com/canonical/pebble/internals/testutil"
+
+ . "gopkg.in/check.v1"
+)
+
+type errorIsCheckerSuite struct{}
+
+var _ = Suite(&errorIsCheckerSuite{})
+
+type baseError struct{}
+
+func (baseError) Error() string { return "" }
+
+func (baseError) Is(err error) bool {
+ _, ok := err.(baseError)
+ return ok
+}
+
+type wrapperError struct {
+ err error
+}
+
+func (*wrapperError) Error() string { return "" }
+
+func (e *wrapperError) Unwrap() error { return e.err }
+
+func (*errorIsCheckerSuite) TestErrorIsCheckSucceeds(c *C) {
+ testInfo(c, ErrorIs, "ErrorIs", []string{"error", "target"})
+
+ c.Assert(baseError{}, ErrorIs, baseError{})
+ err := &wrapperError{err: baseError{}}
+ c.Assert(err, ErrorIs, baseError{})
+}
+
+func (*errorIsCheckerSuite) TestErrorIsCheckFails(c *C) {
+ c.Assert(nil, Not(ErrorIs), baseError{})
+ c.Assert(errors.New(""), Not(ErrorIs), baseError{})
+}
+
+func (*errorIsCheckerSuite) TestErrorIsWithInvalidArguments(c *C) {
+ res, errMsg := ErrorIs.Check([]interface{}{"", errors.New("")}, []string{"error", "target"})
+ c.Assert(res, Equals, false)
+ c.Assert(errMsg, Equals, "first argument must be an error")
+
+ res, errMsg = ErrorIs.Check([]interface{}{errors.New(""), ""}, []string{"error", "target"})
+ c.Assert(res, Equals, false)
+ c.Assert(errMsg, Equals, "second argument must be an error")
+}