Skip to content

Commit

Permalink
Merge pull request #683 from smira/545-download-contxt
Browse files Browse the repository at this point in the history
Use Go context to abort gracefully mirror updates
  • Loading branch information
smira authored Nov 30, 2017
2 parents d836334 + b7490fe commit 9cb2a30
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 80 deletions.
5 changes: 3 additions & 2 deletions aptly/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package aptly

import (
"context"
"io"
"os"

Expand Down Expand Up @@ -124,9 +125,9 @@ type Progress interface {
// Downloader is parallel HTTP fetcher
type Downloader interface {
// Download starts new download task
Download(url string, destination string) error
Download(ctx context.Context, url string, destination string) error
// DownloadWithChecksum starts new download task with checksum verification
DownloadWithChecksum(url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error
DownloadWithChecksum(ctx context.Context, url string, destination string, expected *utils.ChecksumInfo, ignoreMismatch bool, maxTries int) error
// GetProgress returns Progress object
GetProgress() Progress
}
Expand Down
55 changes: 22 additions & 33 deletions cmd/mirror_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package cmd

import (
"fmt"
"os"
"os/signal"
"strings"
"sync"

Expand Down Expand Up @@ -113,17 +111,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
return fmt.Errorf("unable to update: %s", err)
}

// Catch ^C
sigch := make(chan os.Signal)
signal.Notify(sigch, os.Interrupt)
defer signal.Stop(sigch)

abort := make(chan struct{})
go func() {
<-sigch
signal.Stop(sigch)
close(abort)
}()
context.GoContextHandleSignals()

count := len(queue)
context.Progress().Printf("Download queue: %d items (%s)\n", count, utils.HumanBytes(downloadSize))
Expand All @@ -148,7 +136,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
for idx := range queue {
select {
case downloadQueue <- idx:
case <-abort:
case <-context.Done():
return
}
}
Expand Down Expand Up @@ -181,6 +169,7 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {

// download file...
e = context.Downloader().DownloadWithChecksum(
context,
repo.PackageURL(task.File.DownloadURL()).String(),
task.TempDownPath,
&task.File.Checksums,
Expand All @@ -190,28 +179,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
pushError(e)
continue
}
case <-abort:

task.Done = true
case <-context.Done():
return
}
}
}()
}

// Wait for all downloads to finish
// Wait for all download goroutines to finish
wg.Wait()

select {
case <-abort:
return fmt.Errorf("unable to update: interrupted")
default:
}

context.Progress().ShutdownBar()

if len(errors) > 0 {
return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n "))
}

