Skip to content

Commit

Permalink
Merge #58631
Browse files Browse the repository at this point in the history
58631: storageccl: implement on-the-fly decrypting remote sst reader r=dt a=dt

The chunked encrypted encoding means we can jump to arbitrary chunks and
start decrypting. This allows implementing ReadAt and thus a pebble SST
reader can be created that decrypts on the fly, as it reads.

Release note: none.

Co-authored-by: David Taylor <[email protected]>
  • Loading branch information
craig[bot] and dt committed Jan 25, 2021
2 parents 0b09b61 + ea81e05 commit 6e26eb5
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 61 deletions.
2 changes: 2 additions & 0 deletions pkg/ccl/storageccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ go_library(
"//pkg/util/retry",
"//pkg/util/tracing",
"@com_github_cockroachdb_errors//:errors",
"@com_github_cockroachdb_pebble//sstable",
"@com_github_cockroachdb_pebble//vfs",
"@org_golang_x_crypto//pbkdf2",
],
)
Expand Down
131 changes: 90 additions & 41 deletions pkg/ccl/storageccl/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ import (
"encoding/binary"
"io"
"io/ioutil"
"os"

"github.com/cockroachdb/errors"
"github.com/cockroachdb/pebble/sstable"
"github.com/cockroachdb/pebble/vfs"
"golang.org/x/crypto/pbkdf2"
)

