Skip to content

Commit

Permalink
Refactor common.Login (#1101)
Browse files Browse the repository at this point in the history
* convert function args to a struct
* add some missing tests
* move logic that is only relevant for connect out
  • Loading branch information
ishustava authored Mar 22, 2022
1 parent 32d513d commit f8c7780
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 61 deletions.
16 changes: 11 additions & 5 deletions control-plane/subcommand/acl-init/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,26 @@ func (c *Command) Run(args []string) int {
return 1
}

meta := map[string]string{
"component": c.flagComponentName,
loginParams := common.LoginParams{
AuthMethod: c.flagACLAuthMethod,
Datacenter: c.flagPrimaryDatacenter,
BearerTokenFile: c.bearerTokenFile,
TokenSinkFile: c.flagTokenSinkFile,
Meta: map[string]string{
"component": c.flagComponentName,
},
}
secret, err = common.ConsulLogin(c.consulClient, cfg, c.flagACLAuthMethod, c.flagPrimaryDatacenter, "", c.bearerTokenFile, "", c.flagTokenSinkFile, meta, c.logger)
secret, err = common.ConsulLogin(c.consulClient, loginParams, c.logger)
if err != nil {
c.logger.Error("Consul login failed", "error", err)
return 1
}
c.logger.Info("Successfully read ACL token from the server")
} else {
// Use k8s secret to obtain token
// Use k8s secret to obtain token.

// Check if the client secret exists yet
// If not, wait until it does
// If not, wait until it does.
for {
var err error
secret, err = c.getSecret(c.flagSecretName)
Expand Down
74 changes: 42 additions & 32 deletions control-plane/subcommand/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/cenkalti/backoff"
"github.com/go-logr/logr"
"github.com/hashicorp/consul-k8s/control-plane/consul"
godiscover "github.com/hashicorp/consul-k8s/control-plane/helper/go-discover"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/go-discover"
Expand Down Expand Up @@ -86,71 +85,83 @@ func ValidateUnprivilegedPort(flagName, flagValue string) error {
return nil
}

// LoginParams are parameters used to log in to consul.
type LoginParams struct {
// AuthMethod is the name of the auth method.
AuthMethod string
// Datacenter is the datacenter for the login request.
Datacenter string
// Namespace is the namespace for the login request.
Namespace string
// BearerTokenFile is the file where the bearer token is stored.
BearerTokenFile string
// TokenSinkFile is the file where to write the token received from Consul.
TokenSinkFile string
// Meta is the metadata to set on the token.
Meta map[string]string

// numRetries is only used in tests to make them run faster.
numRetries uint64
}

// ConsulLogin issues an ACL().Login to Consul and writes out the token to tokenSinkFile.
// The logic of this is taken from the `consul login` command.
func ConsulLogin(client *api.Client, cfg *api.Config, authMethodName, datacenter, namespace, bearerTokenFile, serviceAccountName, tokenSinkFile string, meta map[string]string, log hclog.Logger) (string, error) {
func ConsulLogin(client *api.Client, params LoginParams, log hclog.Logger) (string, error) {
// Read the bearerTokenFile.
data, err := ioutil.ReadFile(bearerTokenFile)
data, err := ioutil.ReadFile(params.BearerTokenFile)
if err != nil {
return "", fmt.Errorf("unable to read bearerTokenFile: %v, err: %v", bearerTokenFile, err)
return "", fmt.Errorf("unable to read bearer token file: %v, err: %v", params.BearerTokenFile, err)
}
bearerToken := strings.TrimSpace(string(data))
if bearerToken == "" {
return "", fmt.Errorf("no bearer token found in %s", bearerTokenFile)
return "", fmt.Errorf("no bearer token found in %q", params.BearerTokenFile)
}

if params.numRetries == 0 {
params.numRetries = numLoginRetries
}
var token *api.ACLToken
err = backoff.Retry(func() error {
// Do the login.
req := &api.ACLLoginParams{
AuthMethod: authMethodName,
AuthMethod: params.AuthMethod,
BearerToken: bearerToken,
Meta: meta,
Meta: params.Meta,
}
// The datacenter flag will either have the value of the primary datacenter or "". In case of the latter,
// the token will be created in the datacenter of the installation. In case a global token is required,
// the token will be created in the primary datacenter.
tok, _, err := client.ACL().Login(req, &api.WriteOptions{Namespace: namespace, Datacenter: datacenter})
token, _, err = client.ACL().Login(req, &api.WriteOptions{Namespace: params.Namespace, Datacenter: params.Datacenter})
if err != nil {
log.Error("unable to login", "error", err)
return fmt.Errorf("error logging in: %s", err)
}
if tokenSinkFile != "" {
if params.TokenSinkFile != "" {
// Write out the resultant token file.
// Must be 0644 because this is written by the consul-k8s user but needs
// to be readable by the consul user
if err := WriteFileWithPerms(tokenSinkFile, tok.SecretID, 0644); err != nil {
if err = WriteFileWithPerms(params.TokenSinkFile, token.SecretID, 0644); err != nil {
return fmt.Errorf("error writing token to file sink: %v", err)
}
}
return err
}, backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), numLoginRetries))
}, backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), params.numRetries))
if err != nil {
if serviceAccountName == "default" {
log.Warn("The service account name for this Pod is \"default\"." +
" In default installations this is not a supported service account name." +
" The service account name must match the name of the Kubernetes Service" +
" or the consul.hashicorp.com/connect-service annotation.")
}
log.Error("Hit maximum retries for consul login", "error", err)
return "", err
}
// Now update the client so that it will read the ACL token we just fetched.
cfg.TokenFile = tokenSinkFile
client, err = consul.NewClient(cfg)
if err != nil {
log.Error("Unable to update client connection", "error", err)
return "", err
}

log.Info("Consul login complete")

// A workaround to check that the ACL token is replicated to other Consul servers.
//
// A consul client may reach out to a follower instead of a leader to resolve the token during the
// call to get services below. This is because clients talk to servers in the stale consistency mode
// A consul client may reach out to a follower instead of a leader to resolve the token for an API call
// with that token. This is because clients talk to servers in the stale consistency mode
// to decrease the load on the servers (see https://www.consul.io/docs/architecture/consensus#stale).
// In that case, it's possible that the token isn't replicated
// to that server instance yet. The client will then get an "ACL not found" error
// and subsequently cache this not found response. Then our call below
// to get services from the agent will keep hitting the same "ACL not found" error
// and subsequently cache this not found response. Then on any API call with the token,
// we will keep hitting the same "ACL not found" error
// until the cache entry expires (determined by the `acl_token_ttl` which defaults to 30 seconds).
// This is not great because it will delay app start up time by 30 seconds in most cases
// (if you are running 3 servers, then the probability of ending up on a follower is close to 2/3).
Expand All @@ -163,17 +174,16 @@ func ConsulLogin(client *api.Client, cfg *api.Config, authMethodName, datacenter
// Note though that this workaround does not eliminate this problem completely. It's still possible
// for this call and the next call to reach different servers and those servers to have different
// states from each other.
// For example, this call can reach a leader and succeed, while the call below can go to a follower
// For example, this call can reach a leader and succeed, while the next call can go to a follower
// that is still behind the leader and get an "ACL not found" error.
// However, this is a pretty unlikely case because
// clients have sticky connections to a server, and those connections get rebalanced only every 2-3min.
// And so, this workaround should work in a vast majority of cases.
log.Info("Checking that the ACL token exists when reading it in the stale consistency mode")
// Use raft timeout and polling interval to determine the number of retries.
numTokenReadRetries := uint64(raftReplicationTimeout.Milliseconds() / tokenReadPollingInterval.Milliseconds())
var aclLoginToken *api.ACLToken
err = backoff.Retry(func() error {
aclLoginToken, _, err = client.ACL().TokenReadSelf(&api.QueryOptions{AllowStale: true})
_, _, err = client.ACL().TokenReadSelf(&api.QueryOptions{AllowStale: true, Token: token.SecretID})
if err != nil {
log.Error("Unable to read ACL token; retrying", "err", err)
}
Expand All @@ -185,7 +195,7 @@ func ConsulLogin(client *api.Client, cfg *api.Config, authMethodName, datacenter
return "", err
}
log.Info("Successfully read ACL token from the server")
return aclLoginToken.SecretID, nil
return token.SecretID, nil
}

// WriteFileWithPerms will write payload as the contents of the outputFile and set permissions after writing the contents. This function is necessary since using ioutil.WriteFile() alone will create the new file with the requested permissions prior to actually writing the file, so you can't set read-only permissions.
Expand Down
145 changes: 122 additions & 23 deletions control-plane/subcommand/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,54 +54,154 @@ func TestValidateUnprivilegedPort(t *testing.T) {
// TestConsulLogin ensures that our implementation of consul login hits `/v1/acl/login`.
func TestConsulLogin(t *testing.T) {
t.Parallel()
require := require.New(t)

counter := 0
bearerTokenFile := WriteTempFile(t, "foo")
tokenFile := WriteTempFile(t, "")

// This is a common.Logger.
log, err := Logger("INFO", false)
require.NoError(err)
client, cfg := startMockServer(t, &counter)
_, err = ConsulLogin(client, cfg, testAuthMethod, "dc1", "", bearerTokenFile, "", tokenFile, testPodMeta, log)
require.NoError(err)
require.Equal(counter, 1)
require.NoError(t, err)
client := startMockServer(t)
params := LoginParams{
AuthMethod: testAuthMethod,
Datacenter: "dc1",
BearerTokenFile: bearerTokenFile,
TokenSinkFile: tokenFile,
}
_, err = ConsulLogin(client, params, log)
require.NoError(t, err)
// Validate that the token file was written to disk.
data, err := ioutil.ReadFile(tokenFile)
require.NoError(err)
require.Equal(string(data), "b78d37c7-0ca7-5f4d-99ee-6d9975ce4586")
require.NoError(t, err)
require.Equal(t, string(data), "b78d37c7-0ca7-5f4d-99ee-6d9975ce4586")
}

// TestConsulLogin_Retries tests we retry /v1/acl/login call if it fails.
func TestConsulLogin_Retries(t *testing.T) {
t.Parallel()

numLoginCalls := 0
bearerTokenFile := WriteTempFile(t, "foo")
tokenFile := WriteTempFile(t, "")

// This is a common.Logger.
log, err := Logger("INFO", false)
require.NoError(t, err)
// Start the Consul server.
consulServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Record all the API calls made.
if r != nil && r.URL.Path == "/v1/acl/login" && r.Method == "POST" {
if numLoginCalls == 0 {
w.WriteHeader(500)
} else {
w.Write([]byte(testLoginResponse))
}
numLoginCalls++
}
if r != nil && r.URL.Path == "/v1/acl/token/self" && r.Method == "GET" {
w.Write([]byte(testLoginResponse))
}
}))
t.Cleanup(consulServer.Close)

serverURL, err := url.Parse(consulServer.URL)
require.NoError(t, err)
clientConfig := &api.Config{Address: serverURL.String()}
client, err := api.NewClient(clientConfig)
require.NoError(t, err)
params := LoginParams{
AuthMethod: testAuthMethod,
Datacenter: "dc1",
BearerTokenFile: bearerTokenFile,
TokenSinkFile: tokenFile,
}
_, err = ConsulLogin(client, params, log)
require.NoError(t, err)
require.Equal(t, 2, numLoginCalls)
// Validate that the token file was written to disk.
data, err := ioutil.ReadFile(tokenFile)
require.NoError(t, err)
require.Equal(t, string(data), "b78d37c7-0ca7-5f4d-99ee-6d9975ce4586")
}

// TestConsulLogin_TokenNotReplicated tests that if we can't read the token in stale consistency mode
// we return an error.
func TestConsulLogin_TokenNotReplicated(t *testing.T) {
t.Parallel()

bearerTokenFile := WriteTempFile(t, "foo")
tokenFile := WriteTempFile(t, "")

// This is a common.Logger.
log, err := Logger("INFO", false)
require.NoError(t, err)
// Start the Consul server.
consulServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Record all the API calls made.
if r != nil && r.URL.Path == "/v1/acl/login" && r.Method == "POST" {
w.Write([]byte(testLoginResponse))
}
if r != nil && r.URL.Path == "/v1/acl/token/self" && r.Method == "GET" {
w.WriteHeader(500)
}
}))
t.Cleanup(consulServer.Close)

