Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding transit logical backend #12

Merged
merged 1 commit into from
Apr 16, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions builtin/logical/transit/backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package transit

import (
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

func Factory(map[string]string) (logical.Backend, error) {
return Backend(), nil
}

func Backend() *framework.Backend {
var b backend
b.Backend = &framework.Backend{
PathsSpecial: &logical.Paths{
Root: []string{
"policy/*",
},
},

Paths: []*framework.Path{
pathPolicy(),
pathEncrypt(),
pathDecrypt(),
},

Secrets: []*framework.Secret{},
}

return b.Backend
}

type backend struct {
*framework.Backend
}
132 changes: 132 additions & 0 deletions builtin/logical/transit/backend_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package transit

import (
"encoding/base64"
"fmt"
"testing"

"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
"github.com/mitchellh/mapstructure"
)

const (
testPlaintext = "the quick brown fox"
)

func TestBackend_basic(t *testing.T) {
decryptData := make(map[string]interface{})
logicaltest.Test(t, logicaltest.TestCase{
Backend: Backend(),
Steps: []logicaltest.TestStep{
testAccStepWritePolicy(t, "test"),
testAccStepReadPolicy(t, "test", false),
testAccStepEncrypt(t, "test", testPlaintext, decryptData),
testAccStepDecrypt(t, "test", testPlaintext, decryptData),
testAccStepDeletePolicy(t, "test"),
testAccStepReadPolicy(t, "test", true),
},
})
}

func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "policy/" + name,
}
}

func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.DeleteOperation,
Path: "policy/" + name,
}
}

func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ReadOperation,
Path: "policy/" + name,
Check: func(resp *logical.Response) error {
if resp == nil && !expectNone {
return fmt.Errorf("missing response")
} else if expectNone {
if resp != nil {
return fmt.Errorf("response when expecting none")
}
return nil
}
var d struct {
Name string `mapstructure:"name"`
Key []byte `mapstructure:"key"`
CipherMode string `mapstructure:"cipher_mode"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}

if d.Name != name {
return fmt.Errorf("bad: %#v", d)
}
if d.CipherMode != "aes-gcm" {
return fmt.Errorf("bad: %#v", d)
}
if len(d.Key) != 32 {
return fmt.Errorf("bad: %#v", d)
}
return nil
},
}
}

func testAccStepEncrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "encrypt/" + name,
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)),
},
Check: func(resp *logical.Response) error {
var d struct {
Ciphertext string `mapstructure:"ciphertext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}
if d.Ciphertext == "" {
return fmt.Errorf("missing ciphertext")
}
decryptData["ciphertext"] = d.Ciphertext
return nil
},
}
}

func testAccStepDecrypt(
t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.WriteOperation,
Path: "decrypt/" + name,
Data: decryptData,
Check: func(resp *logical.Response) error {
var d struct {
Plaintext string `mapstructure:"plaintext"`
}
if err := mapstructure.Decode(resp.Data, &d); err != nil {
return err
}

// Decode the base64
plainRaw, err := base64.StdEncoding.DecodeString(d.Plaintext)
if err != nil {
return err
}

if string(plainRaw) != plaintext {
return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext)
}
return nil
},
}
}
100 changes: 100 additions & 0 deletions builtin/logical/transit/path_decrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package transit

import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"strings"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

func pathDecrypt() *framework.Path {
return &framework.Path{
Pattern: `decrypt/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},

"ciphertext": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Ciphertext value to decrypt",
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathDecryptWrite,
},
}
}

func pathDecryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
value := d.Get("ciphertext").(string)
if len(value) == 0 {
return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest
}

// Get the policy
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}

// Error if invalid policy
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}

// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
default:
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
}

// Verify the prefix
if !strings.HasPrefix(value, "vault:v0:") {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}

// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:"))
if err != nil {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}

// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
if err != nil {
return nil, err
}

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, err
}

// Extract the nonce and ciphertext
nonce := decoded[:gcm.NonceSize()]
ciphertext := decoded[gcm.NonceSize():]

// Verify and Decrypt
plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest
}

// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"plaintext": base64.StdEncoding.EncodeToString(plain),
},
}
return resp, nil
}
104 changes: 104 additions & 0 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package transit

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"

"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)

func pathEncrypt() *framework.Path {
return &framework.Path{
Pattern: `encrypt/(?P<name>\w+)`,
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Name of the policy",
},

"plaintext": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Plaintext value to encrypt",
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
logical.WriteOperation: pathEncryptWrite,
},
}
}

func pathEncryptWrite(
req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
value := d.Get("plaintext").(string)
if len(value) == 0 {
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}

// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest
}

// Get the policy
p, err := getPolicy(req, name)
if err != nil {
return nil, err
}

// Error if invalid policy
if p == nil {
return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest
}

// Guard against a potentially invalid cipher-mode
switch p.CipherMode {
case "aes-gcm":
default:
return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest
}

// Setup the cipher
aesCipher, err := aes.NewCipher(p.Key)
if err != nil {
return nil, err
}

// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, err
}

// Compute random nonce
nonce := make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return nil, err
}

// Encrypt and tag with GCM
out := gcm.Seal(nil, nonce, plaintext, nil)

// Place the encrypted data after the nonce
full := append(nonce, out...)

// Convert to base64
encoded := base64.StdEncoding.EncodeToString(full)

// Prepend some information
encoded = "vault:v0:" + encoded

// Generate the response
resp := &logical.Response{
Data: map[string]interface{}{
"ciphertext": encoded,
},
}
return resp, nil
}
Loading