From 988db4d4b3f7b915b2dc99a9ecb23bdfb428febc Mon Sep 17 00:00:00 2001 From: Mohamed Feddad Date: Wed, 18 Sep 2024 15:52:00 +0400 Subject: [PATCH] chore: replace go-observable with StatusObservable, and optomize --- go.mod | 1 - go.sum | 2 -- safelock/core.go | 43 +++++++------------------- safelock/decrypt.go | 15 ++++----- safelock/encrypt.go | 18 +++++------ safelock/events.go | 73 +++++++++++++++++++++++++++++++++++++++++--- safelock/logger.go | 6 ++-- safelock/safelock.go | 27 +++++++++++++--- safelock/writer.go | 8 +---- 9 files changed, 120 insertions(+), 73 deletions(-) diff --git a/go.mod b/go.mod index b3603c7..cfb8baf 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/mrf345/safelock-cli go 1.22 require ( - github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00 github.com/inancgumus/screen v0.0.0-20190314163918-06e984b86ed3 github.com/mholt/archiver/v4 v4.0.0-alpha.8 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 69e070b..284de0d 100644 --- a/go.sum +++ b/go.sum @@ -17,8 +17,6 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00 h1:4wp5bMTx9eV6in+ZKiUsyeOqYdp9ooqpw1YWXjwVHJo= -github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00/go.mod h1:2pqNiwoZ8Fj1HBGWyPTXW/iPD332sJzTp3Iy0dIcFMc= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/bodgit/plumbing v1.2.0 h1:gg4haxoKphLjml+tgnecR4yLBV5zo4HAZGCtAh3xCzM= diff --git a/safelock/core.go b/safelock/core.go index 3d6fc01..f85a823 100644 --- a/safelock/core.go +++ b/safelock/core.go @@ -2,8 +2,6 @@ package safelock import ( "crypto/cipher" - "crypto/rand" - "errors" "fmt" "io" @@ -30,8 +28,7 @@ func newAeadWriter(pwd string, w io.Writer, config EncryptionConfig, errs chan e errs: errs, aeadDone: make(chan bool, 2), } - aw.writeSalt(w) - go aw.loadAead() + go aw.writeSaltAndLoad(w) return aw } @@ -55,20 +52,15 @@ func (aw *aeadWrapper) getAead() cipher.AEAD { return aw.aead } -func (aw *aeadWrapper) writeSalt(w io.Writer) { - var err error - - aw.salt = make([]byte, aw.config.SaltLength) - - if _, err = io.ReadFull(rand.Reader, aw.salt); err != nil { - aw.errs <- fmt.Errorf("failed to create random salt > %w", err) - return - } +func (aw *aeadWrapper) writeSaltAndLoad(w io.Writer) { + aw.salt = (<-aw.config.random)[:aw.config.SaltLength] - if _, err = w.Write(aw.salt); err != nil { + if _, err := w.Write(aw.salt); err != nil { aw.errs <- fmt.Errorf("failed to write salt > %w", err) return } + + aw.loadAead() } func (aw *aeadWrapper) readSalt(r InputReader) { @@ -94,11 +86,6 @@ func (aw *aeadWrapper) readSalt(r InputReader) { func (aw *aeadWrapper) loadAead() { var err error - if aw.config.SaltLength > len(aw.salt) { - aw.errs <- errors.New("missing salt, most probably race condition") - return - } - key := argon2.IDKey( aw.pwd, aw.salt, @@ -116,27 +103,19 @@ func (aw *aeadWrapper) loadAead() { aw.aeadDone <- true } -func (aw *aeadWrapper) encrypt(chunk []byte) (encrypted []byte, err error) { - aead := aw.getAead() +func (aw *aeadWrapper) encrypt(chunk []byte) []byte { idx := []byte(fmt.Sprintf("%d", aw.counter)) - nonce := make([]byte, aead.NonceSize()) - - if _, err = rand.Read(nonce); err != nil { - aw.errs <- fmt.Errorf("failed to generate nonce > %w", err) - return - } - - encrypted = append(nonce, aead.Seal(nil, nonce, chunk, idx)...) + aead := aw.getAead() + nonce := (<-aw.config.random)[:aead.NonceSize()] aw.counter += 1 - - return + return append(nonce, aead.Seal(nil, nonce, chunk, idx)...) } func (aw *aeadWrapper) decrypt(chunk []byte) (output []byte, err error) { aead := aw.getAead() if aead.NonceSize() > len(chunk) { - err = &slErrs.ErrFailedToAuthenticate{Msg: "chunk size size"} + err = &slErrs.ErrFailedToAuthenticate{Msg: "invalid chunk size"} aw.errs <- err return } diff --git a/safelock/decrypt.go b/safelock/decrypt.go index 71d63f9..4ba83e7 100644 --- a/safelock/decrypt.go +++ b/safelock/decrypt.go @@ -16,21 +16,18 @@ import ( // and then outputs the content into `outputPath` which must be a valid path to an existing directory // // NOTE: `ctx` context is optional you can pass `nil` and the method will handle it -func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) { +func (sl *Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) { errs := make(chan error) signals, closeSignals := utils.GetExitSignals() + unSubStatus := sl.StatusObs.Subscribe(sl.logStatus) if ctx == nil { ctx = context.Background() } - sl.StatusObs. - On(StatusUpdate.Str(), sl.logStatus). - Trigger(StatusStart.Str()) - - defer sl.StatusObs. - Off(StatusUpdate.Str(), sl.logStatus). - Trigger(StatusEnd.Str()) + sl.StatusObs.next(StatusItem{Event: StatusStart}) + defer sl.StatusObs.next(StatusItem{Event: StatusEnd}) + defer unSubStatus() go func() { if err = sl.validateDecryptionPaths(outputPath); err != nil { @@ -68,7 +65,7 @@ func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, p err = context.DeadlineExceeded return case err = <-errs: - sl.StatusObs.Trigger(StatusError.Str(), err) + sl.StatusObs.next(StatusItem{Event: StatusError, Err: err}) return case <-signals: return diff --git a/safelock/encrypt.go b/safelock/encrypt.go index c19c747..5295777 100644 --- a/safelock/encrypt.go +++ b/safelock/encrypt.go @@ -16,21 +16,20 @@ import ( // outputs into an object `output` that implements [io.Writer] such as [io.File] // // NOTE: `ctx` context is optional you can pass `nil` and the method will handle it -func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) { +func (sl *Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) { errs := make(chan error) + go sl.loadRandom(errs) + aead := newAeadWriter(password, output, sl.EncryptionConfig, errs) signals, closeSignals := utils.GetExitSignals() + unSubStatus := sl.StatusObs.Subscribe(sl.logStatus) if ctx == nil { ctx = context.Background() } - sl.StatusObs. - On(StatusUpdate.Str(), sl.logStatus). - Trigger(StatusStart.Str()) - - defer sl.StatusObs. - Off(StatusUpdate.Str(), sl.logStatus). - Trigger(StatusEnd.Str()) + sl.StatusObs.next(StatusItem{Event: StatusStart}) + defer sl.StatusObs.next(StatusItem{Event: StatusEnd}) + defer unSubStatus() go func() { if err = sl.validateEncryptionInputs(inputPaths, password); err != nil { @@ -39,7 +38,6 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W } ctx, cancel := context.WithCancel(ctx) - aead := newAeadWriter(password, output, sl.EncryptionConfig, errs) writer := newWriter(password, output, 20.0, cancel, aead) if err = sl.encryptFiles(ctx, inputPaths, writer); err != nil { @@ -63,7 +61,7 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W err = context.DeadlineExceeded return case err = <-errs: - sl.StatusObs.Trigger(StatusError.Str(), err) + sl.StatusObs.next(StatusItem{Event: StatusError, Err: err}) return case <-signals: return diff --git a/safelock/events.go b/safelock/events.go index 24fa050..ef439b0 100644 --- a/safelock/events.go +++ b/safelock/events.go @@ -1,14 +1,16 @@ package safelock +import "sync" + // [safelock.Safelock.StatusObs] streaming event keys type type StatusEvent string // [safelock.Safelock.StatusObs] streaming event keys const ( - StatusStart StatusEvent = "start_status" // encryption/decryption has started (no args) - StatusEnd StatusEvent = "end_status" // encryption/decryption has ended (no args) - StatusUpdate StatusEvent = "update_status" // new status update (status string, percent float64) - StatusError StatusEvent = "error_status" // encryption/decryption failed (error) + StatusStart StatusEvent = "start_status" // encryption/decryption has started + StatusEnd StatusEvent = "end_status" // encryption/decryption has ended + StatusUpdate StatusEvent = "update_status" // new status update + StatusError StatusEvent = "error_status" // encryption/decryption failed ) // return event key value as string @@ -16,6 +18,67 @@ func (se StatusEvent) Str() string { return string(se) } +// item used to communicate status changes +type StatusItem struct { + // status change event key + Event StatusEvent + // completion percent + Percent float64 + // optional status change text + Msg string + // optional status change error + Err error +} + +// observable like data structure used to stream status changes +type StatusObservable struct { + mu sync.RWMutex + subs map[int]func(StatusItem) + counter int +} + +// creates a new [safelock.StatusObservable] instance +func NewStatusObs() *StatusObservable { + return &StatusObservable{ + subs: make(map[int]func(StatusItem)), + } +} + +// adds a new status change subscriber, and returns the unsubscribe function +func (obs *StatusObservable) Subscribe(callback func(StatusItem)) func() { + obs.mu.Lock() + id := obs.counter + obs.subs[id] = callback + obs.counter += 1 + obs.mu.Unlock() + + // returns unsubscribe function + return func() { + obs.mu.Lock() + delete(obs.subs, id) + obs.mu.Unlock() + } +} + +// clears all subscriptions +func (obs *StatusObservable) Unsubscribe() { + obs.mu.Lock() + clear(obs.subs) + obs.mu.Unlock() +} + +func (obs *StatusObservable) next(value StatusItem) { + obs.mu.RLock() + for _, callback := range obs.subs { + go callback(value) + } + obs.mu.RUnlock() +} + func (sl *Safelock) updateStatus(status string, percent float64) { - sl.StatusObs.Trigger(StatusUpdate.Str(), status, percent) + sl.StatusObs.next(StatusItem{ + Event: StatusUpdate, + Msg: status, + Percent: percent, + }) } diff --git a/safelock/logger.go b/safelock/logger.go index 85050a8..9721462 100644 --- a/safelock/logger.go +++ b/safelock/logger.go @@ -17,6 +17,8 @@ func (sl *Safelock) log(msg string, params ...any) { } } -func (sl *Safelock) logStatus(status string, percent float64) { - sl.log("%s (%.2f%%)\n", status, percent) +func (sl *Safelock) logStatus(status StatusItem) { + if status.Event == StatusUpdate { + sl.log("%s (%.2f%%)\n", status.Msg, status.Percent) + } } diff --git a/safelock/safelock.go b/safelock/safelock.go index e116634..02a6faf 100644 --- a/safelock/safelock.go +++ b/safelock/safelock.go @@ -2,10 +2,11 @@ package safelock import ( "context" + "crypto/rand" + "fmt" "io" "runtime" - "github.com/GianlucaGuarini/go-observable" "github.com/klauspost/compress/zstd" "github.com/mholt/archiver/v4" ) @@ -14,7 +15,7 @@ import ( type EncryptionConfig struct { // encryption key length (default: 32) KeyLength uint32 - // encryption salt length (default: 12) + // encryption salt length (default: 16) SaltLength int // number of argon2 hashing iterations (default: 3) IterationCount uint32 @@ -26,6 +27,21 @@ type EncryptionConfig struct { MinPasswordLength int // ratio to create file header size based on (default: 1024 * 4) HeaderRatio int + + random chan []byte +} + +func (ec *EncryptionConfig) loadRandom(errs chan error) { + for { + nonce := make([]byte, 50) + + if _, err := rand.Read(nonce); err != nil { + errs <- fmt.Errorf("failed to generate random bytes > %w", err) + return + } + + ec.random <- nonce + } } // archiving and compression configuration settings @@ -49,7 +65,7 @@ type Safelock struct { // disable all output and logs (default: false) Quiet bool // observable instance that allows us to stream the status to multiple listeners - StatusObs *observable.Observable + StatusObs *StatusObservable } // creates a new [safelock.Safelock] instance with the default recommended options @@ -66,12 +82,13 @@ func New() *Safelock { EncryptionConfig: EncryptionConfig{ IterationCount: 3, KeyLength: 32, - SaltLength: 12, + SaltLength: 16, MinPasswordLength: 8, HeaderRatio: 1024 * 4, MemSize: 64 * 1024, Threads: uint8(runtime.NumCPU()), + random: make(chan []byte, 500), }, - StatusObs: observable.New(), + StatusObs: NewStatusObs(), } } diff --git a/safelock/writer.go b/safelock/writer.go index 47e2d63..305d744 100644 --- a/safelock/writer.go +++ b/safelock/writer.go @@ -33,13 +33,7 @@ func newWriter( } func (sw *safelockWriter) Write(chunk []byte) (written int, err error) { - var encrypted []byte - - if encrypted, err = sw.aead.encrypt(chunk); err != nil { - return 0, sw.handleErr(err) - } - - if written, err = sw.writer.Write(encrypted); err != nil { + if written, err = sw.writer.Write(sw.aead.encrypt(chunk)); err != nil { err = fmt.Errorf("can't write encrypted chunk > %w", err) return written, sw.handleErr(err) }