serverURL, err := url.Parse(consulServer.URL)
require.NoError(t, err)
clientConfig := &api.Config{Address: serverURL.String()}
client, err := api.NewClient(clientConfig)
require.NoError(t, err)
params := LoginParams{
AuthMethod: testAuthMethod,
Datacenter: "dc1",
BearerTokenFile: bearerTokenFile,
TokenSinkFile: tokenFile,
}
_, err = ConsulLogin(client, params, log)
require.EqualError(t, err, "Unexpected response code: 500 ()")
}

func TestConsulLogin_EmptyBearerTokenFile(t *testing.T) {
t.Parallel()
require := require.New(t)

bearerTokenFile := WriteTempFile(t, "")
_, err := ConsulLogin(nil, nil, testAuthMethod, "", "", bearerTokenFile, "", "", testPodMeta, hclog.NewNullLogger())
require.EqualError(err, fmt.Sprintf("no bearer token found in %s", bearerTokenFile))
params := LoginParams{
BearerTokenFile: bearerTokenFile,
}
_, err := ConsulLogin(nil, params, hclog.NewNullLogger())
require.EqualError(err, fmt.Sprintf("no bearer token found in %q", bearerTokenFile))
}

func TestConsulLogin_BearerTokenFileDoesNotExist(t *testing.T) {
t.Parallel()
require := require.New(t)
randFileName := fmt.Sprintf("/foo/%d/%d", rand.Int(), rand.Int())
_, err := ConsulLogin(nil, nil, testAuthMethod, "", "", randFileName, "", "", testPodMeta, hclog.NewNullLogger())
params := LoginParams{
BearerTokenFile: randFileName,
}
_, err := ConsulLogin(nil, params, hclog.NewNullLogger())
require.Error(err)
require.Contains(err.Error(), "unable to read bearerTokenFile")
require.Contains(err.Error(), "unable to read bearer token file")
}

