diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go new file mode 100644 index 000000000000..241da22b7bd2 --- /dev/null +++ b/builtin/logical/transit/backend.go @@ -0,0 +1,35 @@ +package transit + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(map[string]string) (logical.Backend, error) { + return Backend(), nil +} + +func Backend() *framework.Backend { + var b backend + b.Backend = &framework.Backend{ + PathsSpecial: &logical.Paths{ + Root: []string{ + "policy/*", + }, + }, + + Paths: []*framework.Path{ + pathPolicy(), + pathEncrypt(), + pathDecrypt(), + }, + + Secrets: []*framework.Secret{}, + } + + return b.Backend +} + +type backend struct { + *framework.Backend +} diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go new file mode 100644 index 000000000000..0d5deecc1209 --- /dev/null +++ b/builtin/logical/transit/backend_test.go @@ -0,0 +1,132 @@ +package transit + +import ( + "encoding/base64" + "fmt" + "testing" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/mitchellh/mapstructure" +) + +const ( + testPlaintext = "the quick brown fox" +) + +func TestBackend_basic(t *testing.T) { + decryptData := make(map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepWritePolicy(t, "test"), + testAccStepReadPolicy(t, "test", false), + testAccStepEncrypt(t, "test", testPlaintext, decryptData), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeletePolicy(t, "test"), + testAccStepReadPolicy(t, "test", true), + }, + }) +} + +func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "policy/" + name, + } +} + +func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: "policy/" + name, + } +} + +func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "policy/" + name, + Check: func(resp *logical.Response) error { + if resp == nil && !expectNone { + return fmt.Errorf("missing response") + } else if expectNone { + if resp != nil { + return fmt.Errorf("response when expecting none") + } + return nil + } + var d struct { + Name string `mapstructure:"name"` + Key []byte `mapstructure:"key"` + CipherMode string `mapstructure:"cipher_mode"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.Name != name { + return fmt.Errorf("bad: %#v", d) + } + if d.CipherMode != "aes-gcm" { + return fmt.Errorf("bad: %#v", d) + } + if len(d.Key) != 32 { + return fmt.Errorf("bad: %#v", d) + } + return nil + }, + } +} + +func testAccStepEncrypt( + t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "encrypt/" + name, + Data: map[string]interface{}{ + "plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)), + }, + Check: func(resp *logical.Response) error { + var d struct { + Ciphertext string `mapstructure:"ciphertext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Ciphertext == "" { + return fmt.Errorf("missing ciphertext") + } + decryptData["ciphertext"] = d.Ciphertext + return nil + }, + } +} + +func testAccStepDecrypt( + t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "decrypt/" + name, + Data: decryptData, + Check: func(resp *logical.Response) error { + var d struct { + Plaintext string `mapstructure:"plaintext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + // Decode the base64 + plainRaw, err := base64.StdEncoding.DecodeString(d.Plaintext) + if err != nil { + return err + } + + if string(plainRaw) != plaintext { + return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext) + } + return nil + }, + } +} diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go new file mode 100644 index 000000000000..bf0c5e41489e --- /dev/null +++ b/builtin/logical/transit/path_decrypt.go @@ -0,0 +1,100 @@ +package transit + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "strings" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathDecrypt() *framework.Path { + return &framework.Path{ + Pattern: `decrypt/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the policy", + }, + + "ciphertext": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Ciphertext value to decrypt", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathDecryptWrite, + }, + } +} + +func pathDecryptWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + value := d.Get("ciphertext").(string) + if len(value) == 0 { + return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest + } + + // Get the policy + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + + // Error if invalid policy + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + // Guard against a potentially invalid cipher-mode + switch p.CipherMode { + case "aes-gcm": + default: + return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest + } + + // Verify the prefix + if !strings.HasPrefix(value, "vault:v0:") { + return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest + } + + // Decode the base64 + decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:")) + if err != nil { + return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest + } + + // Setup the cipher + aesCipher, err := aes.NewCipher(p.Key) + if err != nil { + return nil, err + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, err + } + + // Extract the nonce and ciphertext + nonce := decoded[:gcm.NonceSize()] + ciphertext := decoded[gcm.NonceSize():] + + // Verify and Decrypt + plain, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest + } + + // Generate the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "plaintext": base64.StdEncoding.EncodeToString(plain), + }, + } + return resp, nil +} diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go new file mode 100644 index 000000000000..e7d23308787a --- /dev/null +++ b/builtin/logical/transit/path_encrypt.go @@ -0,0 +1,104 @@ +package transit + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathEncrypt() *framework.Path { + return &framework.Path{ + Pattern: `encrypt/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the policy", + }, + + "plaintext": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Plaintext value to encrypt", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathEncryptWrite, + }, + } +} + +func pathEncryptWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + value := d.Get("plaintext").(string) + if len(value) == 0 { + return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest + } + + // Decode the plaintext value + plaintext, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest + } + + // Get the policy + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + + // Error if invalid policy + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + // Guard against a potentially invalid cipher-mode + switch p.CipherMode { + case "aes-gcm": + default: + return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest + } + + // Setup the cipher + aesCipher, err := aes.NewCipher(p.Key) + if err != nil { + return nil, err + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, err + } + + // Compute random nonce + nonce := make([]byte, gcm.NonceSize()) + _, err = rand.Read(nonce) + if err != nil { + return nil, err + } + + // Encrypt and tag with GCM + out := gcm.Seal(nil, nonce, plaintext, nil) + + // Place the encrypted data after the nonce + full := append(nonce, out...) + + // Convert to base64 + encoded := base64.StdEncoding.EncodeToString(full) + + // Prepend some information + encoded = "vault:v0:" + encoded + + // Generate the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "ciphertext": encoded, + }, + } + return resp, nil +} diff --git a/builtin/logical/transit/path_policy.go b/builtin/logical/transit/path_policy.go new file mode 100644 index 000000000000..ca11f553060e --- /dev/null +++ b/builtin/logical/transit/path_policy.go @@ -0,0 +1,140 @@ +package transit + +import ( + "crypto/rand" + "encoding/json" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// Policy is the struct used to store metadata +type Policy struct { + Name string `json:"name"` + Key []byte `json:"key"` + CipherMode string `json:"cipher"` +} + +func (p *Policy) Serialize() ([]byte, error) { + return json.Marshal(p) +} + +func DeserializePolicy(buf []byte) (*Policy, error) { + p := new(Policy) + if err := json.Unmarshal(buf, p); err != nil { + return nil, err + } + return p, nil +} + +func getPolicy(req *logical.Request, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := req.Storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + p, err := DeserializePolicy(raw.Value) + if err != nil { + return nil, err + } + return p, nil +} + +func pathPolicy() *framework.Path { + return &framework.Path{ + Pattern: `policy/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the policy", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathPolicyWrite, + logical.DeleteOperation: pathPolicyDelete, + logical.ReadOperation: pathPolicyRead, + }, + } +} + +func pathPolicyWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + // Check if the policy already exists + existing, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if existing != nil { + return nil, nil + } + + // Create the policy object + p := &Policy{ + Name: name, + CipherMode: "aes-gcm", + } + + // Generate a 256bit key + p.Key = make([]byte, 32) + _, err = rand.Read(p.Key) + if err != nil { + return nil, err + } + + // Encode the policy + buf, err := p.Serialize() + if err != nil { + return nil, err + } + + // Write the policy into storage + err = req.Storage.Put(&logical.StorageEntry{ + Key: "policy/" + name, + Value: buf, + }) + if err != nil { + return nil, err + } + return nil, nil +} + +func pathPolicyRead( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if p == nil { + return nil, nil + } + + // Return the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "key": p.Key, + "cipher_mode": p.CipherMode, + }, + } + return resp, nil +} + +func pathPolicyDelete( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + err := req.Storage.Delete("policy/" + name) + if err != nil { + return nil, err + } + return nil, nil +} diff --git a/cli/commands.go b/cli/commands.go index acf6b01976a9..73560369c7d9 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -10,6 +10,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/aws" "github.com/hashicorp/vault/builtin/logical/consul" + "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/audit" tokenDisk "github.com/hashicorp/vault/builtin/token/disk" @@ -51,8 +52,9 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { "github": credGitHub.Factory, }, LogicalBackends: map[string]logical.Factory{ - "aws": aws.Factory, - "consul": consul.Factory, + "aws": aws.Factory, + "consul": consul.Factory, + "transit": transit.Factory, }, }, nil },