Skip to content

Commit

Permalink
Add possibility to compress state file when using S3 remote state
Browse files Browse the repository at this point in the history
backend.

Solves hashicorp#20328
  • Loading branch information
SebastianCzoch committed Feb 13, 2019
1 parent 8f3ee18 commit eda63a0
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 111 deletions.
9 changes: 9 additions & 0 deletions backend/remote-state/s3/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ func New() backend.Backend {
Description: "The maximum number of times an AWS API request is retried on retryable failure.",
Default: 5,
},

"compression": {
Type: schema.TypeBool,
Optional: true,
Description: "Enable gzip compression before sending sate file into bucket",
Default: false,
},
},
}

Expand All @@ -248,6 +255,7 @@ type Backend struct {
kmsKeyID string
ddbTable string
workspaceKeyPrefix string
compression bool
}

func (b *Backend) configure(ctx context.Context) error {
Expand All @@ -264,6 +272,7 @@ func (b *Backend) configure(ctx context.Context) error {
b.acl = data.Get("acl").(string)
b.kmsKeyID = data.Get("kms_key_id").(string)
b.workspaceKeyPrefix = data.Get("workspace_key_prefix").(string)
b.compression = data.Get("compression").(bool)

b.ddbTable = data.Get("dynamodb_table").(string)
if b.ddbTable == "" {
Expand Down
1 change: 1 addition & 0 deletions backend/remote-state/s3/backend_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func (b *Backend) remoteClient(name string) (*RemoteClient, error) {
acl: b.acl,
kmsKeyID: b.kmsKeyID,
ddbTable: b.ddbTable,
compression: b.compression,
}

return client, nil
Expand Down
4 changes: 4 additions & 0 deletions backend/remote-state/s3/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func TestBackendConfig(t *testing.T) {
"key": "state",
"encrypt": true,
"dynamodb_table": "dynamoTable",
"compression": true,
}

b := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(config)).(*Backend)
Expand All @@ -53,6 +54,9 @@ func TestBackendConfig(t *testing.T) {
if b.keyName != "state" {
t.Fatalf("Incorrect keyName was populated")
}
if !b.compression {
t.Fatalf("Incorrect compression was populated")
}

credentials, err := b.s3Client.Config.Credentials.Get()
if err != nil {
Expand Down
47 changes: 43 additions & 4 deletions backend/remote-state/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package s3

import (
"bytes"
"compress/gzip"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"time"

Expand Down Expand Up @@ -36,6 +38,7 @@ type RemoteClient struct {
acl string
kmsKeyID string
ddbTable string
compression bool
}

var (
Expand Down Expand Up @@ -122,9 +125,14 @@ func (c *RemoteClient) get() (*remote.Payload, error) {
return nil, fmt.Errorf("Failed to read remote state: %s", err)
}

sum := md5.Sum(buf.Bytes())
data, err := c.decompress(buf.Bytes())
if err != nil {
return nil, err
}

sum := md5.Sum(data)
payload := &remote.Payload{
Data: buf.Bytes(),
Data: data,
MD5: sum[:],
}

Expand All @@ -138,12 +146,17 @@ func (c *RemoteClient) get() (*remote.Payload, error) {

func (c *RemoteClient) Put(data []byte) error {
contentType := "application/json"
contentLength := int64(len(data))
objectBody := data
if c.compression {
objectBody = c.compress(data)
contentType = "gzip"
}

contentLength := int64(len(objectBody))
i := &s3.PutObjectInput{
ContentType: &contentType,
ContentLength: &contentLength,
Body: bytes.NewReader(data),
Body: bytes.NewReader(objectBody),
Bucket: &c.bucketName,
Key: &c.path,
}
Expand Down Expand Up @@ -383,6 +396,32 @@ func (c *RemoteClient) lockPath() string {
return fmt.Sprintf("%s/%s", c.bucketName, c.path)
}

func (c *RemoteClient) decompress(data []byte) ([]byte, error) {
if len(data) == 0 {
return data, nil
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
if err == gzip.ErrHeader {
// not a gzipped data
return data, nil
}

return nil, err
}
defer gz.Close()

return ioutil.ReadAll(gz)
}

func (c *RemoteClient) compress(data []byte) []byte {
b := &bytes.Buffer{}
gz := gzip.NewWriter(b)
gz.Write(data)
gz.Close()
return b.Bytes()
}

const errBadChecksumFmt = `state data in S3 does not have the expected content.
This may be caused by unusually long delays in S3 processing a previous state
Expand Down
230 changes: 123 additions & 107 deletions backend/remote-state/s3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,114 +204,130 @@ func TestRemoteClient_clientMD5(t *testing.T) {

// verify that a client won't return a state with an incorrect checksum.
func TestRemoteClient_stateChecksum(t *testing.T) {
testACC(t)

bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"

b1 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"dynamodb_table": bucketName,
})).(*Backend)

createS3Bucket(t, b1.s3Client, bucketName)
defer deleteS3Bucket(t, b1.s3Client, bucketName)
createDynamoDBTable(t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)

s1, err := b1.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client1 := s1.(*remote.State).Client

// create a old and new state version to persist
s := state.TestStateInitial()
sf := &statefile.File{State: s}
var oldState bytes.Buffer
if err := statefile.Write(sf, &oldState); err != nil {
t.Fatal(err)
}
sf.Serial++
var newState bytes.Buffer
if err := statefile.Write(sf, &newState); err != nil {
t.Fatal(err)
}

// Use b2 without a dynamodb_table to bypass the lock table to write the state directly.
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
})).(*Backend)
s2, err := b2.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client2 := s2.(*remote.State).Client

// write the new state through client2 so that there is no checksum yet
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
testCases := map[string]struct {
compression bool
}{
"without compression": {
compression: false,
},
"with compression": {
compression: true,
},
}

// verify that we can pull a state without a checksum
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}

// write the new state back with its checksum
if err := client1.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}

// put an empty state in place to check for panics during get
if err := client2.Put([]byte{}); err != nil {
t.Fatal(err)
}

// remove the timeouts so we can fail immediately
origTimeout := consistencyRetryTimeout
origInterval := consistencyRetryPollInterval
defer func() {
consistencyRetryTimeout = origTimeout
consistencyRetryPollInterval = origInterval
}()
consistencyRetryTimeout = 0
consistencyRetryPollInterval = 0

// fetching an empty state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}

// put the old state in place of the new, without updating the checksum
if err := client2.Put(oldState.Bytes()); err != nil {
t.Fatal(err)
}

// fetching the wrong state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}

// update the state with the correct one after we Get again
testChecksumHook = func() {
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
testChecksumHook = nil
}

consistencyRetryTimeout = origTimeout

// this final Get will fail to fail the checksum verification, the above
// callback will update the state with the correct version, and Get should
// retry automatically.
if _, err := client1.Get(); err != nil {
t.Fatal(err)
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
testACC(t)

bucketName := fmt.Sprintf("terraform-remote-s3-test-%x", time.Now().Unix())
keyName := "testState"

b1 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
"dynamodb_table": bucketName,
"compression": tc.compression,
})).(*Backend)

createS3Bucket(t, b1.s3Client, bucketName)
defer deleteS3Bucket(t, b1.s3Client, bucketName)
createDynamoDBTable(t, b1.dynClient, bucketName)
defer deleteDynamoDBTable(t, b1.dynClient, bucketName)

s1, err := b1.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client1 := s1.(*remote.State).Client

// create a old and new state version to persist
s := state.TestStateInitial()
sf := &statefile.File{State: s}
var oldState bytes.Buffer
if err := statefile.Write(sf, &oldState); err != nil {
t.Fatal(err)
}
sf.Serial++
var newState bytes.Buffer
if err := statefile.Write(sf, &newState); err != nil {
t.Fatal(err)
}

// Use b2 without a dynamodb_table to bypass the lock table to write the state directly.
// client2 will write the "incorrect" state, simulating s3 eventually consistency delays
b2 := backend.TestBackendConfig(t, New(), backend.TestWrapConfig(map[string]interface{}{
"bucket": bucketName,
"key": keyName,
})).(*Backend)
s2, err := b2.StateMgr(backend.DefaultStateName)
if err != nil {
t.Fatal(err)
}
client2 := s2.(*remote.State).Client

// write the new state through client2 so that there is no checksum yet
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}

// verify that we can pull a state without a checksum
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}

// write the new state back with its checksum
if err := client1.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}

// put an empty state in place to check for panics during get
if err := client2.Put([]byte{}); err != nil {
t.Fatal(err)
}

// remove the timeouts so we can fail immediately
origTimeout := consistencyRetryTimeout
origInterval := consistencyRetryPollInterval
defer func() {
consistencyRetryTimeout = origTimeout
consistencyRetryPollInterval = origInterval
}()
consistencyRetryTimeout = 0
consistencyRetryPollInterval = 0

// fetching an empty state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}

// put the old state in place of the new, without updating the checksum
if err := client2.Put(oldState.Bytes()); err != nil {
t.Fatal(err)
}

// fetching the wrong state through client1 should now error out due to a
// mismatched checksum.
if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) {
t.Fatalf("expected state checksum error: got %s", err)
}

// update the state with the correct one after we Get again
testChecksumHook = func() {
if err := client2.Put(newState.Bytes()); err != nil {
t.Fatal(err)
}
testChecksumHook = nil
}

consistencyRetryTimeout = origTimeout

// this final Get will fail to fail the checksum verification, the above
// callback will update the state with the correct version, and Get should
// retry automatically.
if _, err := client1.Get(); err != nil {
t.Fatal(err)
}
})
}
}

0 comments on commit eda63a0

Please sign in to comment.