Skip to content

Commit

Permalink
Add ChaCha20-Poly1305 support to transit (#3975)
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai authored Feb 14, 2018
1 parent 901f98f commit ef00a69
Show file tree
Hide file tree
Showing 16 changed files with 3,147 additions and 51 deletions.
36 changes: 29 additions & 7 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"fmt"
"math/rand"
"os"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -283,6 +284,13 @@ func TestBackend_datakey(t *testing.T) {
}

func TestBackend_rotation(t *testing.T) {
defer os.Setenv("TRANSIT_ACC_KEY_TYPE", "")
testBackendRotation(t)
os.Setenv("TRANSIT_ACC_KEY_TYPE", "CHACHA")
testBackendRotation(t)
}

func testBackendRotation(t *testing.T) {
decryptData := make(map[string]interface{})
encryptHistory := make(map[int]map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Expand Down Expand Up @@ -365,13 +373,17 @@ func TestBackend_basic_derived(t *testing.T) {
}

func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest.TestStep {
return logicaltest.TestStep{
ts := logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "keys/" + name,
Data: map[string]interface{}{
"derived": derived,
},
}
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
ts.Data["type"] = "chacha20-poly1305"
}
return ts
}

func testAccStepListPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
Expand Down Expand Up @@ -509,7 +521,11 @@ func testAccStepReadPolicyWithVersions(t *testing.T, name string, expectNone, de
if d.Name != name {
return fmt.Errorf("bad name: %#v", d)
}
if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
if os.Getenv("TRANSIT_ACC_KEY_TYPE") == "CHACHA" {
if d.Type != keysutil.KeyType(keysutil.KeyType_ChaCha20_Poly1305).String() {
return fmt.Errorf("bad key type: %#v", d)
}
} else if d.Type != keysutil.KeyType(keysutil.KeyType_AES256_GCM96).String() {
return fmt.Errorf("bad key type: %#v", d)
}
// Should NOT get a key back
Expand Down Expand Up @@ -826,14 +842,19 @@ func TestKeyUpgrade(t *testing.T) {
}

func TestDerivedKeyUpgrade(t *testing.T) {
testDerivedKeyUpgrade(t, keysutil.KeyType_AES256_GCM96)
testDerivedKeyUpgrade(t, keysutil.KeyType_ChaCha20_Poly1305)
}

func testDerivedKeyUpgrade(t *testing.T, keyType keysutil.KeyType) {
storage := &logical.InmemStorage{}
key, _ := uuid.GenerateRandomBytes(32)
keyContext, _ := uuid.GenerateRandomBytes(32)

p := &keysutil.Policy{
Name: "test",
Key: key,
Type: keysutil.KeyType_AES256_GCM96,
Type: keyType,
Derived: true,
}

Expand Down Expand Up @@ -883,11 +904,12 @@ func TestDerivedKeyUpgrade(t *testing.T) {
}

func TestConvergentEncryption(t *testing.T) {
testConvergentEncryptionCommon(t, 0)
testConvergentEncryptionCommon(t, 2)
testConvergentEncryptionCommon(t, 0, keysutil.KeyType_AES256_GCM96)
testConvergentEncryptionCommon(t, 2, keysutil.KeyType_AES256_GCM96)
testConvergentEncryptionCommon(t, 2, keysutil.KeyType_ChaCha20_Poly1305)
}

func testConvergentEncryptionCommon(t *testing.T, ver int) {
func testConvergentEncryptionCommon(t *testing.T, ver int, keyType keysutil.KeyType) {
var b *backend
sysView := logical.TestSystemView()
storage := &logical.InmemStorage{}
Expand Down Expand Up @@ -920,7 +942,7 @@ func testConvergentEncryptionCommon(t *testing.T, ver int) {

p := &keysutil.Policy{
Name: "testkey",
Type: keysutil.KeyType_AES256_GCM96,
Type: keyType,
Derived: true,
ConvergentEncryption: true,
ConvergentVersion: ver,
Expand Down
2 changes: 2 additions & 0 deletions builtin/logical/transit/path_backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
func TestTransit_BackupRestore(t *testing.T) {
// Test encryption/decryption after a restore for supported keys
testBackupRestore(t, "aes256-gcm96", "encrypt-decrypt")
testBackupRestore(t, "chacha20-poly1305", "encrypt-decrypt")
testBackupRestore(t, "rsa-2048", "encrypt-decrypt")
testBackupRestore(t, "rsa-4096", "encrypt-decrypt")

Expand All @@ -21,6 +22,7 @@ func TestTransit_BackupRestore(t *testing.T) {

// Test HMAC/verification after a restore for all key types
testBackupRestore(t, "aes256-gcm96", "hmac-verify")
testBackupRestore(t, "chacha20-poly1305", "hmac-verify")
testBackupRestore(t, "ecdsa-p256", "hmac-verify")
testBackupRestore(t, "ed25519", "hmac-verify")
testBackupRestore(t, "rsa-2048", "hmac-verify")
Expand Down
2 changes: 2 additions & 0 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
switch keyType {
case "aes256-gcm96":
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "chacha20-poly1305":
polReq.KeyType = keysutil.KeyType_ChaCha20_Poly1305
case "ecdsa-p256":
return logical.ErrorResponse(fmt.Sprintf("key type %v not supported for this operation", keyType)), logical.ErrInvalidRequest
default:
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st

case exportTypeEncryptionKey:
switch policy.Type {
case keysutil.KeyType_AES256_GCM96:
case keysutil.KeyType_AES256_GCM96, keysutil.KeyType_ChaCha20_Poly1305:
return strings.TrimSpace(base64.StdEncoding.EncodeToString(key.Key)), nil

case keysutil.KeyType_RSA2048, keysutil.KeyType_RSA4096:
Expand Down
2 changes: 2 additions & 0 deletions builtin/logical/transit/path_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import (

func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) {
verifyExportsCorrectVersion(t, "encryption-key", "aes256-gcm96")
verifyExportsCorrectVersion(t, "encryption-key", "chacha20-poly1305")
verifyExportsCorrectVersion(t, "signing-key", "ecdsa-p256")
verifyExportsCorrectVersion(t, "signing-key", "ed25519")
verifyExportsCorrectVersion(t, "hmac-key", "aes256-gcm96")
verifyExportsCorrectVersion(t, "hmac-key", "chacha20-poly1305")
verifyExportsCorrectVersion(t, "hmac-key", "ecdsa-p256")
verifyExportsCorrectVersion(t, "hmac-key", "ed25519")
}
Expand Down
4 changes: 3 additions & 1 deletion builtin/logical/transit/path_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
switch keyType {
case "aes256-gcm96":
polReq.KeyType = keysutil.KeyType_AES256_GCM96
case "chacha20-poly1305":
polReq.KeyType = keysutil.KeyType_ChaCha20_Poly1305
case "ecdsa-p256":
polReq.KeyType = keysutil.KeyType_ECDSA_P256
case "ed25519":
Expand Down Expand Up @@ -247,7 +249,7 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f
}

switch p.Type {
case keysutil.KeyType_AES256_GCM96:
case keysutil.KeyType_AES256_GCM96, keysutil.KeyType_ChaCha20_Poly1305:
retKeys := map[string]int64{}
for k, v := range p.Keys {
retKeys[k] = v.DeprecatedCreationTime
Expand Down
2 changes: 1 addition & 1 deletion helper/keysutil/lock_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (lm *LockManager) getPolicyCommon(ctx context.Context, req PolicyRequest, l
}

switch req.KeyType {
case KeyType_AES256_GCM96:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
if req.Convergent && !req.Derived {
lm.UnlockPolicy(lock, lockType)
return nil, nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
Expand Down
104 changes: 69 additions & 35 deletions helper/keysutil/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"time"

"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/hkdf"

Expand All @@ -48,6 +49,7 @@ const (
KeyType_ED25519
KeyType_RSA2048
KeyType_RSA4096
KeyType_ChaCha20_Poly1305
)

const ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
Expand Down Expand Up @@ -75,15 +77,15 @@ type KeyType int

func (kt KeyType) EncryptionSupported() bool {
switch kt {
case KeyType_AES256_GCM96, KeyType_RSA2048, KeyType_RSA4096:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA4096:
return true
}
return false
}

func (kt KeyType) DecryptionSupported() bool {
switch kt {
case KeyType_AES256_GCM96, KeyType_RSA2048, KeyType_RSA4096:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA4096:
return true
}
return false
Expand All @@ -107,7 +109,7 @@ func (kt KeyType) HashSignatureInput() bool {

func (kt KeyType) DerivationSupported() bool {
switch kt {
case KeyType_AES256_GCM96, KeyType_ED25519:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_ED25519:
return true
}
return false
Expand All @@ -117,6 +119,8 @@ func (kt KeyType) String() string {
switch kt {
case KeyType_AES256_GCM96:
return "aes256-gcm96"
case KeyType_ChaCha20_Poly1305:
return "chacha20-poly1305"
case KeyType_ECDSA_P256:
return "ecdsa-p256"
case KeyType_ED25519:
Expand Down Expand Up @@ -569,7 +573,7 @@ func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) {
}

switch p.Type {
case KeyType_AES256_GCM96:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
n, err := derBytes.ReadFrom(limReader)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
Expand Down Expand Up @@ -622,47 +626,62 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string,
var ciphertext []byte

switch p.Type {
case KeyType_AES256_GCM96:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
// Derive the key that should be used
key, err := p.DeriveKey(context, ver)
if err != nil {
return "", err
}

// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
var aead cipher.AEAD

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
switch p.Type {
case KeyType_AES256_GCM96:
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

aead = gcm

case KeyType_ChaCha20_Poly1305:
cha, err := chacha20poly1305.New(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

aead = cha
}

if p.ConvergentEncryption {
switch p.ConvergentVersion {
case 1:
if len(nonce) != gcm.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", gcm.NonceSize())}
if len(nonce) != aead.NonceSize() {
return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())}
}
default:
nonceHmac := hmac.New(sha256.New, context)
nonceHmac.Write(plaintext)
nonceSum := nonceHmac.Sum(nil)
nonce = nonceSum[:gcm.NonceSize()]
nonce = nonceSum[:aead.NonceSize()]
}
} else {
// Compute random nonce
nonce, err = uuid.GenerateRandomBytes(gcm.NonceSize())
nonce, err = uuid.GenerateRandomBytes(aead.NonceSize())
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
}

// Encrypt and tag with GCM
ciphertext = gcm.Seal(nil, nonce, plaintext, nil)
// Encrypt and tag with AEAD
ciphertext = aead.Seal(nil, nonce, plaintext, nil)

// Place the encrypted data after the nonce
if !p.ConvergentEncryption || p.ConvergentVersion > 1 {
Expand Down Expand Up @@ -736,25 +755,40 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
var plain []byte

switch p.Type {
case KeyType_AES256_GCM96:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
key, err := p.DeriveKey(context, ver)
if err != nil {
return "", err
}

// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}
var aead cipher.AEAD

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
switch p.Type {
case KeyType_AES256_GCM96:
// Setup the cipher
aesCipher, err := aes.NewCipher(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

aead = gcm

case KeyType_ChaCha20_Poly1305:
cha, err := chacha20poly1305.New(key)
if err != nil {
return "", errutil.InternalError{Err: err.Error()}
}

aead = cha
}

if len(decoded) < gcm.NonceSize() {
if len(decoded) < aead.NonceSize() {
return "", errutil.UserError{Err: "invalid ciphertext length"}
}

Expand All @@ -763,12 +797,12 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
if p.ConvergentEncryption && p.ConvergentVersion < 2 {
ciphertext = decoded
} else {
nonce = decoded[:gcm.NonceSize()]
ciphertext = decoded[gcm.NonceSize():]
nonce = decoded[:aead.NonceSize()]
ciphertext = decoded[aead.NonceSize():]
}

// Verify and Decrypt
plain, err = gcm.Open(nil, nonce, ciphertext, nil)
plain, err = aead.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"}
}
Expand Down Expand Up @@ -1040,7 +1074,7 @@ func (p *Policy) Rotate(ctx context.Context, storage logical.Storage) (retErr er
entry.HMACKey = hmacKey

switch p.Type {
case KeyType_AES256_GCM96:
case KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
// Generate a 256bit key
newKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
Expand Down
Loading

0 comments on commit ef00a69

Please sign in to comment.