Skip to content

Commit

Permalink
Check that all required fields in Transit API are present. (#14074)
Browse files Browse the repository at this point in the history
* Check that all required fields in Transit API are present.

* Check for missing plaintext/ciphertext in batched Transit operations.
  • Loading branch information
victorr authored Feb 22, 2022
1 parent 5759798 commit 395cd7b
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 17 deletions.
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
var batchInputItems []BatchRequestItem
var err error
if batchInputRaw != nil {
err = decodeBatchRequestItems(batchInputRaw, &batchInputItems)
err = decodeDecryptBatchRequestItems(batchInputRaw, &batchInputItems)
if err != nil {
return nil, fmt.Errorf("failed to parse batch input: %w", err)
}
Expand Down
26 changes: 21 additions & 5 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,19 @@ to the min_encryption_version configured on the key.`,
}
}

func decodeEncryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
return decodeBatchRequestItems(src, true, false, dst)
}

func decodeDecryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
return decodeBatchRequestItems(src, false, true, dst)
}

// decodeBatchRequestItems is a fast path alternative to mapstructure.Decode to decode []BatchRequestItem.
// It aims to behave as closely possible to the original mapstructure.Decode and will return the same errors.
// Note, however, that an error will also be returned if one of the required fields is missing.
// https://github.com/hashicorp/vault/pull/8775/files#r437709722
func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiphertext bool, dst *[]BatchRequestItem) error {
if src == nil || dst == nil {
return nil
}
Expand Down Expand Up @@ -173,15 +182,18 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' expected type 'string', got unconvertible type '%T'", i, item["ciphertext"]))
}
} else if requireCiphertext {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' missing ciphertext to decrypt", i))
}

// don't allow "null" to be passed in for the plaintext value
if v, has := item["plaintext"]; has {
if casted, ok := v.(string); ok {
(*dst)[i].Plaintext = casted
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' expected type 'string', got unconvertible type '%T'", i, item["plaintext"]))
}
} else if requirePlaintext {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' missing plaintext to encrypt", i))
}

if v, has := item["nonce"]; has {
Expand Down Expand Up @@ -240,7 +252,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
batchInputRaw := d.Raw["batch_input"]
var batchInputItems []BatchRequestItem
if batchInputRaw != nil {
err = decodeBatchRequestItems(batchInputRaw, &batchInputItems)
err = decodeEncryptBatchRequestItems(batchInputRaw, &batchInputItems)
if err != nil {
return nil, fmt.Errorf("failed to parse batch input: %w", err)
}
Expand All @@ -249,14 +261,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("missing batch input to process"), logical.ErrInvalidRequest
}
} else {
valueRaw, ok := d.GetOk("plaintext")
valueRaw, ok := d.Raw["plaintext"]
if !ok {
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}
plaintext, ok := valueRaw.(string)
if !ok {
return logical.ErrorResponse("expected plaintext of type 'string', got unconvertible type '%T'", valueRaw), logical.ErrInvalidRequest
}

batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Plaintext: valueRaw.(string),
Plaintext: plaintext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
Expand Down
120 changes: 113 additions & 7 deletions builtin/logical/transit/path_encrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,75 @@ import (
"github.com/mitchellh/mapstructure"
)

func TestTransit_MissingPlaintext(t *testing.T) {
var resp *logical.Response
var err error

b, s := createBackendWithStorage(t)

// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

encData := map[string]interface{}{
"plaintext": nil,
}

encReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "encrypt/existing_key",
Storage: s,
Data: encData,
}
resp, err = b.HandleRequest(context.Background(), encReq)
if resp == nil || !resp.IsError() {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
}

func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
var resp *logical.Response
var err error

b, s := createBackendWithStorage(t)

// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

batchInput := []interface{}{
map[string]interface{}{}, // Note that there is no map entry for plaintext
}

batchData := map[string]interface{}{
"batch_input": batchInput,
}
batchReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: batchData,
}
resp, err = b.HandleRequest(context.Background(), batchReq)
if err == nil {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
}

// Case1: Ensure that batch encryption did not affect the normal flow of
// encrypting the plaintext with a pre-existing key.
func TestTransit_BatchEncryptionCase1(t *testing.T) {
Expand Down Expand Up @@ -607,10 +676,12 @@ func TestTransit_BatchEncryptionCase13(t *testing.T) {
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.
func TestTransit_decodeBatchRequestItems(t *testing.T) {
tests := []struct {
name string
src interface{}
dest []BatchRequestItem
wantErrContains string
name string
src interface{}
requirePlaintext bool
requireCiphertext bool
dest []BatchRequestItem
wantErrContains string
}{
// basic edge cases of nil values
{name: "nil-nil", src: nil, dest: nil},
Expand Down Expand Up @@ -729,16 +800,51 @@ func TestTransit_decodeBatchRequestItems(t *testing.T) {
src: []interface{}{map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "nonce": "null"}},
dest: []BatchRequestItem{},
},
// required fields
{
name: "required_plaintext_present",
src: []interface{}{map[string]interface{}{"plaintext": ""}},
requirePlaintext: true,
dest: []BatchRequestItem{},
},
{
name: "required_plaintext_missing",
src: []interface{}{map[string]interface{}{}},
requirePlaintext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing plaintext",
},
{
name: "required_ciphertext_present",
src: []interface{}{map[string]interface{}{"ciphertext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}},
requireCiphertext: true,
dest: []BatchRequestItem{},
},
{
name: "required_ciphertext_missing",
src: []interface{}{map[string]interface{}{}},
requireCiphertext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing ciphertext",
},
{
name: "required_plaintext_and_ciphertext_missing",
src: []interface{}{map[string]interface{}{}},
requirePlaintext: true,
requireCiphertext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing ciphertext",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expectedDest := append(tt.dest[:0:0], tt.dest...) // copy of the dest state
expectedErr := mapstructure.Decode(tt.src, &expectedDest)
expectedErr := mapstructure.Decode(tt.src, &expectedDest) != nil || tt.wantErrContains != ""

gotErr := decodeBatchRequestItems(tt.src, &tt.dest)
gotErr := decodeBatchRequestItems(tt.src, tt.requirePlaintext, tt.requireCiphertext, &tt.dest)
gotDest := tt.dest

if expectedErr != nil {
if expectedErr {
if gotErr == nil {
t.Fatal("decodeBatchRequestItems unexpected error value; expected error but got none")
}
Expand Down
11 changes: 10 additions & 1 deletion builtin/logical/transit/path_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,16 @@ Defaults to "sha2-256".`,
}

func (b *backend) pathHashWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
inputB64 := d.Get("input").(string)
rawInput, ok := d.Raw["input"]
if !ok {
return logical.ErrorResponse("input missing"), logical.ErrInvalidRequest
}

inputB64, ok := rawInput.(string)
if !ok {
return logical.ErrorResponse("expected input of type 'string', got unconvertible type '%T'", rawInput), logical.ErrInvalidRequest
}

format := d.Get("format").(string)
algorithm := d.Get("urlalgorithm").(string)
if algorithm == "" {
Expand Down
6 changes: 5 additions & 1 deletion builtin/logical/transit/path_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestTransit_Hash(t *testing.T) {
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
t.Fatalf("bad: did not get error response: %#v", *resp)
}
return
}
Expand Down Expand Up @@ -86,6 +86,10 @@ func TestTransit_Hash(t *testing.T) {
doRequest(req, false, "98rFrYMEIqVAizamCmBiBoe+GAdlo+KJW8O9vYV8nggkbIMGTU42EvDLkn8+rSCEE6uYYkv3sGF68PA/YggJdg==")

// Test bad input/format/algorithm
req.Data["input"] = nil
doRequest(req, true, "")

req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
req.Data["format"] = "base92"
doRequest(req, true, "")

Expand Down
7 changes: 5 additions & 2 deletions builtin/logical/transit/path_trim.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
}
defer p.Unlock()

minAvailableVersionRaw, ok := d.GetOk("min_available_version")
minAvailableVersionRaw, ok := d.Raw["min_available_version"]
if !ok {
return logical.ErrorResponse("missing min_available_version"), nil
}
minAvailableVersion := minAvailableVersionRaw.(int)
minAvailableVersion, ok := minAvailableVersionRaw.(int)
if !ok {
return logical.ErrorResponse("expected min_available_version of type 'int', got unconvertible type '%T'", minAvailableVersionRaw), logical.ErrInvalidRequest
}

originalMinAvailableVersion := p.MinAvailableVersion

Expand Down
14 changes: 14 additions & 0 deletions builtin/logical/transit/path_trim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ func TestTransit_Trim(t *testing.T) {
}
doErrReq(t, req)

// Set min_encryption_version to 0
req.Path = "keys/aes/config"
req.Data = map[string]interface{}{
"min_encryption_version": 0,
}
doReq(t, req)

// Min available version should not be converted to 0 for nil values
req.Path = "keys/aes/trim"
req.Data = map[string]interface{}{
"min_available_version": nil,
}
doErrReq(t, req)

// Set min_encryption_version to 4
req.Path = "keys/aes/config"
req.Data = map[string]interface{}{
Expand Down
3 changes: 3 additions & 0 deletions changelog/14074.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
secrets/transit: Return an error if any required parameter is missing or nil. Do not encrypt nil plaintext as if it was an empty string.
```

0 comments on commit 395cd7b

Please sign in to comment.