err = context.ReOpenDatabase()
if err != nil {
return fmt.Errorf("unable to update: %s", err)
Expand All @@ -221,11 +202,15 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
context.Progress().InitBar(int64(len(queue)), false)

for idx := range queue {

context.Progress().AddBar(1)

task := &queue[idx]

if !task.Done {
// download not finished yet
continue
}

// and import it back to the pool
task.File.PoolPath, err = context.PackagePool().Import(task.TempDownPath, task.File.Filename, &task.File.Checksums, true, context.CollectionFactory().ChecksumCollection())
if err != nil {
Expand All @@ -237,16 +222,20 @@ func aptlyMirrorUpdate(cmd *commander.Command, args []string) error {
additionalTask.File.PoolPath = task.File.PoolPath
additionalTask.File.Checksums = task.File.Checksums
}

select {
case <-abort:
return fmt.Errorf("unable to update: interrupted")
default:
}
}

context.Progress().ShutdownBar()

select {
case <-context.Done():
return fmt.Errorf("unable to update: interrupted")
default:
}

if len(errors) > 0 {
return fmt.Errorf("unable to update: download errors:\n %s", strings.Join(errors, "\n "))
}

repo.FinalizeDownload(context.CollectionFactory(), context.Progress())
err = context.CollectionFactory().RemoteRepoCollection().Update(repo)
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
package context

import (
gocontext "context"
"fmt"
"math/rand"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
Expand All @@ -30,6 +32,8 @@ import (
type AptlyContext struct {
sync.Mutex

gocontext.Context

flags, globalFlags *flag.FlagSet
configLoaded bool

Expand Down Expand Up @@ -438,6 +442,27 @@ func (context *AptlyContext) GlobalFlags() *flag.FlagSet {
return context.globalFlags
}

// GoContextHandleSignals upgrades context to handle ^C by aborting context
func (context *AptlyContext) GoContextHandleSignals() {
context.Lock()
defer context.Unlock()

// Catch ^C
sigch := make(chan os.Signal)
signal.Notify(sigch, os.Interrupt)

var cancel gocontext.CancelFunc

context.Context, cancel = gocontext.WithCancel(context.Context)

go func() {
<-sigch
signal.Stop(sigch)
context.Progress().PrintfStdErr("Aborting... press ^C once again to abort immediately\n")
cancel()
}()
}

// Shutdown shuts context down
func (context *AptlyContext) Shutdown() {
context.Lock()
Expand Down Expand Up @@ -494,6 +519,7 @@ func NewContext(flags *flag.FlagSet) (*AptlyContext, error) {
flags: flags,
globalFlags: flags,
dependencyOptions: -1,
Context: gocontext.TODO(),
publishedStorages: map[string]aptly.PublishedStorage{},
}

Expand Down
1 change: 1 addition & 0 deletions deb/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ type PackageDownloadTask struct {
File *PackageFile
Additional []PackageDownloadTask
TempDownPath string
Done bool
}

// DownloadList returns list of missing package files for download in format
Expand Down
11 changes: 6 additions & 5 deletions deb/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package deb

import (
"bytes"
gocontext "context"
"fmt"
"log"
"net/url"
Expand Down Expand Up @@ -258,13 +259,13 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error {

if verifier == nil {
// 0. Just download release file to temporary URL
release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String())
release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String())
if err != nil {
return err
}
} else {
// 1. try InRelease file
inrelease, err = http.DownloadTemp(d, repo.ReleaseURL("InRelease").String())
inrelease, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("InRelease").String())
if err != nil {
goto splitsignature
}
Expand All @@ -286,12 +287,12 @@ func (repo *RemoteRepo) Fetch(d aptly.Downloader, verifier pgp.Verifier) error {

splitsignature:
// 2. try Release + Release.gpg
release, err = http.DownloadTemp(d, repo.ReleaseURL("Release").String())
release, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release").String())
if err != nil {
return err
}

releasesig, err = http.DownloadTemp(d, repo.ReleaseURL("Release.gpg").String())
releasesig, err = http.DownloadTemp(gocontext.TODO(), d, repo.ReleaseURL("Release.gpg").String())
if err != nil {
return err
}
Expand Down Expand Up @@ -439,7 +440,7 @@ func (repo *RemoteRepo) DownloadPackageIndexes(progress aptly.Progress, d aptly.

for _, info := range packagesPaths {
path, kind := info[0], info[1]
packagesReader, packagesFile, err := http.DownloadTryCompression(d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries)
packagesReader, packagesFile, err := http.DownloadTryCompression(gocontext.TODO(), d, repo.IndexesRootURL(), path, repo.ReleaseFiles, ignoreMismatch, maxTries)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions http/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http
import (
"compress/bzip2"
"compress/gzip"
"context"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -39,7 +40,7 @@ var compressionMethods = []struct {

// DownloadTryCompression tries to download from URL .bz2, .gz and raw extension until
// it finds existing file.
func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) {
func DownloadTryCompression(ctx context.Context, downloader aptly.Downloader, baseURL *url.URL, path string, expectedChecksums map[string]utils.ChecksumInfo, ignoreMismatch bool, maxTries int) (io.Reader, *os.File, error) {
var err error

for _, method := range compressionMethods {
Expand All @@ -63,13 +64,13 @@ func DownloadTryCompression(downloader aptly.Downloader, baseURL *url.URL, path

if foundChecksum {
expected := expectedChecksums[bestSuffix]
file, err = DownloadTempWithChecksum(downloader, tryURL.String(), &expected, ignoreMismatch, maxTries)
file, err = DownloadTempWithChecksum(ctx, downloader, tryURL.String(), &expected, ignoreMismatch, maxTries)
} else {
if !ignoreMismatch {
continue
}

file, err = DownloadTemp(downloader, tryURL.String())
file, err = DownloadTemp(ctx, downloader, tryURL.String())
}

if err != nil {
Expand Down
21 changes: 12 additions & 9 deletions http/compression_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"context"
"errors"
"io"
"net/url"
Expand All @@ -12,6 +13,7 @@ import (

type CompressionSuite struct {
baseURL *url.URL
ctx context.Context
}

var _ = Suite(&CompressionSuite{})
Expand All @@ -25,6 +27,7 @@ const (

func (s *CompressionSuite) SetUpTest(c *C) {
s.baseURL, _ = url.Parse("http://example.com/")
s.ctx = context.Background()
}

func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
Expand All @@ -41,7 +44,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
buf = make([]byte, 4)
d := NewFakeDownloader()
d.ExpectResponse("http://example.com/file.bz2", bzipData)
r, file, err := DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -53,7 +56,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.gz", gzipData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -66,7 +69,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.xz", xzData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -80,7 +83,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectError("http://example.com/file.xz", &Error{Code: 404})
d.ExpectResponse("http://example.com/file", rawData)
r, file, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
r, file, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -91,7 +94,7 @@ func (s *CompressionSuite) TestDownloadTryCompression(c *C) {
d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectResponse("http://example.com/file.gz", "x")
_, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "unexpected EOF")
c.Assert(d.Empty(), Equals, true)
}
Expand All @@ -109,7 +112,7 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) {
buf = make([]byte, 4)
d := NewFakeDownloader()
d.ExpectResponse("http://example.com/subdir/file.bz2", bzipData)
r, file, err := DownloadTryCompression(d, s.baseURL, "subdir/file", expectedChecksums, false, 1)
r, file, err := DownloadTryCompression(s.ctx, d, s.baseURL, "subdir/file", expectedChecksums, false, 1)
c.Assert(err, IsNil)
defer file.Close()
io.ReadFull(r, buf)
Expand All @@ -119,15 +122,15 @@ func (s *CompressionSuite) TestDownloadTryCompressionLongestSuffix(c *C) {

func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) {
d := NewFakeDownloader()
_, _, err := DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err := DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "unexpected request.*")

d = NewFakeDownloader()
d.ExpectError("http://example.com/file.bz2", &Error{Code: 404})
d.ExpectError("http://example.com/file.gz", &Error{Code: 404})
d.ExpectError("http://example.com/file.xz", &Error{Code: 404})
d.ExpectError("http://example.com/file", errors.New("403"))
_, _, err = DownloadTryCompression(d, s.baseURL, "file", nil, true, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", nil, true, 1)
c.Assert(err, ErrorMatches, "403")

d = NewFakeDownloader()
Expand All @@ -141,6 +144,6 @@ func (s *CompressionSuite) TestDownloadTryCompressionErrors(c *C) {
"file.xz": {Size: 7},
"file": {Size: 7},
}
_, _, err = DownloadTryCompression(d, s.baseURL, "file", expectedChecksums, false, 1)
_, _, err = DownloadTryCompression(s.ctx, d, s.baseURL, "file", expectedChecksums, false, 1)
c.Assert(err, ErrorMatches, "checksums don't match.*")
}
Loading

0 comments on commit 9cb2a30

Please sign in to comment.