diff --git a/dirs/dirs.go b/dirs/dirs.go
index 4e309e46a3e..bd748c2cfdc 100644
--- a/dirs/dirs.go
+++ b/dirs/dirs.go
@@ -103,6 +103,7 @@ var (
SystemFontconfigCacheDir string
FreezerCgroupDir string
+ SnapshotsDir string
)
const (
@@ -259,4 +260,5 @@ func SetRootDir(rootdir string) {
SystemFontconfigCacheDir = filepath.Join(rootdir, "/var/cache/fontconfig")
FreezerCgroupDir = filepath.Join(rootdir, "/sys/fs/cgroup/freezer/")
+ SnapshotsDir = filepath.Join(rootdir, snappyDir, "snapshots")
}
diff --git a/overlord/snapshotstate/backend/backend.go b/overlord/snapshotstate/backend/backend.go
new file mode 100644
index 00000000000..7b2a9baa7af
--- /dev/null
+++ b/overlord/snapshotstate/backend/backend.go
@@ -0,0 +1,266 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+import (
+ "archive/zip"
+ "crypto"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sort"
+ "time"
+
+ "golang.org/x/net/context"
+
+ "github.com/snapcore/snapd/client"
+ "github.com/snapcore/snapd/dirs"
+ "github.com/snapcore/snapd/logger"
+ "github.com/snapcore/snapd/osutil"
+ "github.com/snapcore/snapd/snap"
+ "github.com/snapcore/snapd/strutil"
+)
+
+const (
+ archiveName = "archive.tgz"
+ metadataName = "meta.json"
+ metaHashName = "meta.sha3_384"
+
+ userArchivePrefix = "user/"
+ userArchiveSuffix = ".tgz"
+)
+
+var (
+ // Stop is used to ask Iter to stop iteration, without it being an error.
+ Stop = errors.New("stop iteration")
+
+ osOpen = os.Open
+ dirNames = (*os.File).Readdirnames
+ backendOpen = Open
+)
+
+// Iter loops over all snapshots in the snapshots directory, applying the given
+// function to each. The snapshot will be closed after the function returns. If
+// the function returns an error, iteration is stopped (and if the error isn't
+// Stop, it's returned as the error of the iterator).
+func Iter(ctx context.Context, f func(*Reader) error) error {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+
+ dir, err := osOpen(dirs.SnapshotsDir)
+ if err != nil {
+ if osutil.IsDirNotExist(err) {
+ // no dir -> no snapshots
+ return nil
+ }
+ return fmt.Errorf("cannot open snapshots directory: %v", err)
+ }
+ defer dir.Close()
+
+ var names []string
+ var readErr error
+ for readErr == nil && err == nil {
+ names, readErr = dirNames(dir, 100)
+ // note os.Readdirnames can return a non-empty names and a non-nil err
+ for _, name := range names {
+ if err = ctx.Err(); err != nil {
+ break
+ }
+
+ filename := filepath.Join(dirs.SnapshotsDir, name)
+ reader, openError := backendOpen(filename)
+ // reader can be non-nil even when openError is not nil (in
+ // which case reader.Broken will have a reason). f can
+ // check and either ignore or return an error when
+ // finding a broken snapshot.
+ if reader != nil {
+ err = f(reader)
+ } else {
+ // TODO: use warnings instead
+ logger.Noticef("Cannot open snapshot %q: %v.", name, openError)
+ }
+ if openError == nil {
+ // if openError was nil the snapshot was opened and needs closing
+ if closeError := reader.Close(); err == nil {
+ err = closeError
+ }
+ }
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if readErr != nil && readErr != io.EOF {
+ return readErr
+ }
+
+ if err == Stop {
+ err = nil
+ }
+
+ return err
+}
+
+// List valid snapshots sets.
+func List(ctx context.Context, setID uint64, snapNames []string) ([]client.SnapshotSet, error) {
+ setshots := map[uint64][]*client.Snapshot{}
+ err := Iter(ctx, func(reader *Reader) error {
+ if setID == 0 || reader.SetID == setID {
+ if len(snapNames) == 0 || strutil.ListContains(snapNames, reader.Snap) {
+ setshots[reader.SetID] = append(setshots[reader.SetID], &reader.Snapshot)
+ }
+ }
+ return nil
+ })
+
+ sets := make([]client.SnapshotSet, 0, len(setshots))
+ for id, shots := range setshots {
+ sort.Sort(bySnap(shots))
+ sets = append(sets, client.SnapshotSet{ID: id, Snapshots: shots})
+ }
+
+ sort.Sort(byID(sets))
+
+ return sets, err
+}
+
+// Filename of the given client.Snapshot in this backend.
+func Filename(snapshot *client.Snapshot) string {
+ // this _needs_ the snap name and version to be valid
+ return filepath.Join(dirs.SnapshotsDir, fmt.Sprintf("%d_%s_%s_%s.zip", snapshot.SetID, snapshot.Snap, snapshot.Version, snapshot.Revision))
+}
+
+// Save a snapshot
+func Save(ctx context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string) (*client.Snapshot, error) {
+ if err := os.MkdirAll(dirs.SnapshotsDir, 0700); err != nil {
+ return nil, err
+ }
+
+ snapshot := &client.Snapshot{
+ SetID: id,
+ Snap: si.Name(),
+ Revision: si.Revision,
+ Version: si.Version,
+ Time: time.Now(),
+ SHA3_384: make(map[string]string),
+ Size: 0,
+ Conf: cfg,
+ }
+
+ aw, err := osutil.NewAtomicFile(Filename(snapshot), 0600, 0, osutil.NoChown, osutil.NoChown)
+ if err != nil {
+ return nil, err
+ }
+ // if things worked, we'll commit (and Cancel becomes a NOP)
+ defer aw.Cancel()
+
+ w := zip.NewWriter(aw)
+ defer w.Close() // note this does not close the file descriptor (that's done by hand on the atomic writer, above)
+ if err := addDirToZip(ctx, snapshot, w, "root", archiveName, si.DataDir()); err != nil {
+ return nil, err
+ }
+
+ users, err := usersForUsernames(usernames)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, usr := range users {
+ if err := addDirToZip(ctx, snapshot, w, usr.Username, userArchiveName(usr), si.UserDataDir(usr.HomeDir)); err != nil {
+ return nil, err
+ }
+ }
+
+ metaWriter, err := w.Create(metadataName)
+ if err != nil {
+ return nil, err
+ }
+
+ hasher := crypto.SHA3_384.New()
+ enc := json.NewEncoder(io.MultiWriter(metaWriter, hasher))
+ if err := enc.Encode(snapshot); err != nil {
+ return nil, err
+ }
+
+ hashWriter, err := w.Create(metaHashName)
+ if err != nil {
+ return nil, err
+ }
+ fmt.Fprintf(hashWriter, "%x\n", hasher.Sum(nil))
+ if err := w.Close(); err != nil {
+ return nil, err
+ }
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ if err := aw.Commit(); err != nil {
+ return nil, err
+ }
+
+ return snapshot, nil
+}
+
+func addDirToZip(ctx context.Context, snapshot *client.Snapshot, w *zip.Writer, username string, entry, dir string) error {
+ hasher := crypto.SHA3_384.New()
+ if exists, isDir, err := osutil.DirExists(dir); !exists || !isDir || err != nil {
+ if exists && !isDir {
+ logger.Noticef("Not saving %q in snapshot #%d of %q as it is not a directory.", dir, snapshot.SetID, snapshot.Snap)
+ }
+ return err
+ }
+ parent, dir := filepath.Split(dir)
+
+ archiveWriter, err := w.CreateHeader(&zip.FileHeader{Name: entry})
+ if err != nil {
+ return err
+ }
+
+ var sz sizer
+
+ cmd := maybeRunuserCommand(username,
+ "tar",
+ "--create",
+ "--sparse", "--gzip",
+ "--directory", parent, dir, "common")
+ cmd.Env = []string{"GZIP=-9 -n"}
+ cmd.Stdout = io.MultiWriter(archiveWriter, hasher, &sz)
+ matchCounter := &strutil.MatchCounter{Regexp: regexp.MustCompile(".*"), N: 1}
+ cmd.Stderr = matchCounter
+ if err := osutil.RunWithContext(ctx, cmd); err != nil {
+ matches, count := matchCounter.Matches()
+ if count > 0 {
+ return fmt.Errorf("cannot create archive: %s (and %d more)", matches[0], count-1)
+ }
+ return fmt.Errorf("tar failed: %v", err)
+ }
+
+ snapshot.SHA3_384[entry] = fmt.Sprintf("%x", hasher.Sum(nil))
+ snapshot.Size += sz.size
+
+ return nil
+}
diff --git a/overlord/snapshotstate/backend/backend_test.go b/overlord/snapshotstate/backend/backend_test.go
new file mode 100644
index 00000000000..12678e26996
--- /dev/null
+++ b/overlord/snapshotstate/backend/backend_test.go
@@ -0,0 +1,498 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend_test
+
+import (
+ "archive/zip"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "os/user"
+ "path/filepath"
+ "sort"
+ "testing"
+
+ "golang.org/x/net/context"
+ "gopkg.in/check.v1"
+
+ "github.com/snapcore/snapd/client"
+ "github.com/snapcore/snapd/dirs"
+ "github.com/snapcore/snapd/logger"
+ "github.com/snapcore/snapd/overlord/snapshotstate/backend"
+ "github.com/snapcore/snapd/snap"
+)
+
+type snapshotSuite struct {
+ root string
+ restore []func()
+}
+
+var _ = check.Suite(&snapshotSuite{})
+
+// tie gocheck into testing
+func TestSnapshot(t *testing.T) { check.TestingT(t) }
+
+type tableT struct {
+ dir string
+ name string
+ content string
+}
+
+func table(si snap.PlaceInfo, homeDir string) []tableT {
+ return []tableT{
+ {
+ dir: si.DataDir(),
+ name: "foo",
+ content: "versioned system canary\n",
+ }, {
+ dir: si.CommonDataDir(),
+ name: "bar",
+ content: "common system canary\n",
+ }, {
+ dir: si.UserDataDir(homeDir),
+ name: "ufoo",
+ content: "versioned user canary\n",
+ }, {
+ dir: si.UserCommonDataDir(homeDir),
+ name: "ubar",
+ content: "common user canary\n",
+ },
+ }
+}
+
+func (s *snapshotSuite) SetUpTest(c *check.C) {
+ s.root = c.MkDir()
+
+ dirs.SetRootDir(s.root)
+
+ si := snap.MinimalPlaceInfo("hello-snap", snap.R(42))
+
+ for _, t := range table(si, filepath.Join(dirs.GlobalRootDir, "home/snapuser")) {
+ c.Check(os.MkdirAll(t.dir, 0755), check.IsNil)
+ c.Check(ioutil.WriteFile(filepath.Join(t.dir, t.name), []byte(t.content), 0644), check.IsNil)
+ }
+
+ cur, err := user.Current()
+ c.Assert(err, check.IsNil)
+
+ s.restore = append(s.restore, backend.MockUserLookup(func(username string) (*user.User, error) {
+ if username != "snapuser" {
+ return nil, user.UnknownUserError(username)
+ }
+ rv := *cur
+ rv.Username = username
+ rv.HomeDir = filepath.Join(dirs.GlobalRootDir, "home/snapuser")
+ return &rv, nil
+ }))
+}
+
+func (s *snapshotSuite) TearDownTest(c *check.C) {
+ dirs.SetRootDir("")
+ for _, restore := range s.restore {
+ restore()
+ }
+}
+
+func hashkeys(snapshot *client.Snapshot) (keys []string) {
+ for k := range snapshot.SHA3_384 {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ return keys
+}
+
+func (s *snapshotSuite) TestIterBailsIfContextDone(c *check.C) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return nil, nil // deal with it
+ })()
+
+ err := backend.Iter(ctx, nil)
+ c.Check(err, check.Equals, context.Canceled)
+ c.Check(triedToOpenDir, check.Equals, false)
+}
+
+func (s *snapshotSuite) TestIterBailsIfContextDoneMidway(c *check.C) {
+ ctx, cancel := context.WithCancel(context.Background())
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return os.Open(os.DevNull)
+ })()
+ readNames := 0
+ defer backend.MockDirNames(func(*os.File, int) ([]string, error) {
+ readNames++
+ cancel()
+ return []string{"hello"}, nil
+ })()
+ triedToOpenSnapshot := false
+ defer backend.MockOpen(func(string) (*backend.Reader, error) {
+ triedToOpenSnapshot = true
+ return nil, nil
+ })()
+
+ err := backend.Iter(ctx, nil)
+ c.Check(err, check.Equals, context.Canceled)
+ c.Check(triedToOpenDir, check.Equals, true)
+ // bails as soon as
+ c.Check(readNames, check.Equals, 1)
+ c.Check(triedToOpenSnapshot, check.Equals, false)
+}
+
+func (s *snapshotSuite) TestIterReturnsOkIfSnapshotsDirNonexistent(c *check.C) {
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return nil, os.ErrNotExist
+ })()
+
+ err := backend.Iter(context.Background(), nil)
+ c.Check(err, check.IsNil)
+ c.Check(triedToOpenDir, check.Equals, true)
+}
+
+func (s *snapshotSuite) TestIterBailsIfSnapshotsDirFails(c *check.C) {
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return nil, os.ErrInvalid
+ })()
+
+ err := backend.Iter(context.Background(), nil)
+ c.Check(err, check.ErrorMatches, "cannot open snapshots directory: invalid argument")
+ c.Check(triedToOpenDir, check.Equals, true)
+}
+
+func (s *snapshotSuite) TestIterWarnsOnOpenErrorIfSnapshotNil(c *check.C) {
+ logbuf, restore := logger.MockLogger()
+ defer restore()
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return new(os.File), nil
+ })()
+ readNames := 0
+ defer backend.MockDirNames(func(*os.File, int) ([]string, error) {
+ readNames++
+ if readNames > 1 {
+ return nil, io.EOF
+ }
+ return []string{"hello"}, nil
+ })()
+ triedToOpenSnapshot := false
+ defer backend.MockOpen(func(string) (*backend.Reader, error) {
+ triedToOpenSnapshot = true
+ return nil, os.ErrInvalid
+ })()
+
+ calledF := false
+ f := func(snapshot *backend.Reader) error {
+ calledF = true
+ return nil
+ }
+
+ err := backend.Iter(context.Background(), f)
+ // snapshot open errors are not failures:
+ c.Check(err, check.IsNil)
+ c.Check(triedToOpenDir, check.Equals, true)
+ c.Check(readNames, check.Equals, 2)
+ c.Check(triedToOpenSnapshot, check.Equals, true)
+ c.Check(logbuf.String(), check.Matches, `(?m).* Cannot open snapshot "hello": invalid argument.`)
+ c.Check(calledF, check.Equals, false)
+}
+
+func (s *snapshotSuite) TestIterCallsFuncIfSnapshotNotNil(c *check.C) {
+ logbuf, restore := logger.MockLogger()
+ defer restore()
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return new(os.File), nil
+ })()
+ readNames := 0
+ defer backend.MockDirNames(func(*os.File, int) ([]string, error) {
+ readNames++
+ if readNames > 1 {
+ return nil, io.EOF
+ }
+ return []string{"hello"}, nil
+ })()
+ triedToOpenSnapshot := false
+ defer backend.MockOpen(func(string) (*backend.Reader, error) {
+ triedToOpenSnapshot = true
+ // NOTE non-nil reader, and error, returned
+ r := backend.Reader{}
+ r.Broken = "xyzzy"
+ return &r, os.ErrInvalid
+ })()
+
+ calledF := false
+ f := func(snapshot *backend.Reader) error {
+ c.Check(snapshot.Broken, check.Equals, "xyzzy")
+ calledF = true
+ return nil
+ }
+
+ err := backend.Iter(context.Background(), f)
+ // snapshot open errors are not failures:
+ c.Check(err, check.IsNil)
+ c.Check(triedToOpenDir, check.Equals, true)
+ c.Check(readNames, check.Equals, 2)
+ c.Check(triedToOpenSnapshot, check.Equals, true)
+ c.Check(logbuf.String(), check.Equals, "")
+ c.Check(calledF, check.Equals, true)
+}
+
+func (s *snapshotSuite) TestIterReportsCloseError(c *check.C) {
+ logbuf, restore := logger.MockLogger()
+ defer restore()
+ triedToOpenDir := false
+ defer backend.MockOsOpen(func(string) (*os.File, error) {
+ triedToOpenDir = true
+ return new(os.File), nil
+ })()
+ readNames := 0
+ defer backend.MockDirNames(func(*os.File, int) ([]string, error) {
+ readNames++
+ if readNames > 1 {
+ return nil, io.EOF
+ }
+ return []string{"hello"}, nil
+ })()
+ triedToOpenSnapshot := false
+ defer backend.MockOpen(func(string) (*backend.Reader, error) {
+ triedToOpenSnapshot = true
+ r := backend.Reader{}
+ r.SetID = 42
+ return &r, nil
+ })()
+
+ calledF := false
+ f := func(snapshot *backend.Reader) error {
+ c.Check(snapshot.SetID, check.Equals, uint64(42))
+ calledF = true
+ return nil
+ }
+
+ err := backend.Iter(context.Background(), f)
+ // snapshot close errors _are_ failures (because they're completely unexpected):
+ c.Check(err, check.Equals, os.ErrInvalid)
+ c.Check(triedToOpenDir, check.Equals, true)
+ c.Check(readNames, check.Equals, 1) // never gets to read another one
+ c.Check(triedToOpenSnapshot, check.Equals, true)
+ c.Check(logbuf.String(), check.Equals, "")
+ c.Check(calledF, check.Equals, true)
+}
+
+func (s *snapshotSuite) TestList(c *check.C) {
+ logbuf, restore := logger.MockLogger()
+ defer restore()
+ defer backend.MockOsOpen(func(string) (*os.File, error) { return new(os.File), nil })()
+ readNames := 0
+ defer backend.MockDirNames(func(*os.File, int) ([]string, error) {
+ readNames++
+ if readNames > 4 {
+ return nil, io.EOF
+ }
+ return []string{
+ fmt.Sprintf("%d_foo", readNames),
+ fmt.Sprintf("%d_bar", readNames),
+ fmt.Sprintf("%d_baz", readNames),
+ }, nil
+ })()
+ defer backend.MockOpen(func(fn string) (*backend.Reader, error) {
+ var id uint64
+ var snapname string
+ fn = filepath.Base(fn)
+ _, err := fmt.Sscanf(fn, "%d_%s", &id, &snapname)
+ c.Assert(err, check.IsNil, check.Commentf(fn))
+ f, err := os.Open(os.DevNull)
+ c.Assert(err, check.IsNil, check.Commentf(fn))
+ return &backend.Reader{
+ File: f,
+ Snapshot: client.Snapshot{
+ SetID: id,
+ Snap: snapname,
+ Version: "v1.0-" + snapname,
+ Revision: snap.R(int(id)),
+ },
+ }, nil
+ })()
+
+ type tableT struct {
+ setID uint64
+ snapnames []string
+ numSets int
+ numShots int
+ predicate func(*client.Snapshot) bool
+ }
+ table := []tableT{
+ {0, nil, 4, 12, nil},
+ {0, []string{"foo"}, 4, 4, func(snapshot *client.Snapshot) bool { return snapshot.Snap == "foo" }},
+ {1, nil, 1, 3, func(snapshot *client.Snapshot) bool { return snapshot.SetID == 1 }},
+ {2, []string{"bar"}, 1, 1, func(snapshot *client.Snapshot) bool { return snapshot.Snap == "bar" && snapshot.SetID == 2 }},
+ {0, []string{"foo", "bar"}, 4, 8, func(snapshot *client.Snapshot) bool { return snapshot.Snap == "foo" || snapshot.Snap == "bar" }},
+ }
+
+ for i, t := range table {
+ comm := check.Commentf("%d: %d/%v", i, t.setID, t.snapnames)
+ // reset
+ readNames = 0
+ logbuf.Reset()
+
+ sets, err := backend.List(context.Background(), t.setID, t.snapnames)
+ c.Check(err, check.IsNil, comm)
+ c.Check(readNames, check.Equals, 5, comm)
+ c.Check(logbuf.String(), check.Equals, "", comm)
+ c.Check(sets, check.HasLen, t.numSets, comm)
+ nShots := 0
+ fnTpl := filepath.Join(dirs.SnapshotsDir, "%d_%s_%s_%s.zip")
+ for j, ss := range sets {
+ for k, snapshot := range ss.Snapshots {
+ comm := check.Commentf("%d: %d/%v #%d/%d", i, t.setID, t.snapnames, j, k)
+ if t.predicate != nil {
+ c.Check(t.predicate(snapshot), check.Equals, true, comm)
+ }
+ nShots++
+ fn := fmt.Sprintf(fnTpl, snapshot.SetID, snapshot.Snap, snapshot.Version, snapshot.Revision)
+ c.Check(backend.Filename(snapshot), check.Equals, fn, comm)
+ }
+ }
+ c.Check(nShots, check.Equals, t.numShots)
+ }
+}
+
+func (s *snapshotSuite) TestAddDirToZipBails(c *check.C) {
+ snapshot := &client.Snapshot{SetID: 42, Snap: "a-snap"}
+ buf, restore := logger.MockLogger()
+ defer restore()
+ // note as the zip is nil this would panic if it didn't bail
+ c.Check(backend.AddDirToZip(nil, snapshot, nil, "", "an/entry", filepath.Join(s.root, "nonexistent")), check.IsNil)
+ // no log for the non-existent case
+ c.Check(buf.String(), check.Equals, "")
+ buf.Reset()
+ c.Check(backend.AddDirToZip(nil, snapshot, nil, "", "an/entry", "/etc/passwd"), check.IsNil)
+ c.Check(buf.String(), check.Matches, "(?m).* is not a directory.")
+}
+
+func (s *snapshotSuite) TestAddDirToZipTarFails(c *check.C) {
+ d := filepath.Join(s.root, "foo")
+ c.Assert(os.MkdirAll(filepath.Join(d, "bar"), 0755), check.IsNil)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ var buf bytes.Buffer
+ z := zip.NewWriter(&buf)
+ c.Assert(backend.AddDirToZip(ctx, nil, z, "", "an/entry", d), check.ErrorMatches, ".* context canceled")
+}
+
+func (s *snapshotSuite) TestAddDirToZip(c *check.C) {
+ d := filepath.Join(s.root, "foo")
+ c.Assert(os.MkdirAll(filepath.Join(d, "bar"), 0755), check.IsNil)
+ c.Assert(os.MkdirAll(filepath.Join(s.root, "common"), 0755), check.IsNil)
+ c.Assert(ioutil.WriteFile(filepath.Join(d, "bar", "baz"), []byte("hello\n"), 0644), check.IsNil)
+
+ var buf bytes.Buffer
+ z := zip.NewWriter(&buf)
+ snapshot := &client.Snapshot{
+ SHA3_384: map[string]string{},
+ }
+ c.Assert(backend.AddDirToZip(context.Background(), snapshot, z, "", "an/entry", d), check.IsNil)
+ z.Close() // write out the central directory
+
+ c.Check(snapshot.SHA3_384, check.HasLen, 1)
+ c.Check(snapshot.SHA3_384["an/entry"], check.HasLen, 96)
+ c.Check(snapshot.Size > 0, check.Equals, true) // actual size most likely system-dependent
+ br := bytes.NewReader(buf.Bytes())
+ r, err := zip.NewReader(br, int64(br.Len()))
+ c.Assert(err, check.IsNil)
+ c.Check(r.File, check.HasLen, 1)
+ c.Check(r.File[0].Name, check.Equals, "an/entry")
+}
+
+func (s *snapshotSuite) TestHappyRoundtrip(c *check.C) {
+ logger.SimpleSetup()
+
+ info := &snap.Info{SideInfo: snap.SideInfo{RealName: "hello-snap", Revision: snap.R(42)}, Version: "v1.33"}
+ cfg := map[string]interface{}{"some-setting": false}
+ shID := uint64(12)
+
+ shw, err := backend.Save(context.TODO(), shID, info, cfg, []string{"snapuser"})
+ c.Assert(err, check.IsNil)
+ c.Check(shw.SetID, check.Equals, shID)
+ c.Check(shw.Snap, check.Equals, info.Name())
+ c.Check(shw.Version, check.Equals, info.Version)
+ c.Check(shw.Revision, check.Equals, info.Revision)
+ c.Check(shw.Conf, check.DeepEquals, cfg)
+ c.Check(backend.Filename(shw), check.Equals, filepath.Join(dirs.SnapshotsDir, "12_hello-snap_v1.33_42.zip"))
+ c.Check(hashkeys(shw), check.DeepEquals, []string{"archive.tgz", "user/snapuser.tgz"})
+
+ shs, err := backend.List(context.TODO(), 0, nil)
+ c.Assert(err, check.IsNil)
+ c.Assert(shs, check.HasLen, 1)
+
+ shr, err := backend.Open(backend.Filename(shw))
+ c.Assert(err, check.IsNil)
+ defer shr.Close()
+
+ c.Check(shr.SetID, check.Equals, shID)
+ c.Check(shr.Snap, check.Equals, info.Name())
+ c.Check(shr.Version, check.Equals, info.Version)
+ c.Check(shr.Revision, check.Equals, info.Revision)
+ c.Check(shr.Conf, check.DeepEquals, cfg)
+ c.Check(shr.Name(), check.Equals, filepath.Join(dirs.SnapshotsDir, "12_hello-snap_v1.33_42.zip"))
+ c.Check(shr.SHA3_384, check.DeepEquals, shw.SHA3_384)
+
+ c.Check(shr.Check(context.TODO(), nil), check.IsNil)
+
+ newroot := c.MkDir()
+ c.Assert(os.MkdirAll(filepath.Join(newroot, "home/snapuser"), 0755), check.IsNil)
+ dirs.SetRootDir(newroot)
+
+ var diff = func() *exec.Cmd {
+ cmd := exec.Command("diff", "-urN", "-x*.zip", s.root, newroot)
+ // cmd.Stdout = os.Stdout
+ // cmd.Stderr = os.Stderr
+ return cmd
+ }
+
+ for i := 0; i < 3; i++ {
+ comm := check.Commentf("%d", i)
+ // sanity check
+ c.Check(diff().Run(), check.NotNil, comm)
+
+ // restore leaves things like they were (again and again)
+ rs, err := shr.Restore(context.TODO(), nil, logger.Debugf)
+ c.Assert(err, check.IsNil, comm)
+ rs.Cleanup()
+ c.Check(diff().Run(), check.IsNil, comm)
+
+ // dirty it -> no longer like it was
+ c.Check(ioutil.WriteFile(filepath.Join(info.DataDir(), "marker"), []byte("scribble\n"), 0644), check.IsNil, comm)
+ }
+}
diff --git a/overlord/snapshotstate/backend/export_test.go b/overlord/snapshotstate/backend/export_test.go
new file mode 100644
index 00000000000..9804922f700
--- /dev/null
+++ b/overlord/snapshotstate/backend/export_test.go
@@ -0,0 +1,69 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+import (
+ "os"
+ "os/user"
+)
+
+var (
+ AddDirToZip = addDirToZip
+)
+
+func MockUserLookup(newLookup func(string) (*user.User, error)) func() {
+ oldLookup := userLookup
+ userLookup = newLookup
+ return func() {
+ userLookup = oldLookup
+ }
+}
+
+func MockUserLookupId(newLookupId func(string) (*user.User, error)) func() {
+ oldLookupId := userLookupId
+ userLookupId = newLookupId
+ return func() {
+ userLookupId = oldLookupId
+ }
+}
+
+func MockOsOpen(newOsOpen func(string) (*os.File, error)) func() {
+ oldOsOpen := osOpen
+ osOpen = newOsOpen
+ return func() {
+ osOpen = oldOsOpen
+ }
+}
+
+func MockDirNames(newDirNames func(*os.File, int) ([]string, error)) func() {
+ oldDirNames := dirNames
+ dirNames = newDirNames
+ return func() {
+ dirNames = oldDirNames
+ }
+}
+
+func MockOpen(newOpen func(string) (*Reader, error)) func() {
+ oldOpen := backendOpen
+ backendOpen = newOpen
+ return func() {
+ backendOpen = oldOpen
+ }
+}
diff --git a/overlord/snapshotstate/backend/helpers.go b/overlord/snapshotstate/backend/helpers.go
new file mode 100644
index 00000000000..88e036d30ee
--- /dev/null
+++ b/overlord/snapshotstate/backend/helpers.go
@@ -0,0 +1,174 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+import (
+ "archive/zip"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "os/user"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "github.com/snapcore/snapd/client"
+ "github.com/snapcore/snapd/dirs"
+ "github.com/snapcore/snapd/osutil/sys"
+)
+
+func zipMember(f *os.File, member string) (r io.ReadCloser, sz int64, err error) {
+ // rewind the file
+ // (shouldn't be needed, but doesn't hurt too much)
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, -1, err
+ }
+
+ fi, err := f.Stat()
+ if err != nil {
+ return nil, -1, err
+ }
+
+ arch, err := zip.NewReader(f, fi.Size())
+ if err != nil {
+ return nil, -1, err
+ }
+
+ for _, fh := range arch.File {
+ if fh.Name == member {
+ r, err = fh.Open()
+ return r, int64(fh.UncompressedSize64), err
+ }
+ }
+
+ return nil, -1, fmt.Errorf("missing archive member %q", member)
+}
+
+func userArchiveName(usr *user.User) string {
+ return filepath.Join(userArchivePrefix, usr.Username+userArchiveSuffix)
+}
+
+func isUserArchive(entry string) bool {
+ return strings.HasPrefix(entry, userArchivePrefix) && strings.HasSuffix(entry, userArchiveSuffix)
+}
+
+func entryUsername(entry string) string {
+ // this _will_ panic if !isUserArchive(entry)
+ return entry[len(userArchivePrefix) : len(entry)-len(userArchiveSuffix)]
+}
+
+type bySnap []*client.Snapshot
+
+func (a bySnap) Len() int { return len(a) }
+func (a bySnap) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a bySnap) Less(i, j int) bool { return a[i].Snap < a[j].Snap }
+
+type byID []client.SnapshotSet
+
+func (a byID) Len() int { return len(a) }
+func (a byID) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
+func (a byID) Less(i, j int) bool { return a[i].ID < a[j].ID }
+
+var (
+ userLookup = user.Lookup
+ userLookupId = user.LookupId
+)
+
+func usersForUsernames(usernames []string) ([]*user.User, error) {
+ if len(usernames) == 0 {
+ return allUsers()
+ }
+ users := make([]*user.User, 0, len(usernames))
+ for _, username := range usernames {
+ usr, err := userLookup(username)
+ if err != nil {
+ if _, ok := err.(*user.UnknownUserError); !ok {
+ return nil, err
+ }
+ u, e := userLookupId(username)
+ if e != nil {
+ // return first error, as it's usually clearer
+ return nil, err
+ }
+ usr = u
+ }
+ users = append(users, usr)
+
+ }
+ return users, nil
+}
+
+func allUsers() ([]*user.User, error) {
+ ds, err := filepath.Glob(dirs.SnapDataHomeGlob)
+ if err != nil {
+ // can't happen?
+ return nil, err
+ }
+
+ users := make([]*user.User, 1, len(ds)+1)
+ root, err := user.LookupId("0")
+ if err != nil {
+ return nil, err
+ }
+ users[0] = root
+ seen := make(map[uint32]bool, len(ds)+1)
+ seen[0] = true
+ var st syscall.Stat_t
+ for _, d := range ds {
+ err := syscall.Stat(d, &st)
+ if err != nil {
+ continue
+ }
+ if seen[st.Uid] {
+ continue
+ }
+ seen[st.Uid] = true
+ usr, err := user.LookupId(strconv.FormatUint(uint64(st.Uid), 10))
+ if err != nil {
+ return nil, err
+ }
+ users = append(users, usr)
+ }
+
+ return users, nil
+}
+
+// maybeRunuserCommand returns an exec.Cmd that will, if the current
+// effective user id is 0 and username is not "root", call runuser(1)
+// to change to the given username before running the given command.
+//
+// If username is "root", or the effective user id is 0, the given
+// command is passed directly to exec.Command.
+//
+// TODO: maybe move this to osutil
+func maybeRunuserCommand(username string, args ...string) *exec.Cmd {
+ if username == "root" || sys.Geteuid() != 0 {
+ // runuser only works for euid 0, and doesn't make sense for root
+ return exec.Command(args[0], args[1:]...)
+ }
+ sudoArgs := make([]string, len(args)+2)
+ copy(sudoArgs[2:], args)
+ sudoArgs[0] = "-u"
+ sudoArgs[1] = username
+
+ return exec.Command("runuser", sudoArgs...)
+}
diff --git a/overlord/snapshotstate/backend/reader.go b/overlord/snapshotstate/backend/reader.go
new file mode 100644
index 00000000000..5f2865c5a8f
--- /dev/null
+++ b/overlord/snapshotstate/backend/reader.go
@@ -0,0 +1,353 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+import (
+ "bytes"
+ "crypto"
+ "errors"
+ "fmt"
+ "hash"
+ "io"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "sort"
+ "syscall"
+
+ "golang.org/x/net/context"
+
+ "github.com/snapcore/snapd/client"
+ "github.com/snapcore/snapd/jsonutil"
+ "github.com/snapcore/snapd/logger"
+ "github.com/snapcore/snapd/osutil"
+ "github.com/snapcore/snapd/osutil/sys"
+ "github.com/snapcore/snapd/snap"
+ "github.com/snapcore/snapd/strutil"
+)
+
+// A Reader is a snapshot that's been opened for reading.
+type Reader struct {
+ *os.File
+ client.Snapshot
+}
+
+// Open a Snapshot given its full filename.
+//
+// If the returned error is nil, the caller must close the reader (or
+// its file) when done with it.
+//
+// If the returned error is non-nil, the returned Reader will be nil,
+// *or* have a non-empty Broken; in the latter case its file will be
+// closed.
+func Open(fn string) (reader *Reader, e error) {
+ f, err := os.Open(fn)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if e != nil && f != nil {
+ f.Close()
+ }
+ }()
+
+ reader = &Reader{
+ File: f,
+ }
+
+ // first try to load the metadata itself
+ var sz sizer
+ hasher := crypto.SHA3_384.New()
+ metaReader, metaSize, err := zipMember(f, metadataName)
+ if err != nil {
+ // no metadata file -> nothing to do :-(
+ return nil, err
+ }
+
+ if err := jsonutil.DecodeWithNumber(io.TeeReader(metaReader, io.MultiWriter(hasher, &sz)), &reader.Snapshot); err != nil {
+ return nil, err
+ }
+
+ // OK, from here on we have a Snapshot
+
+ if !reader.IsValid() {
+ reader.Broken = "invalid snapshot"
+ return reader, errors.New(reader.Broken)
+ }
+
+ if sz.size != metaSize {
+ reader.Broken = fmt.Sprintf("declared metadata size (%d) does not match actual (%d)", metaSize, sz.size)
+ return reader, errors.New(reader.Broken)
+ }
+
+ actualMetaHash := fmt.Sprintf("%x", hasher.Sum(nil))
+
+ // grab the metadata hash
+ sz.Reset()
+ metaHashReader, metaHashSize, err := zipMember(f, metaHashName)
+ if err != nil {
+ reader.Broken = err.Error()
+ return reader, err
+ }
+ metaHashBuf, err := ioutil.ReadAll(io.TeeReader(metaHashReader, &sz))
+ if err != nil {
+ reader.Broken = err.Error()
+ return reader, err
+ }
+ if sz.size != metaHashSize {
+ reader.Broken = fmt.Sprintf("declared hash size (%d) does not match actual (%d)", metaHashSize, sz.size)
+ return reader, errors.New(reader.Broken)
+ }
+ if expectedMetaHash := string(bytes.TrimSpace(metaHashBuf)); actualMetaHash != expectedMetaHash {
+ reader.Broken = fmt.Sprintf("declared hash (%.7s…) does not match actual (%.7s…)", expectedMetaHash, actualMetaHash)
+ return reader, errors.New(reader.Broken)
+ }
+
+ return reader, nil
+}
+
+func (r *Reader) checkOne(ctx context.Context, entry string, hasher hash.Hash) error {
+ body, reportedSize, err := zipMember(r.File, entry)
+ if err != nil {
+ return err
+ }
+ defer body.Close()
+
+ expectedHash := r.SHA3_384[entry]
+ readSize, err := io.Copy(io.MultiWriter(osutil.ContextWriter(ctx), hasher), body)
+ if err != nil {
+ return err
+ }
+
+ if readSize != reportedSize {
+ return fmt.Errorf("snapshot entry %q size (%d) different from actual (%d)", entry, reportedSize, readSize)
+ }
+
+ if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
+ return fmt.Errorf("snapshot entry %q expected hash (%.7s…) does not match actual (%.7s…)", entry, expectedHash, actualHash)
+ }
+ return nil
+}
+
+// Check that the data contained in the snapshot matches its hashsums.
+func (r *Reader) Check(ctx context.Context, usernames []string) error {
+ sort.Strings(usernames)
+
+ hasher := crypto.SHA3_384.New()
+ for entry := range r.SHA3_384 {
+ if len(usernames) > 0 && isUserArchive(entry) {
+ username := entryUsername(entry)
+ if !strutil.SortedListContains(usernames, username) {
+ logger.Debugf("In checking snapshot %q, skipping entry %q by user request.", r.Name(), username)
+ continue
+ }
+ }
+
+ if err := r.checkOne(ctx, entry, hasher); err != nil {
+ return err
+ }
+ hasher.Reset()
+ }
+
+ return nil
+}
+
+// Logf is the type implemented by logging functions.
+type Logf func(format string, args ...interface{})
+
+// Restore the data from the snapshot.
+//
+// If successful this will replace the existing data (for the revision in the
+// snapshot) with that contained in the snapshot. It keeps track of the old
+// data in the task so it can be undone (or cleaned up).
+func (r *Reader) Restore(ctx context.Context, usernames []string, logf Logf) (rs *RestoreState, e error) {
+ rs = &RestoreState{}
+ defer func() {
+ if e != nil {
+ logger.Noticef("Restore of snapshot %q failed (%v); undoing.", r.Name(), e)
+ rs.Revert()
+ rs = nil
+ }
+ }()
+
+ sort.Strings(usernames)
+ isRoot := sys.Geteuid() == 0
+ si := snap.MinimalPlaceInfo(r.Snap, r.Revision)
+ hasher := crypto.SHA3_384.New()
+ var sz sizer
+
+ for entry := range r.SHA3_384 {
+ if err := ctx.Err(); err != nil {
+ return rs, err
+ }
+
+ var dest string
+ isUser := isUserArchive(entry)
+ username := "root"
+ uid := sys.UserID(osutil.NoChown)
+ gid := sys.GroupID(osutil.NoChown)
+
+ if !isUser {
+ if entry != archiveName {
+ // hmmm
+ logf("Skipping restore of unknown entry %q.", entry)
+ continue
+ }
+ dest = si.DataDir()
+ } else {
+ username = entryUsername(entry)
+ if len(usernames) > 0 && !strutil.SortedListContains(usernames, username) {
+ logger.Debugf("In restoring snapshot %q, skipping entry %q by user request.", r.Name(), username)
+ continue
+ }
+ usr, err := userLookup(username)
+ if err != nil {
+ logf("Skipping restore of user %q: %v.", username, err)
+ continue
+ }
+
+ dest = si.UserDataDir(usr.HomeDir)
+ fi, err := os.Stat(usr.HomeDir)
+ if err != nil {
+ if osutil.IsDirNotExist(err) {
+ logf("Skipping restore of %q as %q doesn't exist.", dest, usr.HomeDir)
+ } else {
+ logf("Skipping restore of %q: %v.", dest, err)
+ }
+ continue
+ }
+
+ if !fi.IsDir() {
+ logf("Skipping restore of %q as %q is not a directory.", dest, usr.HomeDir)
+ continue
+ }
+
+ if st, ok := fi.Sys().(*syscall.Stat_t); ok && isRoot {
+ // the mkdir below will use the uid/gid of usr.HomeDir
+ if st.Uid > 0 {
+ uid = sys.UserID(st.Uid)
+ }
+ if st.Gid > 0 {
+ gid = sys.GroupID(st.Gid)
+ }
+ }
+ }
+ parent, revdir := filepath.Split(dest)
+
+ exists, isDir, err := osutil.DirExists(parent)
+ if err != nil {
+ return rs, err
+ }
+ if !exists {
+ // NOTE that the chown won't happen (it'll be NoChown)
+ // for the system path, and we won't be creating the
+ // user's home (as we skip restore in that case).
+ // Also no chown happens for root/root.
+ if err := osutil.MkdirAllChown(parent, 0755, uid, gid); err != nil {
+ return rs, err
+ }
+ rs.Created = append(rs.Created, parent)
+ } else if !isDir {
+ return rs, fmt.Errorf("Cannot restore snapshot into %q: not a directory.", parent)
+ }
+
+ tempdir, err := ioutil.TempDir(parent, ".snapshot")
+ if err != nil {
+ return rs, err
+ }
+ // one way or another we want tempdir gone
+ defer func() {
+ if err := os.RemoveAll(tempdir); err != nil {
+ logf("Cannot clean up temporary directory %q: %v.", tempdir, err)
+ }
+ }()
+
+ logger.Debugf("Restoring %q from %q into %q.", entry, r.Name(), tempdir)
+
+ body, expectedSize, err := zipMember(r.File, entry)
+ if err != nil {
+ return rs, err
+ }
+
+ expectedHash := r.SHA3_384[entry]
+
+ tr := io.TeeReader(body, io.MultiWriter(hasher, &sz))
+
+ // resist the temptation of using archive/tar unless it's proven
+ // that calling out to tar has issues -- there are a lot of
+ // special cases we'd need to consider otherwise
+ cmd := maybeRunuserCommand(username,
+ "tar",
+ "--extract",
+ "--preserve-permissions", "--preserve-order", "--gunzip",
+ "--directory", tempdir)
+ cmd.Env = []string{}
+ cmd.Stdin = tr
+ cmd.Stderr = os.Stderr
+ cmd.Stdout = os.Stderr
+
+ if err = osutil.RunWithContext(ctx, cmd); err != nil {
+ return rs, err
+ }
+
+ if sz.size != expectedSize {
+ return rs, fmt.Errorf("snapshot %q entry %q expected size (%d) does not match actual (%d)",
+ r.Name(), entry, expectedSize, sz.size)
+ }
+
+ if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
+ return rs, fmt.Errorf("snapshot %q entry %q expected hash (%.7s…) does not match actual (%.7s…)",
+ r.Name(), entry, expectedHash, actualHash)
+ }
+
+ // TODO: something with Config
+
+ for _, dir := range []string{"common", revdir} {
+ source := filepath.Join(tempdir, dir)
+ if exists, _, err := osutil.DirExists(source); err != nil {
+ return rs, err
+ } else if !exists {
+ continue
+ }
+ target := filepath.Join(parent, dir)
+ exists, _, err := osutil.DirExists(target)
+ if err != nil {
+ return rs, err
+ }
+ if exists {
+ rsfn := restoreStateFilename(target)
+ if err := os.Rename(target, rsfn); err != nil {
+ return rs, err
+ }
+ rs.Moved = append(rs.Moved, rsfn)
+ }
+
+ if err := os.Rename(source, target); err != nil {
+ return rs, err
+ }
+ rs.Created = append(rs.Created, target)
+ }
+
+ sz.Reset()
+ hasher.Reset()
+ }
+
+ return rs, nil
+}
diff --git a/overlord/snapshotstate/backend/restorestate.go b/overlord/snapshotstate/backend/restorestate.go
new file mode 100644
index 00000000000..fae02752696
--- /dev/null
+++ b/overlord/snapshotstate/backend/restorestate.go
@@ -0,0 +1,96 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+import (
+ "fmt"
+ "os"
+ "regexp"
+
+ "github.com/snapcore/snapd/logger"
+ "github.com/snapcore/snapd/strutil"
+)
+
+// RestoreState stores information that can be used to cleanly revert (or finish
+// cleaning up) a snapshot Restore.
+//
+// This is useful when a Restore is part of a chain of operations, and a later
+// one failing necessitates undoing the Restore.
+type RestoreState struct {
+ Done bool `json:"done,omitempty"`
+ Created []string `json:"created,omitempty"`
+ Moved []string `json:"moved,omitempty"`
+ // Config is here for convenience; this package doesn't touch it
+ Config map[string]interface{} `json:"config,omitempty"`
+}
+
+// Cleanup the backed up data from disk.
+func (rs *RestoreState) Cleanup() {
+ if rs.Done {
+ logger.Noticef("Internal error: attempting to clean up a snapshot.RestoreState twice.")
+ return
+ }
+ rs.Done = true
+ for _, dir := range rs.Moved {
+ if err := os.RemoveAll(dir); err != nil {
+ logger.Noticef("Cannot remove directory tree rooted at %q: %v.", dir, err)
+ }
+ }
+}
+
+func restoreStateFilename(fn string) string {
+ return fmt.Sprintf("%s.~%s~", fn, strutil.MakeRandomString(9))
+}
+
+var restoreStateRx = regexp.MustCompile(`\.~[a-zA-Z0-9]{9}~$`)
+
+func restoreState2orig(fn string) string {
+ if idx := restoreStateRx.FindStringIndex(fn); len(idx) > 0 {
+ return fn[:idx[0]]
+ }
+ return ""
+}
+
+// Revert the backed up data: remove what was added, move back what was moved aside.
+func (rs *RestoreState) Revert() {
+ if rs.Done {
+ logger.Noticef("Internal error: attempting to revert a snapshot.RestoreState twice.")
+ return
+ }
+ rs.Done = true
+ for _, dir := range rs.Created {
+ logger.Debugf("Removing %q.", dir)
+ if err := os.RemoveAll(dir); err != nil {
+ logger.Noticef("While undoing changes because of a previous error: cannot remove %q: %v.", dir, err)
+ }
+ }
+ for _, dir := range rs.Moved {
+ orig := restoreState2orig(dir)
+ if orig == "" {
+ // dir is not restore state?!?
+ logger.Debugf("Skipping restore of %q: unrecognised filename.", dir)
+ continue
+ }
+ logger.Debugf("Restoring %q to %q.", dir, orig)
+ if err := os.Rename(dir, orig); err != nil {
+ logger.Noticef("While undoing changes because of a previous error: cannot restore %q to %q: %v.", dir, orig, err)
+ }
+ }
+}
diff --git a/overlord/snapshotstate/backend/sizer.go b/overlord/snapshotstate/backend/sizer.go
new file mode 100644
index 00000000000..0d0a6dfb08e
--- /dev/null
+++ b/overlord/snapshotstate/backend/sizer.go
@@ -0,0 +1,34 @@
+// -*- Mode: Go; indent-tabs-mode: t -*-
+
+/*
+ * Copyright (C) 2018 Canonical Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 3 as
+ * published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+package backend
+
+type sizer struct {
+ size int64
+}
+
+func (sz *sizer) Write(data []byte) (n int, err error) {
+ n = len(data)
+ sz.size += int64(n)
+ return
+}
+
+func (sz *sizer) Reset() {
+ sz.size = 0
+}