Skip to content

Commit

Permalink
Merge pull request #20 from mrf345/testing
Browse files Browse the repository at this point in the history
Replace `go-observable` with `StatusObservable`, and optimize
  • Loading branch information
mrf345 authored Sep 18, 2024
2 parents 8ac1145 + 988db4d commit 233d988
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 73 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module github.com/mrf345/safelock-cli
go 1.22

require (
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00
github.com/inancgumus/screen v0.0.0-20190314163918-06e984b86ed3
github.com/mholt/archiver/v4 v4.0.0-alpha.8
github.com/spf13/cobra v1.8.1
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0Zeo
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00 h1:4wp5bMTx9eV6in+ZKiUsyeOqYdp9ooqpw1YWXjwVHJo=
github.com/GianlucaGuarini/go-observable v0.0.0-20171228155646-e39e699e0a00/go.mod h1:2pqNiwoZ8Fj1HBGWyPTXW/iPD332sJzTp3Iy0dIcFMc=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/bodgit/plumbing v1.2.0 h1:gg4haxoKphLjml+tgnecR4yLBV5zo4HAZGCtAh3xCzM=
Expand Down
43 changes: 11 additions & 32 deletions safelock/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package safelock

import (
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"

Expand All @@ -30,8 +28,7 @@ func newAeadWriter(pwd string, w io.Writer, config EncryptionConfig, errs chan e
errs: errs,
aeadDone: make(chan bool, 2),
}
aw.writeSalt(w)
go aw.loadAead()
go aw.writeSaltAndLoad(w)
return aw
}

Expand All @@ -55,20 +52,15 @@ func (aw *aeadWrapper) getAead() cipher.AEAD {
return aw.aead
}

func (aw *aeadWrapper) writeSalt(w io.Writer) {
var err error

aw.salt = make([]byte, aw.config.SaltLength)

if _, err = io.ReadFull(rand.Reader, aw.salt); err != nil {
aw.errs <- fmt.Errorf("failed to create random salt > %w", err)
return
}
func (aw *aeadWrapper) writeSaltAndLoad(w io.Writer) {
aw.salt = (<-aw.config.random)[:aw.config.SaltLength]

if _, err = w.Write(aw.salt); err != nil {
if _, err := w.Write(aw.salt); err != nil {
aw.errs <- fmt.Errorf("failed to write salt > %w", err)
return
}

aw.loadAead()
}

func (aw *aeadWrapper) readSalt(r InputReader) {
Expand All @@ -94,11 +86,6 @@ func (aw *aeadWrapper) readSalt(r InputReader) {
func (aw *aeadWrapper) loadAead() {
var err error

if aw.config.SaltLength > len(aw.salt) {
aw.errs <- errors.New("missing salt, most probably race condition")
return
}

key := argon2.IDKey(
aw.pwd,
aw.salt,
Expand All @@ -116,27 +103,19 @@ func (aw *aeadWrapper) loadAead() {
aw.aeadDone <- true
}

func (aw *aeadWrapper) encrypt(chunk []byte) (encrypted []byte, err error) {
aead := aw.getAead()
func (aw *aeadWrapper) encrypt(chunk []byte) []byte {
idx := []byte(fmt.Sprintf("%d", aw.counter))
nonce := make([]byte, aead.NonceSize())

if _, err = rand.Read(nonce); err != nil {
aw.errs <- fmt.Errorf("failed to generate nonce > %w", err)
return
}

encrypted = append(nonce, aead.Seal(nil, nonce, chunk, idx)...)
aead := aw.getAead()
nonce := (<-aw.config.random)[:aead.NonceSize()]
aw.counter += 1

return
return append(nonce, aead.Seal(nil, nonce, chunk, idx)...)
}

func (aw *aeadWrapper) decrypt(chunk []byte) (output []byte, err error) {
aead := aw.getAead()

if aead.NonceSize() > len(chunk) {
err = &slErrs.ErrFailedToAuthenticate{Msg: "chunk size size"}
err = &slErrs.ErrFailedToAuthenticate{Msg: "invalid chunk size"}
aw.errs <- err
return
}
Expand Down
15 changes: 6 additions & 9 deletions safelock/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@ import (
// and then outputs the content into `outputPath` which must be a valid path to an existing directory
//
// NOTE: `ctx` context is optional you can pass `nil` and the method will handle it
func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) {
func (sl *Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, password string) (err error) {
errs := make(chan error)
signals, closeSignals := utils.GetExitSignals()
unSubStatus := sl.StatusObs.Subscribe(sl.logStatus)

if ctx == nil {
ctx = context.Background()
}

sl.StatusObs.
On(StatusUpdate.Str(), sl.logStatus).
Trigger(StatusStart.Str())

defer sl.StatusObs.
Off(StatusUpdate.Str(), sl.logStatus).
Trigger(StatusEnd.Str())
sl.StatusObs.next(StatusItem{Event: StatusStart})
defer sl.StatusObs.next(StatusItem{Event: StatusEnd})
defer unSubStatus()

go func() {
if err = sl.validateDecryptionPaths(outputPath); err != nil {
Expand Down Expand Up @@ -68,7 +65,7 @@ func (sl Safelock) Decrypt(ctx context.Context, input InputReader, outputPath, p
err = context.DeadlineExceeded
return
case err = <-errs:
sl.StatusObs.Trigger(StatusError.Str(), err)
sl.StatusObs.next(StatusItem{Event: StatusError, Err: err})
return
case <-signals:
return
Expand Down
18 changes: 8 additions & 10 deletions safelock/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@ import (
// outputs into an object `output` that implements [io.Writer] such as [io.File]
//
// NOTE: `ctx` context is optional you can pass `nil` and the method will handle it
func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) {
func (sl *Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.Writer, password string) (err error) {
errs := make(chan error)
go sl.loadRandom(errs)
aead := newAeadWriter(password, output, sl.EncryptionConfig, errs)
signals, closeSignals := utils.GetExitSignals()
unSubStatus := sl.StatusObs.Subscribe(sl.logStatus)

if ctx == nil {
ctx = context.Background()
}

sl.StatusObs.
On(StatusUpdate.Str(), sl.logStatus).
Trigger(StatusStart.Str())

defer sl.StatusObs.
Off(StatusUpdate.Str(), sl.logStatus).
Trigger(StatusEnd.Str())
sl.StatusObs.next(StatusItem{Event: StatusStart})
defer sl.StatusObs.next(StatusItem{Event: StatusEnd})
defer unSubStatus()

go func() {
if err = sl.validateEncryptionInputs(inputPaths, password); err != nil {
Expand All @@ -39,7 +38,6 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W
}

ctx, cancel := context.WithCancel(ctx)
aead := newAeadWriter(password, output, sl.EncryptionConfig, errs)
writer := newWriter(password, output, 20.0, cancel, aead)

if err = sl.encryptFiles(ctx, inputPaths, writer); err != nil {
Expand All @@ -63,7 +61,7 @@ func (sl Safelock) Encrypt(ctx context.Context, inputPaths []string, output io.W
err = context.DeadlineExceeded
return
case err = <-errs:
sl.StatusObs.Trigger(StatusError.Str(), err)
sl.StatusObs.next(StatusItem{Event: StatusError, Err: err})
return
case <-signals:
return
Expand Down
73 changes: 68 additions & 5 deletions safelock/events.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,84 @@
package safelock

import "sync"

// [safelock.Safelock.StatusObs] streaming event keys type
type StatusEvent string

// [safelock.Safelock.StatusObs] streaming event keys
const (
StatusStart StatusEvent = "start_status" // encryption/decryption has started (no args)
StatusEnd StatusEvent = "end_status" // encryption/decryption has ended (no args)
StatusUpdate StatusEvent = "update_status" // new status update (status string, percent float64)
StatusError StatusEvent = "error_status" // encryption/decryption failed (error)
StatusStart StatusEvent = "start_status" // encryption/decryption has started
StatusEnd StatusEvent = "end_status" // encryption/decryption has ended
StatusUpdate StatusEvent = "update_status" // new status update
StatusError StatusEvent = "error_status" // encryption/decryption failed
)

// return event key value as string
func (se StatusEvent) Str() string {
return string(se)
}

// item used to communicate status changes
type StatusItem struct {
// status change event key
Event StatusEvent
// completion percent
Percent float64
// optional status change text
Msg string
// optional status change error
Err error
}

// observable like data structure used to stream status changes
type StatusObservable struct {
mu sync.RWMutex
subs map[int]func(StatusItem)
counter int
}

// creates a new [safelock.StatusObservable] instance
func NewStatusObs() *StatusObservable {
return &StatusObservable{
subs: make(map[int]func(StatusItem)),
}
}

// adds a new status change subscriber, and returns the unsubscribe function
func (obs *StatusObservable) Subscribe(callback func(StatusItem)) func() {
obs.mu.Lock()
id := obs.counter
obs.subs[id] = callback
obs.counter += 1
obs.mu.Unlock()

// returns unsubscribe function
return func() {
obs.mu.Lock()
delete(obs.subs, id)
obs.mu.Unlock()
}
}

// clears all subscriptions
func (obs *StatusObservable) Unsubscribe() {
obs.mu.Lock()
clear(obs.subs)
obs.mu.Unlock()
}

func (obs *StatusObservable) next(value StatusItem) {
obs.mu.RLock()
for _, callback := range obs.subs {
go callback(value)
}
obs.mu.RUnlock()
}

func (sl *Safelock) updateStatus(status string, percent float64) {
sl.StatusObs.Trigger(StatusUpdate.Str(), status, percent)
sl.StatusObs.next(StatusItem{
Event: StatusUpdate,
Msg: status,
Percent: percent,
})
}
6 changes: 4 additions & 2 deletions safelock/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ func (sl *Safelock) log(msg string, params ...any) {
}
}

func (sl *Safelock) logStatus(status string, percent float64) {
sl.log("%s (%.2f%%)\n", status, percent)
func (sl *Safelock) logStatus(status StatusItem) {
if status.Event == StatusUpdate {
sl.log("%s (%.2f%%)\n", status.Msg, status.Percent)
}
}
27 changes: 22 additions & 5 deletions safelock/safelock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package safelock

import (
"context"
"crypto/rand"
"fmt"
"io"
"runtime"

"github.com/GianlucaGuarini/go-observable"
"github.com/klauspost/compress/zstd"
"github.com/mholt/archiver/v4"
)
Expand All @@ -14,7 +15,7 @@ import (
type EncryptionConfig struct {
// encryption key length (default: 32)
KeyLength uint32
// encryption salt length (default: 12)
// encryption salt length (default: 16)
SaltLength int
// number of argon2 hashing iterations (default: 3)
IterationCount uint32
Expand All @@ -26,6 +27,21 @@ type EncryptionConfig struct {
MinPasswordLength int
// ratio to create file header size based on (default: 1024 * 4)
HeaderRatio int

random chan []byte
}

func (ec *EncryptionConfig) loadRandom(errs chan error) {
for {
nonce := make([]byte, 50)

if _, err := rand.Read(nonce); err != nil {
errs <- fmt.Errorf("failed to generate random bytes > %w", err)
return
}

ec.random <- nonce
}
}

// archiving and compression configuration settings
Expand All @@ -49,7 +65,7 @@ type Safelock struct {
// disable all output and logs (default: false)
Quiet bool
// observable instance that allows us to stream the status to multiple listeners
StatusObs *observable.Observable
StatusObs *StatusObservable
}

// creates a new [safelock.Safelock] instance with the default recommended options
Expand All @@ -66,12 +82,13 @@ func New() *Safelock {
EncryptionConfig: EncryptionConfig{
IterationCount: 3,
KeyLength: 32,
SaltLength: 12,
SaltLength: 16,
MinPasswordLength: 8,
HeaderRatio: 1024 * 4,
MemSize: 64 * 1024,
Threads: uint8(runtime.NumCPU()),
random: make(chan []byte, 500),
},
StatusObs: observable.New(),
StatusObs: NewStatusObs(),
}
}
8 changes: 1 addition & 7 deletions safelock/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ func newWriter(
}

func (sw *safelockWriter) Write(chunk []byte) (written int, err error) {
var encrypted []byte

if encrypted, err = sw.aead.encrypt(chunk); err != nil {
return 0, sw.handleErr(err)
}

if written, err = sw.writer.Write(encrypted); err != nil {
if written, err = sw.writer.Write(sw.aead.encrypt(chunk)); err != nil {
err = fmt.Errorf("can't write encrypted chunk > %w", err)
return written, sw.handleErr(err)
}
Expand Down

0 comments on commit 233d988

Please sign in to comment.