Skip to content

Commit

Permalink
Tproxy writefile changes (#465)
Browse files Browse the repository at this point in the history
modify how we do writefile for acl token and proxy id file
Co-authored-by: Nitya Dhanushkodi <[email protected]>
Co-authored-by: Ashwin Venkatesh <[email protected]>
  • Loading branch information
kschoche authored Mar 25, 2021
1 parent a23abd1 commit 76d9730
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 12 deletions.
19 changes: 16 additions & 3 deletions subcommand/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,23 @@ func ConsulLogin(client *api.Client, bearerTokenFile, authMethodName, tokenSinkF
return fmt.Errorf("error logging in: %s", err)
}

// Write the token out to file with permissions so consul-k8s user can read.
payload := []byte(tok.SecretID)
if err := ioutil.WriteFile(tokenSinkFile, payload, 0444); err != nil {
if err := WriteFileWithPerms(tokenSinkFile, tok.SecretID, 0444); err != nil {
return fmt.Errorf("error writing token to file sink: %v", err)
}
return nil
}

// WriteFileWithPerms will write payload as the contents of the outputFile and set permissions after writing the contents. This function is necessary since using ioutil.WriteFile() alone will create the new file with the requested permissions prior to actually writing the file, so you can't set read-only permissions.
func WriteFileWithPerms(outputFile, payload string, mode os.FileMode) error {
// os.WriteFile truncates existing files and overwrites them, but only if they are writable.
// If the file exists it will already likely be read-only. Remove it first.
if _, err := os.Stat(outputFile); err == nil {
if err = os.Remove(outputFile); err != nil {
return fmt.Errorf("unable to delete existing file: %s", err)
}
}
if err := ioutil.WriteFile(outputFile, []byte(payload), os.ModePerm); err != nil {
return fmt.Errorf("unable to write file: %s", err)
}
return os.Chmod(outputFile, mode)
}
53 changes: 53 additions & 0 deletions subcommand/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"

"github.com/hashicorp/consul/api"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -107,6 +109,57 @@ func TestConsulLogin_TokenFileUnwritable(t *testing.T) {
require.Contains(err.Error(), "error writing token to file sink")
}

func TestWriteFileWithPerms_InvalidOutputFile(t *testing.T) {
t.Parallel()
rand.Seed(time.Now().UnixNano())
randFileName := fmt.Sprintf("/tmp/tmp/tmp/%d", rand.Int())
t.Cleanup(func() {
os.Remove(randFileName)
})
err := WriteFileWithPerms(randFileName, "", os.FileMode(0444))
require.Errorf(t, err, "unable to create file: %s", randFileName)
}

func TestWriteFileWithPerms_OutputFileExists(t *testing.T) {
t.Parallel()
rand.Seed(time.Now().UnixNano())
randFileName := fmt.Sprintf("/tmp/%d", rand.Int())
err := ioutil.WriteFile(randFileName, []byte("foo"), os.FileMode(0444))
require.NoError(t, err)
t.Cleanup(func() {
os.Remove(randFileName)
})
payload := "abcd"
err = WriteFileWithPerms(randFileName, payload, os.FileMode(0444))
require.NoError(t, err)
data, err := ioutil.ReadFile(randFileName)
require.NoError(t, err)
require.Equal(t, payload, string(data))
}

func TestWriteFileWithPerms(t *testing.T) {
t.Parallel()
payload := "foo-foo-foo-foo"
rand.Seed(time.Now().UnixNano())
randFileName := fmt.Sprintf("/tmp/%d", rand.Int())
t.Cleanup(func() {
os.Remove(randFileName)
})
// Issue the write.
mode := os.FileMode(0444)
err := WriteFileWithPerms(randFileName, payload, mode)
require.NoError(t, err)
file, err := os.Stat(randFileName)
require.NoError(t, err)
// Validate the size and mode are correct.
require.Equal(t, file.Mode(), mode)
require.Equal(t, file.Size(), int64(len(payload)))
// Validate the data was written correctly.
data, err := ioutil.ReadFile(randFileName)
require.NoError(t, err)
require.Equal(t, payload, string(data))
}

// startMockServer starts an httptest server used to mock a Consul server's
// /v1/acl/login endpoint. apiCallCounter will be incremented on each call to /v1/acl/login.
// It returns a consul client pointing at the server.
Expand Down
4 changes: 2 additions & 2 deletions subcommand/connect-init/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package connectinit
import (
"flag"
"fmt"
"io/ioutil"
"os"
"sync"
"time"

Expand Down Expand Up @@ -158,7 +158,7 @@ func (c *Command) Run(args []string) int {
return 1
}
// Write the proxy ID to the shared volume so `consul connect envoy` can use it for bootstrapping.
err = ioutil.WriteFile(c.proxyIDFile, []byte(proxyID), 0444)
err = common.WriteFileWithPerms(c.proxyIDFile, proxyID, os.FileMode(0444))
if err != nil {
c.UI.Error(fmt.Sprintf("Unable to write proxy ID to file: %s", err))
return 1
Expand Down
25 changes: 18 additions & 7 deletions subcommand/connect-init/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"

Expand Down Expand Up @@ -66,8 +67,12 @@ func TestRun_ServicePollingWithACLsAndTLS(t *testing.T) {
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
bearerFile := common.WriteTempFile(t, serviceAccountJWTToken)
proxyFile := common.WriteTempFile(t, "")
tokenFile := common.WriteTempFile(t, "")
tokenFile := fmt.Sprintf("/tmp/%d1", rand.Int())
proxyFile := fmt.Sprintf("/tmp/%d2", rand.Int())
t.Cleanup(func() {
os.Remove(proxyFile)
os.Remove(tokenFile)
})

var caFile, certFile, keyFile string
// Start Consul server with ACLs enabled and default deny policy.
Expand Down Expand Up @@ -205,7 +210,10 @@ func TestRun_ServicePollingOnly(t *testing.T) {
}
for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
proxyFile := common.WriteTempFile(t, "")
proxyFile := fmt.Sprintf("/tmp/%d", rand.Int())
t.Cleanup(func() {
os.Remove(proxyFile)
})

var caFile, certFile, keyFile string
// Start Consul server with TLS enabled if required.
Expand Down Expand Up @@ -422,7 +430,10 @@ func TestRun_ServicePollingErrors(t *testing.T) {

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
proxyFile := common.WriteTempFile(t, "")
proxyFile := fmt.Sprintf("/tmp/%d", rand.Int())
t.Cleanup(func() {
os.Remove(proxyFile)
})

// Start Consul server.
server, err := testutil.NewTestServerConfigT(t, nil)
Expand Down Expand Up @@ -528,7 +539,7 @@ func TestRun_InvalidProxyFile(t *testing.T) {
proxyIDFile: randFileName,
serviceRegistrationPollingAttempts: 3,
}
expErr := fmt.Sprintf("Unable to write proxy ID to file: open %s: no such file or directory\n", randFileName)
expErr := fmt.Sprintf("Unable to write proxy ID to file: unable to write file: open %s: no such file or directory\n", randFileName)
flags := []string{"-http-addr", server.HTTPAddr}
flags = append(flags, defaultTestFlags...)
code := cmd.Run(flags)
Expand Down Expand Up @@ -773,7 +784,7 @@ xtr5PSwH1DusYfVaGH2O
"Tags": [],
"Meta": {
"k8s-namespace": "default",
"pod-name": "counting"
"pod-name": "counting-pod"
},
"Port": 9001,
"Address": "10.32.3.26",
Expand Down Expand Up @@ -801,7 +812,7 @@ xtr5PSwH1DusYfVaGH2O
"Tags": [],
"Meta": {
"k8s-namespace": "default",
"pod-name": "counting"
"pod-name": "counting-pod"
},
"Port": 20000,
"Address": "10.32.3.26",
Expand Down

0 comments on commit 76d9730

Please sign in to comment.