Expand Down Expand Up @@ -148,28 +151,30 @@ func encryptFile(plaintext, key []byte, chunked bool) ([]byte, error) {
// and reading the IV from a prefix of the file. See comments on EncryptFile
// for intended usage, and see DecryptFile
func DecryptFile(ciphertext, key []byte) ([]byte, error) {
r, err := DecryptingReader(bytes.NewReader(ciphertext), key)
r, err := decryptingReader(bytes.NewReader(ciphertext), key)
if err != nil {
return nil, err
}
return ioutil.ReadAll(r)
return ioutil.ReadAll(r.(io.Reader))
}

type decryptReader struct {
ciphertext io.Reader
ciphertext io.ReaderAt
g cipher.AEAD
fileIV []byte

eof bool
ivScratch []byte
buf []byte
pos int
pos int64 // pos is used to transform Read() to ReadAt(pos).
chunk int64
}

// DecryptingReader returns a reader that decrypts on the fly with the given
// key.
func DecryptingReader(ciphertext io.Reader, key []byte) (io.Reader, error) {
type readerAndReaderAt interface {
io.Reader
io.ReaderAt
}

func decryptingReader(ciphertext readerAndReaderAt, key []byte) (sstable.ReadableFile, error) {
gcm, err := aesgcm(key)
if err != nil {
return nil, err
Expand Down Expand Up @@ -202,7 +207,7 @@ func DecryptingReader(ciphertext io.Reader, key []byte) (io.Reader, error) {
return nil, err
}
buf, err = gcm.Open(buf[:0], iv, buf, nil)
return bytes.NewReader(buf), errors.Wrap(err, "failed to decrypt — maybe incorrect key")
return vfs.NewMemFile(buf), errors.Wrap(err, "failed to decrypt — maybe incorrect key")
}
buf := make([]byte, nonceSize, encryptionChunkSizeV2+tagSize+nonceSize)
ivScratch := buf[:nonceSize]
Expand All @@ -211,34 +216,29 @@ func DecryptingReader(ciphertext io.Reader, key []byte) (io.Reader, error) {
return r, err
}

func (r *decryptReader) fill() error {
if r.eof {
return io.EOF
// fill loads the requested chunk into the buffer.
func (r *decryptReader) fill(chunk int64) error {
if chunk == r.chunk {
return nil // this chunk is already loaded in buf.
}
r.pos = 0
r.buf = r.buf[:cap(r.buf)]
r.chunk++
var read int
for read < len(r.buf) {
n, err := r.ciphertext.Read(r.buf[read:])
read += n

// If we've reached end of the ciphertext, we still need to need to unseal
// the current chunk (even if it was empty, to detect truncations).
if err == io.EOF {
r.eof = true
break
}
if err != nil {
return err
}
r.chunk = -1 // invalidate the current buffered chunk while we fill it.
ciphertextChunkSize := int64(encryptionChunkSizeV2) + tagSize
// Load the region of ciphertext that corresponds to chunk.
n, err := r.ciphertext.ReadAt(r.buf[:cap(r.buf)], headerSize+chunk*ciphertextChunkSize)
if err != nil && err != io.EOF {
return err
}
var err error
r.buf, err = r.g.Open(r.buf[:0], r.chunkIV(r.chunk), r.buf[:read], nil)
if r.eof && len(r.buf) >= encryptionChunkSizeV2 {
return errors.Wrap(io.ErrUnexpectedEOF, "encrypted file appears truncated")
r.buf = r.buf[:n]

// Decrypt the ciphertext chunk into buf.
buf, err := r.g.Open(r.buf[:0], r.chunkIV(chunk), r.buf, nil)
if err != nil {
return errors.Wrap(err, "failed to decrypt — maybe incorrect key")
}
return errors.Wrap(err, "failed to decrypt — maybe incorrect key")
r.buf = buf
r.chunk = chunk
return err
}

func (r *decryptReader) chunkIV(num int64) []byte {
Expand All @@ -247,16 +247,48 @@ func (r *decryptReader) chunkIV(num int64) []byte {
return r.ivScratch
}

func (r *decryptReader) Read(p []byte) (int, error) {
if r.pos >= len(r.buf) {
if err := r.fill(); err != nil {
r.chunk = -1
return 0, err
func (r *decryptReader) ReadAt(p []byte, offset int64) (int, error) {
if offset < 0 {
return 0, errors.New("bad offset")
}

var read int
for {
chunk := offset / int64(encryptionChunkSizeV2)
offsetInChunk := offset % int64(encryptionChunkSizeV2)

if err := r.fill(chunk); err != nil {
return read, err
}

// If the decrypted chunk is too small to contain offset, that implies EOF.
if offsetInChunk >= int64(len(r.buf)) {
return read, io.EOF
}

// Copy from the chunk.
n := copy(p[read:], r.buf[offsetInChunk:])
read += n

// Return if we've fulfilled the request.
if read == len(p) {
return read, nil
}

// Return EOF if this was the last chunk (<chunksize).
if len(r.buf) < encryptionChunkSizeV2 {
return read, io.EOF
}

// Move offset by how much we read and go again.
offset += int64(n)
}
read := copy(p, r.buf[r.pos:])
r.pos += read
return read, nil
}

func (r *decryptReader) Read(p []byte) (int, error) {
n, err := r.ReadAt(p, r.pos)
r.pos += int64(n)
return n, err
}

func (r *decryptReader) Close() error {
Expand All @@ -266,6 +298,23 @@ func (r *decryptReader) Close() error {
return nil
}

// Size returns the size of the file.
func (r *decryptReader) Stat() (os.FileInfo, error) {
stater, ok := r.ciphertext.(interface{ Stat() (os.FileInfo, error) })
if !ok {
return nil, errors.Newf("%T does not support stat", r.ciphertext)
}
stat, err := stater.Stat()
if err != nil {
return nil, err
}

size := stat.Size()
size -= headerSize
size -= tagSize * ((size / (int64(encryptionChunkSizeV2) + tagSize)) + 1)
return sizeStat(size), nil
}

func aesgcm(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
Expand Down
128 changes: 126 additions & 2 deletions pkg/ccl/storageccl/encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,99 @@ func TestEncryptDecrypt(t *testing.T) {
require.EqualError(t, err, "file does not appear to be encrypted")
})

t.Run("ReadAt", func(t *testing.T) {
rng, _ := randutil.NewTestPseudoRand()

encryptionChunkSizeV2 = 32

plaintext := randutil.RandBytes(rng, 256)
plainReader := bytes.NewReader(plaintext)

ciphertext, err := encryptFile(plaintext, key, true)
require.NoError(t, err)

r, err := decryptingReader(bytes.NewReader(ciphertext), key)
require.NoError(t, err)

t.Run("start", func(t *testing.T) {
expected := make([]byte, 24)
got := make([]byte, len(expected))

expectedN, expectedErr := plainReader.ReadAt(expected, 0)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, 0)

require.Equal(t, expectedN, gotN)
require.Equal(t, expectedErr, gotErr)
require.Equal(t, expected, got)
})

t.Run("spanning", func(t *testing.T) {
expected := make([]byte, 24)
got := make([]byte, len(expected))

expectedN, expectedErr := plainReader.ReadAt(expected, 30)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, 30)

require.Equal(t, expectedN, gotN)
require.Equal(t, expectedErr, gotErr)
require.Equal(t, expected, got)

expectedEmpty := make([]byte, 0)
gotEmpty := make([]byte, 0)
expectedEmptyN, expectedEmptyErr := plainReader.ReadAt(expectedEmpty, 30)
gotEmptyN, gotEmptyErr := r.(io.ReaderAt).ReadAt(gotEmpty, 30)

require.Equal(t, expectedEmptyN, gotEmptyN)
require.Equal(t, expectedEmptyErr != nil, gotEmptyErr != nil)
require.Equal(t, expectedEmpty, gotEmpty)
})

t.Run("to-end", func(t *testing.T) {
expected := make([]byte, 24)
got := make([]byte, len(expected))

expectedN, expectedErr := plainReader.ReadAt(expected, 256-24)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, 256-24)

require.Equal(t, expectedN, gotN)
require.Equal(t, expectedErr, gotErr)
require.Equal(t, expected, got)
})

t.Run("spanning-end", func(t *testing.T) {
expected := make([]byte, 100)
got := make([]byte, len(expected))

expectedN, expectedErr := plainReader.ReadAt(expected, 180)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, 180)

require.Equal(t, expectedN, gotN)
require.Equal(t, expectedErr, gotErr)
require.Equal(t, expected, got)
})

t.Run("after-end", func(t *testing.T) {
expected := make([]byte, 24)
got := make([]byte, len(expected))

expectedN, _ := plainReader.ReadAt(expected, 300)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, 300)

require.Equal(t, expectedN, gotN)
require.NotNil(t, gotErr)
require.Equal(t, expected, got)

expectedEmpty := make([]byte, 0)
gotEmpty := make([]byte, 0)
expectedEmptyN, expectedEmptyErr := plainReader.ReadAt(expectedEmpty, 300)
gotEmptyN, gotEmptyErr := r.(io.ReaderAt).ReadAt(gotEmpty, 300)

require.Equal(t, expectedEmptyN, gotEmptyN)
require.Equal(t, expectedEmptyErr != nil, gotEmptyErr != nil)
require.Equal(t, expectedEmpty, gotEmpty)
})
})

t.Run("Random", func(t *testing.T) {
rng, _ := randutil.NewTestPseudoRand()
t.Run("DecryptFile", func(t *testing.T) {
Expand All @@ -83,6 +176,37 @@ func TestEncryptDecrypt(t *testing.T) {
})
}
})

t.Run("ReadAt", func(t *testing.T) {
// For each random size of chunk and text, verify random reads.
const chunkSizes, textSizes, reads = 10, 100, 500

for i := 0; i < chunkSizes; i++ {
encryptionChunkSizeV2 = rng.Intn(1024*24) + 1
for j := 0; j < textSizes; j++ {
plaintext := randutil.RandBytes(rng, rng.Intn(1024*32))
plainReader := bytes.NewReader(plaintext)
ciphertext, err := encryptFile(plaintext, key, encryptionChunkSizeV2 > 0)
require.NoError(t, err)
r, err := decryptingReader(bytes.NewReader(ciphertext), key)
require.NoError(t, err)
for k := 0; k < reads; k++ {
start := rng.Int63n(int64(float64(len(plaintext)) * 1.1))
expected := make([]byte, rng.Int63n(int64(len(plaintext))/2))
got := make([]byte, len(expected))
expectedN, expectedErr := plainReader.ReadAt(expected, start)
gotN, gotErr := r.(io.ReaderAt).ReadAt(got, start)
require.Equal(t, expectedN, gotN)
if start < int64(len(plaintext)) {
require.Equal(t, expectedErr, gotErr)
} else {
require.Equal(t, expectedErr != nil, gotErr != nil)
}
require.Equal(t, expected[:expectedN], got[:gotN])
}
}
}
})
})
_ = EncryptFileChunked // suppress unused warning.
}
Expand Down Expand Up @@ -147,11 +271,11 @@ func BenchmarkEncryption(b *testing.B) {
ciphertext := bytes.NewReader(ciphertextOriginal[chunkSizeNum])
for i := 0; i < b.N; i++ {
ciphertext.Reset(ciphertextOriginal[chunkSizeNum])
r, err := DecryptingReader(ciphertext, key)
r, err := decryptingReader(ciphertext, key)
if err != nil {
b.Fatal(err)
}
_, err = io.Copy(ioutil.Discard, r)
_, err = io.Copy(ioutil.Discard, r.(io.Reader))
if err != nil {
b.Fatal(err)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/ccl/storageccl/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func evalExport(
}

if args.Encryption != nil {
// TODO(dt): cluster version gate use EncryptFileChunked.
data, err = EncryptFile(data, args.Encryption.Key)
if err != nil {
return result.Result{}, err
Expand Down
Loading

0 comments on commit 6e26eb5

Please sign in to comment.