func TestConsulLogin_TokenFileUnwritable(t *testing.T) {
t.Parallel()
require := require.New(t)
counter := 0
bearerTokenFile := WriteTempFile(t, "foo")
client, cfg := startMockServer(t, &counter)
client := startMockServer(t)
// This is a common.Logger.
log, err := Logger("INFO", false)
require.NoError(err)
randFileName := fmt.Sprintf("/foo/%d/%d", rand.Int(), rand.Int())
_, err = ConsulLogin(client, cfg, testAuthMethod, "", "", bearerTokenFile, "", randFileName, testPodMeta, log)
params := LoginParams{
AuthMethod: testAuthMethod,
BearerTokenFile: bearerTokenFile,
TokenSinkFile: randFileName,
numRetries: 2,
}
_, err = ConsulLogin(client, params, log)
require.Error(err)
require.Contains(err.Error(), "error writing token to file sink")
}
Expand Down Expand Up @@ -199,15 +299,16 @@ func TestGetResolvedServerAddresses(t *testing.T) {
// startMockServer starts an httptest server used to mock a Consul server's
// /v1/acl/login endpoint. apiCallCounter will be incremented on each call to /v1/acl/login.
// It returns a consul client pointing at the server.
func startMockServer(t *testing.T, apiCallCounter *int) (*api.Client, *api.Config) {

func startMockServer(t *testing.T) *api.Client {
// Start the Consul server.
consulServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Record all the API calls made.
if r != nil && r.URL.Path == "/v1/acl/login" && r.Method == "POST" {
*apiCallCounter++
w.Write([]byte(testLoginResponse))
}
if r != nil && r.URL.Path == "/v1/acl/token/self" && r.Method == "GET" {
w.Write([]byte(testLoginResponse))
}
w.Write([]byte(testLoginResponse))
}))
t.Cleanup(consulServer.Close)

Expand All @@ -217,7 +318,7 @@ func startMockServer(t *testing.T, apiCallCounter *int) (*api.Client, *api.Confi
client, err := api.NewClient(clientConfig)
require.NoError(t, err)

return client, clientConfig
return client
}

const testAuthMethod = "consul-k8s-auth-method"
Expand All @@ -243,5 +344,3 @@ const testLoginResponse = `{
"CreateIndex": 36,
"ModifyIndex": 36
}`

var testPodMeta = map[string]string{"pod": "default/podName"}
Loading

0 comments on commit f8c7780

Please sign in to comment.